Add validation rules for builder parameters.

This commit is contained in:
Ville Pietilä
2025-11-03 09:57:27 +00:00
parent aeaa457e75
commit fd15355261
2 changed files with 1337 additions and 0 deletions

View File

@@ -0,0 +1,512 @@
# Template Parameter Constraint Rules for Forward Convolution Device Operations
This document lists all static_assert rules and runtime validation checks that constrain template parameter selection for the five forward convolution device operations in Composable Kernel.
## Device Operations Analyzed
1. **DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3**
2. **DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle**
3. **DeviceGroupedConvFwdMultipleD_Wmma_CShuffle**
4. **DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor**
5. **DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK**
---
## 1. DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
### Gridwise Implementation
Uses: `GridwiseGemmMultiD_xdl_cshuffle_v3`
### Compile-Time Static Asserts (Gridwise Level)
#### Block and Wave Tiling Constraints
```cpp
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
```
- **Rule**: MPerBlock must be divisible by (MPerXdl × MXdlPerWave)
- **Rule**: NPerBlock must be divisible by (NXdlPerWave × NPerXdl)
#### Shuffle Constraints
```cpp
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!");
```
- **Rule**: MXdlPerWave must be divisible by CShuffleMXdlPerWavePerShuffle
- **Rule**: NXdlPerWave must be divisible by CShuffleNXdlPerWavePerShuffle
### Compile-Time Static Asserts (Blockwise Level)
From `BlockwiseGemmXdlops`:
```cpp
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 &&
NPerBlock % (NPerXDL * NRepeat) == 0,
"wrong!");
static_assert(KPerThread % KPack == 0,
"Wrong KPack setting; try increasing KPerThread or decreasing KPack");
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
```
- **Rule**: MPerBlock must be divisible by (MPerXDL × MRepeat)
- **Rule**: NPerBlock must be divisible by (NPerXDL × NRepeat)
- **Rule**: KPerThread must be divisible by KPack
- **Rule**: BlockSize must equal MWaves × NWaves × WaveSize
### Runtime Validation Checks
#### Vector Access for A (Input) Tensor
For layouts G_NW_C, G_NHW_C, G_NDHW_C, GNWC, GNHWC, GNDHWC, NWGC, NHWGC, NDHWGC, NGCW, NGCHW, NGCDHW:
```cpp
C % ABlockTransferSrcScalarPerVector == 0
```
- **Rule**: C (input channels) must be divisible by ABlockTransferSrcScalarPerVector when ABlockTransferSrcVectorDim == 2
#### Vector Access for B (Weight) Tensor
For layouts G_K_X_C, G_K_YX_C, G_K_ZYX_C, GKXC, GKYXC, GKZYXC, KXGC, KYXGC, KZYXGC, GKCX, GKCYX, GKCZYX:
```cpp
C % BBlockTransferSrcScalarPerVector == 0
```
- **Rule**: C (input channels) must be divisible by BBlockTransferSrcScalarPerVector when BBlockTransferSrcVectorDim == 2
#### Vector Access for E (Output) Tensor
For layouts G_NW_K, G_NHW_K, G_NDHW_K, GNWK, GNHWK, GNDHWK, NWGK, NHWGK, NDHWGK, NGKW, NGKHW, NGKDHW:
```cpp
K % CDEBlockTransferScalarPerVector_NPerBlock == 0
```
- **Rule**: K (output channels) must be divisible by CDEBlockTransferScalarPerVector_NPerBlock
#### Special NGCHW/NGCDHW Layout Constraints
For NGCHW/NGCDHW layouts requiring transpose:
```cpp
(G * C) % CDEBlockTransferScalarPerVector_NPerBlock == 0
(G * K) % CDEBlockTransferScalarPerVector_NPerBlock == 0
input_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock == 0
output_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock == 0
```
- **Rule**: G×C must be divisible by CDEBlockTransferScalarPerVector_NPerBlock
- **Rule**: G×K must be divisible by CDEBlockTransferScalarPerVector_NPerBlock
- **Rule**: Product of input spatial dimensions must be divisible by CDEBlockTransferScalarPerVector_NPerBlock
- **Rule**: Product of output spatial dimensions must be divisible by CDEBlockTransferScalarPerVector_NPerBlock
#### Descriptor Size Constraints
```cpp
a_grid_desc.GetElementSpaceSize() * sizeof(ADataType) <= 2GB
b_grid_desc.GetElementSpaceSize() * sizeof(BDataType) <= 2GB
c_grid_desc.GetElementSpaceSize() * sizeof(CDataType) <= 2GB
```
- **Rule**: Each tensor descriptor must represent less than 2GB of data
#### Device-Specific Constraints
- On **gfx908**: AccDataType must be `float` or `int32_t`
- **DirectLoad** mode: Only supported on gfx950
---
## 2. DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
### Gridwise Implementation
Uses:
- `GridwiseGemmMultipleABD_xdl_cshuffle` (when isMultiA || isMultiB)
- `GridwiseGemmMultipleD_xdl_cshuffle` (otherwise)
### Compile-Time Static Asserts (Gridwise Level)
Same as V3 version:
```cpp
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
static_assert(KPerBlock % AK1Value == 0 && KPerBlock % BK1Value == 0,
"KPerBlock must be divisible by AK1Value and BK1Value!");
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!");
```
- **Rule**: MPerBlock must be divisible by (MPerXdl × MXdlPerWave)
- **Rule**: NPerBlock must be divisible by (NXdlPerWave × NPerXdl)
- **Rule**: KPerBlock must be divisible by both AK1Value and BK1Value
- **Rule**: MXdlPerWave must be divisible by CShuffleMXdlPerWavePerShuffle
- **Rule**: NXdlPerWave must be divisible by CShuffleNXdlPerWavePerShuffle
### Runtime Validation Checks
#### Vector Access for A (Input) Tensor
For standard layouts (G_NW_C, G_NHW_C, etc.):
```cpp
C % ABlockTransferSrcScalarPerVector == 0 // When ABlockTransferSrcVectorDim == 2
```
- **Rule**: C must be divisible by ABlockTransferSrcScalarPerVector
Alternative for grouped layouts with C==1 or NumGroupsToMerge==1:
```cpp
G % ABlockTransferSrcScalarPerVector == 0 // When ABlockTransferSrcVectorDim == 1
```
- **Rule**: G must be divisible by ABlockTransferSrcScalarPerVector when accessing per G dimension
For NGCHW/NGCDHW layouts without transpose:
```cpp
input_spatial_acum % ABlockTransferSrcScalarPerVector == 0 // When ABlockTransferSrcVectorDim == 1
```
- **Rule**: Product of input spatial dimensions must be divisible by ABlockTransferSrcScalarPerVector
#### Vector Access for B (Weight) Tensor
```cpp
C % BBlockTransferSrcScalarPerVector == 0 // When BBlockTransferSrcVectorDim == 2
```
- **Rule**: C must be divisible by BBlockTransferSrcScalarPerVector
#### Vector Access for D Tensors
For each D tensor with layouts G_NW_K, G_NHW_K, etc.:
```cpp
K % CDEBlockTransferScalarPerVector_NPerBlock == 0
```
- **Rule**: K must be divisible by CDEBlockTransferScalarPerVector_NPerBlock
- **Rule**: D and E tensors must have identical shapes (all dimensions must match)
#### Vector Access for E (Output) Tensor
For standard layouts:
```cpp
K % CDEBlockTransferScalarPerVector_NPerBlock == 0 // When CTranspose == false
```
For transposed layouts:
```cpp
output_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock == 0
```
#### Transpose Kernel Requirements
For NGCHW/NGCDHW layouts with transpose:
```cpp
(G * C) % CDEBlockTransferScalarPerVector_NPerBlock == 0
(G * K) % CDEBlockTransferScalarPerVector_NPerBlock == 0
input_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock == 0
output_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock == 0
```
- Workspace pointer must be allocated
#### NumGroupsToMerge Constraints
When NumGroupsToMerge > 1:
```cpp
C == 1
G % NumGroupsToMerge == 0
```
- **Rule**: C must equal 1
- **Rule**: G must be divisible by NumGroupsToMerge
#### Tensor Size Constraints
```cpp
a_grid_desc.GetElementSpaceSize() * sizeof(ADataType) <= 2GB
b_grid_desc.GetElementSpaceSize() * sizeof(BDataType) <= 2GB
e_grid_desc.GetElementSpaceSize() * sizeof(EDataType) <= 2GB
```
---
## 3. DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
### Gridwise Implementation
Uses: `GridwiseGemmMultipleD_Wmma`
### Compile-Time Static Asserts
```cpp
static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) &&
(NPerBlock % (NRepeat * NPerWmma)) == 0,
"Invalid tuning param!");
static_assert(KPack % (A_K1 * A_KRow) == 0, "wrong!");
static_assert(KPack % (B_K1 * B_KRow) == 0, "wrong!");
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize");
```
- **Rule**: MPerBlock must be divisible by (MPerWmma × MRepeat)
- **Rule**: NPerBlock must be divisible by (NRepeat × NPerWmma)
- **Rule**: KPack must be divisible by (A_K1 × A_KRow) where A_KRow = 2
- **Rule**: KPack must be divisible by (B_K1 × B_KRow) where B_KRow = 2
- **Rule**: BlockSize must equal MWaves × NWaves × WaveSize
- Where: MWaves = MPerBlock / (MRepeat × MPerWmma)
- Where: NWaves = NPerBlock / (NRepeat × NPerWmma)
### Derived Constraints
```cpp
K % K1 == 0 // Asserted in MakeAGridDescriptor and MakeBGridDescriptor
KPack = math::integer_least_multiple(K1, WmmaK) // Where WmmaK = 16
```
- **Rule**: K must be divisible by K1
- **Rule**: KPack must be at least lcm(K1, 16)
### Runtime Validation Checks
#### Device Support
```cpp
ck::is_gfx11_supported() || ck::is_gfx12_supported()
```
- **Rule**: Only supports gfx11 and gfx12 architectures
- **Rule**: On these devices, AccDataType must be `float` or `int32_t`
#### Vector Access for A
For layouts G_NW_C, G_NHW_C, G_NDHW_C, GNWC, GNHWC, GNDHWC, NWGC, NHWGC, NDHWGC:
```cpp
C % ABlockTransferSrcScalarPerVector == 0 // When ABlockTransferSrcVectorDim == 2
```
#### Vector Access for B
For layouts G_K_X_C, G_K_YX_C, G_K_ZYX_C, GKXC, GKYXC, GKZYXC, KXGC, KYXGC, KZYXGC:
```cpp
C % BBlockTransferSrcScalarPerVector == 0 // When BBlockTransferSrcVectorDim == 2
```
#### Vector Access for D and E
For all D tensors and E tensor:
```cpp
K % CDEShuffleBlockTransferScalarPerVector_NPerBlock == 0
```
---
## 4. DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
### Gridwise Implementation
Uses: `GridwiseGemmMultipleD_xdl_cshuffle`
### Compile-Time Static Asserts
Same as other XDL-based operations:
```cpp
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
static_assert(KPerBlock % AK1Value == 0 && KPerBlock % BK1Value == 0,
"KPerBlock must be divisible by AK1Value and BK1Value!");
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!");
```
### Runtime Validation Checks
#### Tensor Splitting Validation
This operation splits large tensors that exceed 2GB:
```cpp
is_split_valid_ && gemms_count_ == valid_gemms_count_
```
- **Rule**: The tensor splitting algorithm must successfully partition the problem into sub-problems < 2GB
#### D and E Tensor Matching
```cpp
ds_g_n_k_wos_strides_[i] == e_g_n_k_wos_strides_
ds_g_n_k_wos_lengths_[i] == e_g_n_k_wos_lengths_
```
- **Rule**: All D tensors must have identical strides and lengths to E tensor
#### Vector Access Constraints
Same as standard XDL operations:
```cpp
C % ABlockTransferSrcScalarPerVector == 0
C % BBlockTransferSrcScalarPerVector == 0
K % CDEBlockTransferScalarPerVector_NPerBlock == 0
```
---
## 5. DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
### Gridwise Implementation
Uses: `GridwiseGemmDlMultipleD_km_kn_mn`
### Compile-Time Static Asserts
From `BlockwiseGemmDl_v2r3`:
```cpp
static_assert(BM % BM1 == 0 && BN % BN1 == 0, "wrong!");
static_assert(BM0 == 2 && BN0 == 2, "wrong");
static_assert(BlockSize == BM101 * BM100 * BN101 * BN100,
"wrong! blocksize and cluster size not consistent");
```
- **Rule**: BM (MPerBlock) must be divisible by BM1
- **Rule**: BN (NPerBlock) must be divisible by BN1
- **Rule**: BM0 must equal 2
- **Rule**: BN0 must equal 2
- **Rule**: BlockSize must equal the product of thread cluster dimensions
### Runtime Validation Checks
#### Device Support
```cpp
ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
ck::is_gfx103_supported() || ck::is_gfx11_supported() || ck::is_gfx12_supported()
```
- **Rule**: Must be one of: gfx906, gfx103, gfx11, gfx12, or support XDL instructions
#### Vector Transfer Constraints for A
```cpp
srcVectorLengths[I1] == 1 && srcVectorLengths[I2] == 1
K1 % srcVectorLengths[I3] == 0
K0PerBlock % srcVectorLengths[I0] == 0
C % (srcVectorLengths[I0] * srcVectorLengths[I3]) == 0
```
- **Rule**: Vector lengths for M dimensions must be 1
- **Rule**: K1 must be divisible by K1 vector length
- **Rule**: K0PerBlock must be divisible by K0 vector length
- **Rule**: C must be divisible by the product of K0 and K1 vector lengths
#### Vector Transfer Constraints for B
Same structure as A:
```cpp
srcVectorLengths[I1] == 1 && srcVectorLengths[I2] == 1
K1 % srcVectorLengths[I3] == 0
K0PerBlock % srcVectorLengths[I0] == 0
C % (srcVectorLengths[I0] * srcVectorLengths[I3]) == 0
```
#### Vector Access for E (Output)
```cpp
K % CThreadTransferDstScalarPerVector == 0
CThreadTransferSrcDstVectorDim == 5
```
- **Rule**: K must be divisible by CThreadTransferDstScalarPerVector
- **Rule**: Vector dimension must be 5 (the K dimension)
#### Tile Size Constraints
```cpp
M % MPerBlock == 0
N % NPerBlock == 0
K0 % K0PerBlock == 0
```
- **Rule**: M must be divisible by MPerBlock
- **Rule**: N must be divisible by NPerBlock
- **Rule**: K0 must be divisible by K0PerBlock
---
## Common Rules Across All Operations
### Specialization Requirements
For **Filter1x1Stride1Pad0** specialization:
```cpp
FilterSpatialDim == 1
ConvStride == 1
LeftPad == 0
RightPad == 0
```
- Must be true for all spatial dimensions
For **Filter1x1Pad0** specialization:
```cpp
FilterSpatialDim == 1
LeftPad == 0
RightPad == 0
```
- Must be true for all spatial dimensions
For **Filter3x3** specialization:
```cpp
C == 1
FilterSpatialDim == 3
```
- Must be true for all spatial dimensions
### Pipeline Stage Constraints
For non-v1 pipeline versions:
```cpp
num_k_loop > PrefetchStages
```
- **Rule**: Number of K-blocks must exceed the number of prefetch stages
### TF32 Support Constraints
```cpp
is_same_v<AComputeDataType, BComputeDataType> // When using TF32
is_tf32_supported() // Device must support TF32
```
- **Rule**: When using TF32, A and B compute data types must match
- **Rule**: Device must have TF32 support
### XDL/WMMA Support Validation
```cpp
ck::is_xdl_wmma_supported<AComputeDataType, BComputeDataType, MPerXdl, NPerXdl>()
```
- **Rule**: The combination of data types and XDL/WMMA tile sizes must be supported by the device
---
## Summary of Key Parameter Relationships
### Block-Level Tiling
```
MPerBlock = MPerXdl × MXdlPerWave × MWaves
NPerBlock = NPerXdl × NXdlPerWave × NWaves
BlockSize = MWaves × NWaves × WaveSize
```
For WMMA:
```
MPerBlock = MPerWmma × MRepeat × MWaves
NPerBlock = NPerWmma × NRepeat × NWaves
```
### K-Dimension Decomposition
```
K = AK0 × AK1 = BK0 × BK1
KPerBlock = AK0PerBlock × AK1 = BK0PerBlock × BK1
```
### Shuffle Constraints
```
MXdlPerWave = N × CShuffleMXdlPerWavePerShuffle (N is integer)
NXdlPerWave = M × CShuffleNXdlPerWavePerShuffle (M is integer)
```
### Vector Access Hierarchy
1. **Data must be aligned** to vector access size
2. **Dimensions accessed vectorially** must be divisible by ScalarPerVector
3. **Different layouts** have different vectorizable dimensions:
- Row-major A: vectorize K dimension
- Column-major A: vectorize M dimension
- Row-major B: vectorize N dimension
- Column-major B: vectorize K dimension
### LDS Padding
```
ABlockLdsExtraM: Padding for A matrix in LDS to avoid bank conflicts
BBlockLdsExtraN: Padding for B matrix in LDS to avoid bank conflicts
```
- Often set to 1 on gfx950 to reduce bank conflicts
---
## Implementation Notes
1. **Hierarchy**: Device ops compose gridwise ops, which compose blockwise ops, which compose threadwise ops
2. **Memory Flow**: Global → LDS → Register (VGPR) → Compute → Register → LDS → Global
3. **Direct Load**: Some implementations support direct global-to-register load (bypassing LDS) on gfx950
4. **Pipeline Versions**: Different pipeline versions (v1, v2, v3, v4, v5) have different prefetch and scheduling strategies
5. **Multi-AB Support**: Some operations support multiple A/B input tensors (tuples)
6. **Transpose Support**: Some layouts require intermediate transpose operations with workspace allocation
---
## Validation Checklist for Template Parameter Selection
When selecting template parameters, verify:
- [ ] MPerBlock % (MPerXdl × MXdlPerWave) == 0
- [ ] NPerBlock % (NPerXdl × NXdlPerWave) == 0
- [ ] KPerBlock % AK1 == 0 and KPerBlock % BK1 == 0
- [ ] BlockSize == computed_from_waves_and_wave_size
- [ ] All channel/spatial dimensions divisible by respective ScalarPerVector values
- [ ] Tensor descriptors < 2GB each
- [ ] Correct device architecture and data type support
- [ ] Specialization requirements met (filter size, stride, padding)
- [ ] Shuffle parameters properly divide wave parameters
- [ ] Pipeline stage requirements met for chosen version
- [ ] Workspace allocated if using transpose kernels

