From 6a116fa958b93bb0efce452252015b8f0616404c Mon Sep 17 00:00:00 2001 From: apoorva Date: Tue, 1 Jul 2025 10:21:25 +0000 Subject: [PATCH] Modified the template parameters to make the instances work. --- ...16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp | 8 +++---- ...e_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp | 24 +++++++++---------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp index cb32f2d118..58a42ae223 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp @@ -28,7 +28,7 @@ using device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_generi //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, BF16, F32, F32,BF16_Tuple, BF16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 1> + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, BF16, F32, F32, BF16_Tuple, BF16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 512, 64, 512, 32, 8, 16, 16, 4, 2, S<4, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 16, 1, 4>, 8> // clang-format on >; @@ -39,9 +39,9 @@ using device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instan //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | Wmma|Wmma| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, BF16, F32, F32,BF16_Tuple, BF16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 256, 16, 128, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, BF16, F32, F32,BF16_Tuple, BF16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, BF16, F32, F32,BF16_Tuple, BF16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, BF16, F32, F32, BF16_Tuple, BF16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 512, 64, 512, 32, 8, 16, 16, 4, 2, S<4, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, BF16, F32, F32, BF16_Tuple, BF16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 128, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, BF16, F32, F32, BF16_Tuple, BF16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 128, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp index 0dd19c6916..0e0fa7497f 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp @@ -28,41 +28,41 @@ using device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_generic_in //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32,F16_Tuple, F16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 1> + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 512, 64, 512, 32, 8, 16, 16, 4, 2, S<4, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 16, 1, 4>, 8> // clang-format on >; -using device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances = std::tuple< +using device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances = std::tuple< // clang-format off // M/N/K padding //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | Wmma|Wmma| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32,F16_Tuple, F16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 256, 16, 128, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32,F16_Tuple, F16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32,F16_Tuple, F16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 512, 64, 512, 32, 8, 16, 16, 4, 2, S<4, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 128, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 128, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1> // clang-format on >; -void add_device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances( +void add_device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances( std::vector>>& instances) { add_device_operation_instances( instances, - device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_generic_instances{}); + device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_generic_instances{}); add_device_operation_instances( - instances, device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances{}); + instances, device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances{}); } } // namespace instance