View File

@@ -0,0 +1,825 @@
# Template Parameter Validation Guide for Forward Convolution Device Operations
This guide maps each template parameter to its constraint rules, enabling upstream validation before instantiation.
## Quick Reference: Template Parameter Names
### Common XDL-based Parameters
- **BlockSize**: Number of threads per block
- **MPerBlock, NPerBlock, KPerBlock**: Block tile sizes for GEMM dimensions
- **AK1, BK1**: K-dimension decomposition (K = K0 × K1)
- **MPerXDL, NPerXDL**: XDL/MFMA instruction tile size
- **MXdlPerWave, NXdlPerWave**: Number of XDL tiles per wave
- **ABlockTransferSrcScalarPerVector**: Vector width for loading A matrix
- **BBlockTransferSrcScalarPerVector**: Vector width for loading B matrix
- **ABlockTransferDstScalarPerVector_AK1**: Vector width for storing A to LDS
- **BBlockTransferDstScalarPerVector_BK1**: Vector width for storing B to LDS
- **CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle**: C-shuffle granularity
- **CDEBlockTransferScalarPerVector_NPerBlock**: Vector width for C/D/E transfers
- **ABlockLdsExtraM, BBlockLdsExtraN**: LDS padding to avoid bank conflicts
### WMMA-specific Parameters
- **MPerWmma, NPerWmma**: WMMA instruction tile size (typically 16×16)
- **K1**: K-dimension granularity
- **MRepeat, NRepeat**: Number of WMMA tiles per wave
- **CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle**: C-shuffle granularity
### DL-specific Parameters
- **K0PerBlock**: K0 dimension of block tile
- **K1**: K1 value (typically 4 or 8)
- **M1PerThread, N1PerThread**: Thread tile size
- **KPerThread**: K dimension per thread
---
## 1. DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
### Template Declaration
```cpp
template <
index_t NDimSpatial,
typename ALayout, typename BLayout, typename DsLayout, typename ELayout,
typename ADataType, typename BDataType, typename AccDataType,
typename CShuffleDataType, typename DsDataType, typename EDataType,
typename AElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation,
ConvolutionForwardSpecialization ConvForwardSpecialization,
GemmSpecialization GemmSpec,
index_t BlockSize,
index_t MPerBlock, index_t NPerBlock, index_t KPerBlock,
index_t AK1, index_t BK1,
index_t MPerXDL, index_t NPerXDL,
index_t MXdlPerWave, index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
typename AComputeDataType = ...,
typename BComputeDataType = AComputeDataType,
bool DirectLoad = false
>
```
### Compile-Time Constraints (Can Check Before Instantiation)
#### Rule 1: Block Tiling Divisibility
```cpp
CONSTRAINT: MPerBlock % (MPerXDL * MXdlPerWave) == 0
PARAMETERS: MPerBlock, MPerXDL, MXdlPerWave
```
#### Rule 2: Block Tiling Divisibility (N dimension)
```cpp
CONSTRAINT: NPerBlock % (NXdlPerWave * NPerXDL) == 0
PARAMETERS: NPerBlock, NPerXDL, NXdlPerWave
```
#### Rule 3: K-dimension Decomposition
```cpp
CONSTRAINT: KPerBlock % AK1 == 0
CONSTRAINT: KPerBlock % BK1 == 0
PARAMETERS: KPerBlock, AK1, BK1
```
#### Rule 4: Shuffle Granularity (M dimension)
```cpp
CONSTRAINT: MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0
PARAMETERS: MXdlPerWave, CShuffleMXdlPerWavePerShuffle
```
#### Rule 5: Shuffle Granularity (N dimension)
```cpp
CONSTRAINT: NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0
PARAMETERS: NXdlPerWave, CShuffleNXdlPerWavePerShuffle
```
#### Rule 6: Derived - Wave Count and BlockSize
```cpp
DERIVED: MWaves = MPerBlock / (MPerXDL * MXdlPerWave)
DERIVED: NWaves = NPerBlock / (NPerXDL * NXdlPerWave)
DERIVED: WaveSize = 64 (or 32 for some architectures)
CONSTRAINT: BlockSize == MWaves * NWaves * WaveSize
PARAMETERS: BlockSize, MPerBlock, NPerBlock, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave
```
#### Rule 7: Derived - KPerThread and KPack
```cpp
DERIVED: KPack = max(lcm(AK1, BK1), MFMA_k_per_blk)
DERIVED: KPerThread = KPerBlock / (KPack)
CONSTRAINT: KPerThread % KPack == 0
PARAMETERS: KPerBlock, AK1, BK1, MPerXDL, NPerXDL (MFMA size affects KPack)
```
#### Rule 8: XDL Support
```cpp
CONSTRAINT: is_xdl_wmma_supported<AComputeDataType, BComputeDataType, MPerXDL, NPerXDL>()
PARAMETERS: AComputeDataType, BComputeDataType, MPerXDL, NPerXDL
```
#### Rule 9: TF32 Constraints
```cpp
CONSTRAINT: IF (AComputeDataType == tf32 OR BComputeDataType == tf32)
THEN AComputeDataType == BComputeDataType
PARAMETERS: AComputeDataType, BComputeDataType
```
#### Rule 10: DirectLoad Limitation
```cpp
CONSTRAINT: IF DirectLoad == true THEN device == "gfx950"
CONSTRAINT: IF DirectLoad == true THEN
AElementwiseOperation == PassThrough AND
BElementwiseOperation == PassThrough
PARAMETERS: DirectLoad, AElementwiseOperation, BElementwiseOperation
```
### Upstream Validation Function Template
```cpp
template<typename DeviceOp>
struct TemplateParameterValidator {
static constexpr bool IsValid() {
// Rule 1: MPerBlock divisibility
if constexpr(DeviceOp::MPerBlock %
(DeviceOp::MPerXDL * DeviceOp::MXdlPerWave) != 0)
return false;
// Rule 2: NPerBlock divisibility
if constexpr(DeviceOp::NPerBlock %
(DeviceOp::NXdlPerWave * DeviceOp::NPerXDL) != 0)
return false;
// Rule 3: KPerBlock divisibility
if constexpr(DeviceOp::KPerBlock % DeviceOp::AK1 != 0 ||
DeviceOp::KPerBlock % DeviceOp::BK1 != 0)
return false;
// Rule 4-5: Shuffle constraints
if constexpr(DeviceOp::MXdlPerWave % DeviceOp::CShuffleMXdlPerWavePerShuffle != 0)
return false;
if constexpr(DeviceOp::NXdlPerWave % DeviceOp::CShuffleNXdlPerWavePerShuffle != 0)
return false;
// Rule 6: BlockSize validation
constexpr auto MWaves = DeviceOp::MPerBlock /
(DeviceOp::MPerXDL * DeviceOp::MXdlPerWave);
constexpr auto NWaves = DeviceOp::NPerBlock /
(DeviceOp::NPerXDL * DeviceOp::NXdlPerWave);
constexpr auto WaveSize = 64; // Adjust based on architecture
if constexpr(DeviceOp::BlockSize != MWaves * NWaves * WaveSize)
return false;
// Additional checks...
return true;
}
};
```
---
## 2. DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
### Template Declaration
```cpp
template <
index_t NDimSpatial,
typename ALayout, typename BLayout, typename DsLayout, typename ELayout,
typename ADataType, typename BDataType, typename AccDataType,
typename CShuffleDataType, typename DsDataType, typename EDataType,
typename AElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation,
ConvolutionForwardSpecialization ConvForwardSpecialization,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock, index_t NPerBlock, index_t KPerBlock,
index_t AK1, index_t BK1,
index_t MPerXDL, index_t NPerXDL,
index_t MXdlPerWave, index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
typename AComputeDataType = ...,
typename BComputeDataType = AComputeDataType,
LoopScheduler LoopSched = make_default_loop_scheduler(),
index_t NumGroupsToMerge = 1
>
```
### Compile-Time Constraints
Same as V3 (Rules 1-9 above) plus:
#### Rule 11: NumGroupsToMerge (compile-time checkable when C=1 is known)
```cpp
CONSTRAINT: IF NumGroupsToMerge > 1 THEN must use specific layouts
PARAMETERS: NumGroupsToMerge, ALayout, BLayout, ELayout
REQUIRED_LAYOUTS: NSpatialGC_GKSpatial_NSpatialGK OR
NGCSpatial_GKSpatial_NGKSpatial OR
NGCHW_NGKHW OR NGCDHW_NGKDHW
```
---
## 3. DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
### Template Declaration
```cpp
template <
index_t NDimSpatial,
typename ALayout, typename BLayout, typename DsLayout, typename ELayout,
typename ADataType, typename BDataType, typename AccDataType,
typename CShuffleDataType, typename DsDataType, typename EDataType,
typename AElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation,
ConvolutionForwardSpecialization ConvForwardSpecialization,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock, index_t NPerBlock, index_t KPerBlock,
index_t K1,
index_t MPerWmma, index_t NPerWmma,
index_t MRepeat, index_t NRepeat,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v1
>
```
### Compile-Time Constraints
#### Rule 12: WMMA Block Tiling (M dimension)
```cpp
CONSTRAINT: MPerBlock % (MPerWmma * MRepeat) == 0
PARAMETERS: MPerBlock, MPerWmma, MRepeat
```
#### Rule 13: WMMA Block Tiling (N dimension)
```cpp
CONSTRAINT: NPerBlock % (NPerWmma * NRepeat) == 0
PARAMETERS: NPerBlock, NPerWmma, NRepeat
```
#### Rule 14: WMMA Wave Count and BlockSize
```cpp
DERIVED: MWaves = MPerBlock / (MPerWmma * MRepeat)
DERIVED: NWaves = NPerBlock / (NPerWmma * NRepeat)
DERIVED: WaveSize = 32 (for WMMA architectures)
CONSTRAINT: BlockSize == MWaves * NWaves * WaveSize
PARAMETERS: BlockSize, MPerBlock, NPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat
```
#### Rule 15: WMMA KPack Constraints
```cpp
DERIVED: WmmaK = (K1 == 16) ? 32 : 16
DERIVED: KPack = math::integer_least_multiple(K1, WmmaK)
CONSTRAINT: KPack % (K1 * 2) == 0 // A_KRow = 2
CONSTRAINT: KPack % (K1 * 2) == 0 // B_KRow = 2
PARAMETERS: K1, KPerBlock
NOTE: KPerBlock should be chosen such that resulting KPack satisfies these
```
#### Rule 16: Device Architecture for WMMA
```cpp
CONSTRAINT: is_gfx11_supported() OR is_gfx12_supported()
CONSTRAINT: AccDataType == float OR AccDataType == int32_t
PARAMETERS: AccDataType
DEVICE: Must be gfx11 or gfx12
```
---
## 4. DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
### Template Declaration
Same as standard XDL version (similar to #2 but without NumGroupsToMerge)
### Compile-Time Constraints
Rules 1-10 apply (same as DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle)
### Special Consideration
This operation handles tensors > 2GB by splitting, so descriptor size constraints are managed internally through splitting algorithm.
---
## 5. DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
### Template Declaration
```cpp
template <
index_t NDimSpatial,
typename ADataType, typename BDataType,
typename DsDataType, typename EDataType, typename AccDataType,
typename ALayout, typename BLayout, typename DsLayout, typename ELayout,
typename AElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation,
ConvolutionForwardSpecialization ConvForwardSpecialization,
GemmSpecialization GemmSpec,
index_t BlockSize,
index_t MPerBlock, index_t NPerBlock,
index_t K0PerBlock, index_t K1,
index_t M1PerThread, index_t N1PerThread, index_t KPerThread,
typename M1N1ThreadClusterM1Xs,
typename M1N1ThreadClusterN1Xs,
typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector
>
```
### Compile-Time Constraints
#### Rule 17: DL Thread Cluster Constraints
```cpp
DERIVED: BM1 = M1PerThread
DERIVED: BN1 = N1PerThread
DERIVED: BM = MPerBlock
DERIVED: BN = NPerBlock
CONSTRAINT: BM % BM1 == 0
CONSTRAINT: BN % BN1 == 0
PARAMETERS: MPerBlock, NPerBlock, M1PerThread, N1PerThread
```
#### Rule 18: DL Grid Decomposition
```cpp
CONSTRAINT: BM0 == 2
CONSTRAINT: BN0 == 2
DERIVED: BM0, BN0 are derived from thread cluster configuration
PARAMETERS: M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs
NOTE: Thread cluster must result in BM0=2, BN0=2
```
#### Rule 19: DL BlockSize Validation
```cpp
DERIVED: BM101, BM100, BN101, BN100 from thread cluster
CONSTRAINT: BlockSize == BM101 * BM100 * BN101 * BN100
PARAMETERS: BlockSize, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs
```
#### Rule 20: DL Vector Transfer Dimensions
```cpp
CONSTRAINT: ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1[I1] == 1
CONSTRAINT: ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1[I2] == 1
CONSTRAINT: BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1[I1] == 1
CONSTRAINT: BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1[I2] == 1
PARAMETERS: ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
```
#### Rule 21: DL Output Vector Dimension
```cpp
CONSTRAINT: CThreadTransferSrcDstVectorDim == 5
PARAMETERS: CThreadTransferSrcDstVectorDim
```
---
## Consolidated Validation Matrix
### For XDL-based Operations (V3, Standard, Large_Tensor)
| Parameter | Constraint | Depends On | Rule # |
|-----------|------------|------------|--------|
| MPerBlock | % (MPerXDL × MXdlPerWave) == 0 | MPerXDL, MXdlPerWave | 1 |
| NPerBlock | % (NXdlPerWave × NPerXDL) == 0 | NPerXDL, NXdlPerWave | 2 |
| KPerBlock | % AK1 == 0 AND % BK1 == 0 | AK1, BK1 | 3 |
| MXdlPerWave | % CShuffleMXdlPerWavePerShuffle == 0 | CShuffleMXdlPerWavePerShuffle | 4 |
| NXdlPerWave | % CShuffleNXdlPerWavePerShuffle == 0 | CShuffleNXdlPerWavePerShuffle | 5 |
| BlockSize | == MWaves × NWaves × WaveSize | All M/N tiling params | 6 |
| KPerThread | % KPack == 0 | KPerBlock, AK1, BK1 | 7 |
| MPerXDL, NPerXDL | is_xdl_supported(...) | AComputeDataType, BComputeDataType | 8 |
| AComputeDataType | == BComputeDataType (if TF32) | BComputeDataType | 9 |
| DirectLoad | Requires gfx950 & PassThrough ops | Device, ElementwiseOps | 10 |
### For WMMA-based Operations
| Parameter | Constraint | Depends On | Rule # |
|-----------|------------|------------|--------|
| MPerBlock | % (MPerWmma × MRepeat) == 0 | MPerWmma, MRepeat | 12 |
| NPerBlock | % (NPerWmma × NRepeat) == 0 | NPerWmma, NRepeat | 13 |
| BlockSize | == MWaves × NWaves × WaveSize | All M/N tiling params | 14 |
| KPack | % (K1 × 2) == 0 (for A and B) | K1 | 15 |
| AccDataType | == float OR == int32_t | - | 16 |
| Device | gfx11 or gfx12 only | - | 16 |
### For DL-based Operations
| Parameter | Constraint | Depends On | Rule # |
|-----------|------------|------------|--------|
| MPerBlock | % M1PerThread == 0 | M1PerThread | 17 |
| NPerBlock | % N1PerThread == 0 | N1PerThread | 17 |
| Thread Cluster | Must result in BM0=2, BN0=2 | M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs | 18 |
| BlockSize | == BM101 × BM100 × BN101 × BN100 | Thread cluster config | 19 |
| Vector Lengths | M/N dimensions == 1 | ABlockTransferSrcVectorTensorLengths | 20 |
| CThreadTransferSrcDstVectorDim | == 5 | - | 21 |
---
## Upstream Validation Strategy
### Phase 1: Compile-Time Template Validation
Create a validation struct that can be checked at compile-time:
```cpp
template<typename DeviceOp>
struct DeviceOpTemplateValidator {
// Extract template parameters as constexpr values
static constexpr auto BlockSize = DeviceOp::BlockSize;
static constexpr auto MPerBlock = DeviceOp::MPerBlock;
static constexpr auto NPerBlock = DeviceOp::NPerBlock;
static constexpr auto KPerBlock = DeviceOp::KPerBlock;
// ... extract all other parameters
// Validate each rule
static constexpr bool ValidateRule1() {
if constexpr(is_xdl_based) {
return MPerBlock % (MPerXDL * MXdlPerWave) == 0;
}
return true; // N/A for non-XDL
}
static constexpr bool ValidateRule2() {
if constexpr(is_xdl_based) {
return NPerBlock % (NXdlPerWave * NPerXDL) == 0;
}
return true;
}
// ... all rules
static constexpr bool IsValid() {
return ValidateRule1() && ValidateRule2() && /* ... all rules */;
}
// Optional: Generate error messages
static constexpr const char* GetErrorMessage() {
if constexpr(!ValidateRule1())
return "MPerBlock not divisible by (MPerXDL * MXdlPerWave)";
if constexpr(!ValidateRule2())
return "NPerBlock not divisible by (NXdlPerWave * NPerXDL)";
// ... etc
return "Valid";
}
};
// Usage:
static_assert(DeviceOpTemplateValidator<MyDeviceOp>::IsValid(),
DeviceOpTemplateValidator<MyDeviceOp>::GetErrorMessage());
```
### Phase 2: Runtime Problem-Size Validation
After template instantiation, validate against actual problem dimensions:
```cpp
struct RuntimeValidator {
// Check vector access divisibility
static bool CheckVectorAccess(
index_t C, index_t K, index_t G,
index_t ABlockTransferSrcScalarPerVector,
index_t BBlockTransferSrcScalarPerVector,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
const Layout& layouts)
{
// Check C divisibility for A and B
if (layouts.AUsesChannelVectorization())
if (C % ABlockTransferSrcScalarPerVector != 0)
return false;
if (layouts.BUsesChannelVectorization())
if (C % BBlockTransferSrcScalarPerVector != 0)
return false;
// Check K divisibility for E
if (K % CDEBlockTransferScalarPerVector_NPerBlock != 0)
return false;
return true;
}
// Check tile size divisibility
static bool CheckTileSizes(
index_t M, index_t N, index_t K,
index_t MPerBlock, index_t NPerBlock, index_t KPerBlock)
{
return (M % MPerBlock == 0) &&
(N % NPerBlock == 0) &&
(K % KPerBlock == 0);
}
// Check descriptor size limits
static bool CheckDescriptorSizes(
index_t M, index_t N, index_t K,
size_t ADataTypeSize, size_t BDataTypeSize, size_t EDataTypeSize)
{
constexpr long_index_t TwoGB = (1L << 31);
return (M * K * ADataTypeSize <= TwoGB) &&
(N * K * BDataTypeSize <= TwoGB) &&
(M * N * EDataTypeSize <= TwoGB);
}
};
```
---
## Complete Validation Code Template
```cpp
#pragma once
#include <type_traits>
namespace ck {
namespace validation {
// Trait to detect operation type
template<typename T> struct is_xdl_based : std::false_type {};
template<typename T> struct is_wmma_based : std::false_type {};
template<typename T> struct is_dl_based : std::false_type {};
// Specialize for each operation type...
template<typename DeviceOp>
struct TemplateParameterValidator {
// Extract parameters
static constexpr auto BlockSize = DeviceOp::BlockSize;
static constexpr auto MPerBlock = DeviceOp::MPerBlock;
static constexpr auto NPerBlock = DeviceOp::NPerBlock;
static constexpr auto KPerBlock = DeviceOp::KPerBlock;
// XDL-specific validation
template<typename T = DeviceOp>
static constexpr bool ValidateXDL() {
if constexpr(is_xdl_based<T>::value) {
constexpr auto MPerXDL = T::MPerXDL;
constexpr auto NPerXDL = T::NPerXDL;
constexpr auto MXdlPerWave = T::MXdlPerWave;
constexpr auto NXdlPerWave = T::NXdlPerWave;
constexpr auto AK1 = T::AK1;
constexpr auto BK1 = T::BK1;
constexpr auto CShuffleMXdlPerWavePerShuffle = T::CShuffleMXdlPerWavePerShuffle;
constexpr auto CShuffleNXdlPerWavePerShuffle = T::CShuffleNXdlPerWavePerShuffle;
// Rule 1
if constexpr(MPerBlock % (MPerXDL * MXdlPerWave) != 0)
return false;
// Rule 2
if constexpr(NPerBlock % (NXdlPerWave * NPerXDL) != 0)
return false;
// Rule 3
if constexpr(KPerBlock % AK1 != 0 || KPerBlock % BK1 != 0)
return false;
// Rule 4
if constexpr(MXdlPerWave % CShuffleMXdlPerWavePerShuffle != 0)
return false;
// Rule 5
if constexpr(NXdlPerWave % CShuffleNXdlPerWavePerShuffle != 0)
return false;
// Rule 6
constexpr auto MWaves = MPerBlock / (MPerXDL * MXdlPerWave);
constexpr auto NWaves = NPerBlock / (NPerXDL * NXdlPerWave);
constexpr auto WaveSize = 64; // Adjust based on arch
if constexpr(BlockSize != MWaves * NWaves * WaveSize)
return false;
return true;
}
return true; // N/A
}
// WMMA-specific validation
template<typename T = DeviceOp>
static constexpr bool ValidateWMMA() {
if constexpr(is_wmma_based<T>::value) {
constexpr auto MPerWmma = T::MPerWmma;
constexpr auto NPerWmma = T::NPerWmma;
constexpr auto MRepeat = T::MRepeat;
constexpr auto NRepeat = T::NRepeat;
constexpr auto K1 = T::K1;
// Rule 12
if constexpr(MPerBlock % (MPerWmma * MRepeat) != 0)
return false;
// Rule 13
if constexpr(NPerBlock % (NPerWmma * NRepeat) != 0)
return false;
// Rule 14
constexpr auto MWaves = MPerBlock / (MPerWmma * MRepeat);
constexpr auto NWaves = NPerBlock / (NPerWmma * NRepeat);
constexpr auto WaveSize = 32;
if constexpr(BlockSize != MWaves * NWaves * WaveSize)
return false;
// Rule 15 - KPack constraints
constexpr auto WmmaK = (K1 == 16) ? 32 : 16;
// KPack computation is complex, may need runtime check
return true;
}
return true; // N/A
}
// DL-specific validation
template<typename T = DeviceOp>
static constexpr bool ValidateDL() {
if constexpr(is_dl_based<T>::value) {
constexpr auto M1PerThread = T::M1PerThread;
constexpr auto N1PerThread = T::N1PerThread;
// Rule 17
if constexpr(MPerBlock % M1PerThread != 0)
return false;
if constexpr(NPerBlock % N1PerThread != 0)
return false;
// Rules 18-19 require analyzing thread cluster which is complex
// May need partial compile-time, partial runtime validation
return true;
}
return true; // N/A
}
// Master validation
static constexpr bool IsValid() {
return ValidateXDL<DeviceOp>() &&
ValidateWMMA<DeviceOp>() &&
ValidateDL<DeviceOp>();
}
};
// Runtime validation for problem-dependent constraints
template<typename DeviceOp>
struct RuntimeParameterValidator {
static bool Validate(
index_t G, index_t N, index_t C, index_t K,
const std::array<index_t, DeviceOp::NDimSpatial>& spatial_dims_in,
const std::array<index_t, DeviceOp::NDimSpatial>& spatial_dims_out)
{
// Calculate spatial products
index_t input_spatial_acum = 1;
index_t output_spatial_acum = 1;
for(index_t i = 0; i < DeviceOp::NDimSpatial; ++i) {
input_spatial_acum *= spatial_dims_in[i];
output_spatial_acum *= spatial_dims_out[i];
}
// Check vector access based on layout
if constexpr(DeviceOp::UsesChannelVectorForA()) {
if(C % DeviceOp::ABlockTransferSrcScalarPerVector != 0)
return false;
}
if constexpr(DeviceOp::UsesChannelVectorForB()) {
if(C % DeviceOp::BBlockTransferSrcScalarPerVector != 0)
return false;
}
if(K % DeviceOp::CDEBlockTransferScalarPerVector_NPerBlock != 0)
return false;
// Check tile divisibility
const auto MPadded = DeviceOp::CalculateMPadded(/* M calculated from problem */);
const auto NPadded = DeviceOp::CalculateNPadded(/* N calculated from problem */);
const auto KPadded = DeviceOp::CalculateKPadded(/* K calculated from problem */);
// Additional checks based on specialization...
return true;
}
};
} // namespace validation
} // namespace ck
```
---
## Usage Example
```cpp
// At compile-time (before instantiation)
using MyDeviceOp = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
2, // NDimSpatial
/* ... layouts ... */
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
16, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXDL
32, // NPerXDL
2, // MXdlPerWave
2, // NXdlPerWave
/* ... other params ... */
>;
// Compile-time validation
static_assert(ck::validation::TemplateParameterValidator<MyDeviceOp>::IsValid(),
"Invalid template parameters for device operation!");
// At runtime (when problem sizes are known)
bool is_valid = ck::validation::RuntimeParameterValidator<MyDeviceOp>::Validate(
G, N, C, K, input_spatial_dims, output_spatial_dims);
```
---
## Parameter Selection Guidelines
### Step 1: Choose Architecture-Specific Base Values
- **XDL**: MPerXDL=32, NPerXDL=32 (for fp16/bf16) or MPerXDL=16, NPerXDL=16 (for int8/fp8)
- **WMMA**: MPerWmma=16, NPerWmma=16
### Step 2: Determine Wave Configuration
- Calculate desired waves: MWaves × NWaves
- Ensure BlockSize = MWaves × NWaves × WaveSize
### Step 3: Calculate Block Tile Sizes
- **XDL**: MPerBlock = MPerXDL × MXdlPerWave × MWaves
- **WMMA**: MPerBlock = MPerWmma × MRepeat × MWaves
### Step 4: Choose K Decomposition
- Select AK1, BK1 (typically 4, 8, or 16)
- Ensure KPerBlock % AK1 == 0 and KPerBlock % BK1 == 0
### Step 5: Choose Shuffle Parameters
- CShuffleMXdlPerWavePerShuffle ≤ MXdlPerWave
- Ensure MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0
### Step 6: Select Vector Transfer Widths
- Match to memory access pattern and data type size
- Ensure alignment with channel/output dimensions
### Step 7: Validate Against All Rules
- Run compile-time validator
- Test with representative problem sizes