From 54e7d86ee2554128b50d9a5a9b1b89936e431518 Mon Sep 17 00:00:00 2001 From: "assistant-librarian[bot]" Date: Wed, 7 Jan 2026 16:16:37 +0000 Subject: [PATCH] Merge commit 'a7d6b1e7008c0b6e1af8a7d79389aefbdca4da65' into develop --- .../gemm_bilinear_wmma_fp16.cpp | 2 +- .../gemm_bilinear_wmma_int8.cpp | 2 +- example/12_reduce/reduce_blockwise.cpp | 2 +- .../reduce_multiblock_atomic_add.cpp | 2 +- .../12_reduce/reduce_threadwise_multi_d.cpp | 2 +- example/13_pool2d_fwd/pool2d_fwd_fp16.cpp | 2 +- ...rouped_gemm_multiple_d_splitk_xdl_fp16.cpp | 2 +- .../grouped_gemm_multiple_d_xdl_fp16.cpp | 2 +- .../gemm_add_add_mean_meansquare_xdl_fp16.cpp | 2 +- .../gemm_add_addsquare_xdl_int8.cpp | 2 +- .../gemm_max_xdl_bf16.cpp | 2 +- .../gemm_max_xdl_fp16.cpp | 2 +- .../gemm_max_xdl_fp32.cpp | 2 +- .../gemm_max_xdl_int4.cpp | 2 +- .../gemm_max_xdl_int8.cpp | 2 +- .../gemm_mean_meansquare_xdl_bf16.cpp | 2 +- .../gemm_mean_meansquare_xdl_fp16.cpp | 2 +- .../gemm_mean_meansquare_xdl_fp32.cpp | 2 +- example/22_cgemm/cgemm_xdl_int4.cpp | 2 +- example/23_softmax/softmax_blockwise.cpp | 2 +- ..._batched_gemm_example_fp16int4_b_scale.inc | 2 +- .../batched_gemm_bias_e_permute_wmma_fp16.cpp | 2 +- .../30_grouped_conv_fwd_multiple_d/common.hpp | 2 +- .../common_wmma.hpp | 2 +- .../33_multiple_reduce/dual_reduce_common.hpp | 2 +- example/35_splitK_gemm/common.hpp | 2 +- .../sparse_embedding3_forward_layernorm.cpp | 2 +- ...ed_gemm_add_add_relu_gemm_add_xdl_fp16.cpp | 2 +- ...bias_relu_perchannel_quantization_int8.cpp | 2 +- ...l_bias_relu_perlayer_quantization_int8.cpp | 2 +- ...bias_tanh_perchannel_quantization_int8.cpp | 2 +- ...l_bias_tanh_perlayer_quantization_int8.cpp | 2 +- ...2d_fwd_dl_perchannel_quantization_int8.cpp | 2 +- ...bias_relu_perchannel_quantization_int8.cpp | 2 +- ...l_bias_relu_perlayer_quantization_int8.cpp | 2 +- ...d_fwd_xdl_perchannel_quantization_int8.cpp | 2 +- .../run_groupnorm_fwd_example.inc | 2 +- .../elementwise_binary_4D_fp16.cpp | 2 +- .../elementwise_permute_4D_fp16.cpp | 2 +- .../elementwise_permute_4D_fp16_col.cpp | 2 +- .../elementwise_permute_4D_fp16_row.cpp | 2 +- .../elementwise_permute_4D_fp32_col.cpp | 2 +- .../elementwise_permute_4D_fp32_row.cpp | 2 +- ...entwise_scale_permute_amax_2D_fp16_fp8.cpp | 2 +- .../elementwise_trinary_4D_fp16.cpp | 2 +- .../elementwise_layernorm_blockwise.cpp | 2 +- .../moe_gemm1_xdl_fp8.cpp | 2 +- .../moe_gemm1_xdl_fp8_blockscale.cpp | 2 +- .../moe_gemm1_xdl_pk_i4.cpp | 2 +- .../moe_gemm2_xdl_fp8.cpp | 2 +- .../moe_gemm2_xdl_fp8_blockscale.cpp | 2 +- .../moe_gemm2_xdl_pk_i4.cpp | 2 +- .../moe_gemm1_xdl_mx_fp4.cpp | 2 +- .../moe_gemm1_xdl_mx_fp4_bns.cpp | 2 +- .../moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp | 2 +- .../moe_gemm2_xdl_mx_fp4.cpp | 2 +- .../moe_gemm2_xdl_mx_fp4_bns.cpp | 2 +- .../moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp | 2 +- experimental/builder/test/CMakeLists.txt | 3 +- .../conv/ck/unit_instance_to_conv_traits.cpp | 1128 +++++++++++++++++ .../test_batched_gemm_multi_d_dl.cpp | 2 +- test/gemm/gemm_standalone_xdl_fp16.cpp | 2 +- test/wrapper/test_wrapper_gemm_xdl.cpp | 2 +- 63 files changed, 1191 insertions(+), 62 deletions(-) create mode 100644 experimental/builder/test/conv/ck/unit_instance_to_conv_traits.cpp diff --git a/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp b/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp index 0bded7d2ac..9b48d5765d 100644 --- a/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp +++ b/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp @@ -119,7 +119,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; // GEMM shape ck::index_t M = 3840; diff --git a/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp b/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp index 4acf4fe9ff..a770bf5c77 100644 --- a/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp +++ b/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp @@ -119,7 +119,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; // GEMM shape ck::index_t M = 3840; diff --git a/example/12_reduce/reduce_blockwise.cpp b/example/12_reduce/reduce_blockwise.cpp index 55f3d99823..f8299028da 100644 --- a/example/12_reduce/reduce_blockwise.cpp +++ b/example/12_reduce/reduce_blockwise.cpp @@ -31,7 +31,7 @@ class SimpleAppArgs bool do_verification = true; int data_type = 1; int init_method = 2; - bool time_kernel = true; + bool time_kernel = false; public: void show_usage(const char* cmd) diff --git a/example/12_reduce/reduce_multiblock_atomic_add.cpp b/example/12_reduce/reduce_multiblock_atomic_add.cpp index af5903f83c..66fc2bb582 100644 --- a/example/12_reduce/reduce_multiblock_atomic_add.cpp +++ b/example/12_reduce/reduce_multiblock_atomic_add.cpp @@ -31,7 +31,7 @@ class SimpleAppArgs bool do_verification = true; int data_type = 1; int init_method = 2; - bool time_kernel = true; + bool time_kernel = false; public: void show_usage(const char* cmd) diff --git a/example/12_reduce/reduce_threadwise_multi_d.cpp b/example/12_reduce/reduce_threadwise_multi_d.cpp index e77daea212..ee06395771 100644 --- a/example/12_reduce/reduce_threadwise_multi_d.cpp +++ b/example/12_reduce/reduce_threadwise_multi_d.cpp @@ -31,7 +31,7 @@ class SimpleAppArgs bool do_verification = true; int data_type = 1; int init_method = 2; - bool time_kernel = true; + bool time_kernel = false; public: void show_usage(const char* cmd) diff --git a/example/13_pool2d_fwd/pool2d_fwd_fp16.cpp b/example/13_pool2d_fwd/pool2d_fwd_fp16.cpp index f0a9ce9270..fc083ba3e2 100644 --- a/example/13_pool2d_fwd/pool2d_fwd_fp16.cpp +++ b/example/13_pool2d_fwd/pool2d_fwd_fp16.cpp @@ -53,7 +53,7 @@ int main(int argc, char* argv[]) { do_verification = true; init_method = 1; - time_kernel = true; + time_kernel = false; } else if(argc == 4) { diff --git a/example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp index 62d2022084..6fe285f165 100644 --- a/example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp @@ -90,7 +90,7 @@ struct ExecutionConfig final bool do_verification = true; int init_method = 1; int k_batch = 128; - bool time_kernel = true; + bool time_kernel = false; }; bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) diff --git a/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp index 1db8a9defb..0e1a38b19a 100644 --- a/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp @@ -89,7 +89,7 @@ struct ExecutionConfig final { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; }; bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_add_add_mean_meansquare_xdl_fp16.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_add_add_mean_meansquare_xdl_fp16.cpp index 08915fdd26..a30bedf282 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_add_add_mean_meansquare_xdl_fp16.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_add_add_mean_meansquare_xdl_fp16.cpp @@ -268,7 +268,7 @@ int main() pass &= ck::utils::check_err(r1_m, r1_m_host, "Error: Incorrect results d1", 1e-2, 1e-2); } - bool time_kernel = true; + bool time_kernel = false; if(time_kernel) { float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_add_addsquare_xdl_int8.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_add_addsquare_xdl_int8.cpp index 7a81d82c25..3401494625 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_add_addsquare_xdl_int8.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_add_addsquare_xdl_int8.cpp @@ -302,7 +302,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; // GEMM shape ck::index_t M = 1024; diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_bf16.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_bf16.cpp index 5a127d1cd4..e4960668eb 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_bf16.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_bf16.cpp @@ -106,7 +106,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; // GEMM shape ck::index_t M = 1024; diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp16.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp16.cpp index 29be3dde0a..c97fa7ebc5 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp16.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp16.cpp @@ -106,7 +106,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; // GEMM shape ck::index_t M = 1024; diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp32.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp32.cpp index 0574488e04..f32d5e9f6d 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp32.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp32.cpp @@ -106,7 +106,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; // GEMM shape ck::index_t M = 1024; diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_int4.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_int4.cpp index 7da40adc90..6c9fb8da75 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_int4.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_int4.cpp @@ -108,7 +108,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; // GEMM shape ck::index_t M = 1024; diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_int8.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_int8.cpp index 47f1d50ef5..4a63bee894 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_int8.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_int8.cpp @@ -105,7 +105,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; // GEMM shape ck::index_t M = 1024; diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_bf16.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_bf16.cpp index cac3db3078..ebd71f1799 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_bf16.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_bf16.cpp @@ -112,7 +112,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; // GEMM shape ck::index_t M = 1024; diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp16.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp16.cpp index 5ea09cfab2..1153a66615 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp16.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp16.cpp @@ -112,7 +112,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; // GEMM shape ck::index_t M = 1024; diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp32.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp32.cpp index 8e120851ec..6b5dde3cc7 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp32.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp32.cpp @@ -112,7 +112,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; // GEMM shape ck::index_t M = 1024; diff --git a/example/22_cgemm/cgemm_xdl_int4.cpp b/example/22_cgemm/cgemm_xdl_int4.cpp index 47b0e1d5a5..4f21c70562 100644 --- a/example/22_cgemm/cgemm_xdl_int4.cpp +++ b/example/22_cgemm/cgemm_xdl_int4.cpp @@ -81,7 +81,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; // CGEMM shape ck::index_t M = 1024; diff --git a/example/23_softmax/softmax_blockwise.cpp b/example/23_softmax/softmax_blockwise.cpp index a741cb8133..0455819cdc 100644 --- a/example/23_softmax/softmax_blockwise.cpp +++ b/example/23_softmax/softmax_blockwise.cpp @@ -65,7 +65,7 @@ class SimpleAppArgs bool do_verification = true; int init_method = 2; - bool time_kernel = true; + bool time_kernel = false; public: void show_usage(const char* cmd) diff --git a/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc b/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc index 12d7cf0aa6..86a36d53e2 100644 --- a/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc +++ b/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc @@ -27,7 +27,7 @@ struct ExecutionConfig final { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; }; template diff --git a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp index 6efed7eb29..06bf971ac4 100644 --- a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp +++ b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp @@ -248,7 +248,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; ck::index_t G0 = 1; ck::index_t G1 = 2; diff --git a/example/30_grouped_conv_fwd_multiple_d/common.hpp b/example/30_grouped_conv_fwd_multiple_d/common.hpp index e1939d4300..dce9f62293 100644 --- a/example/30_grouped_conv_fwd_multiple_d/common.hpp +++ b/example/30_grouped_conv_fwd_multiple_d/common.hpp @@ -92,7 +92,7 @@ struct ExecutionConfig final { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; }; #define DefaultConvParam \ diff --git a/example/30_grouped_conv_fwd_multiple_d/common_wmma.hpp b/example/30_grouped_conv_fwd_multiple_d/common_wmma.hpp index ca8cba039f..2b27405ecd 100644 --- a/example/30_grouped_conv_fwd_multiple_d/common_wmma.hpp +++ b/example/30_grouped_conv_fwd_multiple_d/common_wmma.hpp @@ -92,7 +92,7 @@ struct ExecutionConfig final { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; }; #define DefaultConvParam \ diff --git a/example/33_multiple_reduce/dual_reduce_common.hpp b/example/33_multiple_reduce/dual_reduce_common.hpp index 3f04af5e89..923b5b6f15 100644 --- a/example/33_multiple_reduce/dual_reduce_common.hpp +++ b/example/33_multiple_reduce/dual_reduce_common.hpp @@ -40,7 +40,7 @@ class SimpleAppArgs bool do_verification = true; int init_method = 2; - bool time_kernel = true; + bool time_kernel = false; public: SimpleAppArgs() diff --git a/example/35_splitK_gemm/common.hpp b/example/35_splitK_gemm/common.hpp index d0f03f3611..8bf09ee786 100644 --- a/example/35_splitK_gemm/common.hpp +++ b/example/35_splitK_gemm/common.hpp @@ -44,7 +44,7 @@ struct ExecutionConfig final { bool do_verification = true; int init_method = 2; - bool time_kernel = true; + bool time_kernel = false; }; template diff --git a/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp b/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp index 2f290497c9..ea8858b958 100644 --- a/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp +++ b/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp @@ -56,7 +56,7 @@ template<> struct emb_kernel { using kernel_type = DeviceInsta int main(int argc, char* argv[]) { - bool time_kernel = true; + bool time_kernel = false; ck::index_t num_rows = 65536; constexpr auto dims = ck::Sequence<256, 512, 768, 1024, 1536, 2048, 4096, 8192>{}; diff --git a/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp b/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp index dc0b95863e..ab87124c6b 100644 --- a/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp +++ b/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp @@ -195,7 +195,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; // GEMM shape ck::index_t M = 1024; diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp index c6cc9c6a15..9e7039461c 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp @@ -86,7 +86,7 @@ using DeviceGroupedConvNDFwdInstance = int main(int argc, char* argv[]) { bool do_verification = true; - bool time_kernel = true; + bool time_kernel = false; if(argc == 1) { diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp index 0f49cb5a38..fa6a36c212 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp @@ -84,7 +84,7 @@ using DeviceGroupedConvNDFwdInstance = int main(int argc, char* argv[]) { bool do_verification = true; - bool time_kernel = true; + bool time_kernel = false; if(argc == 1) { diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp index 5652cc38ab..45651da757 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp @@ -87,7 +87,7 @@ using DeviceGroupedConvNDFwdInstance = int main(int argc, char* argv[]) { bool do_verification = true; - bool time_kernel = true; + bool time_kernel = false; if(argc == 1) { diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp index 138a214127..cda4c1419c 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp @@ -84,7 +84,7 @@ using DeviceGroupedConvNDFwdInstance = int main(int argc, char* argv[]) { bool do_verification = true; - bool time_kernel = true; + bool time_kernel = false; if(argc == 1) { diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perchannel_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perchannel_quantization_int8.cpp index 1652cea214..0e52ac280a 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perchannel_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perchannel_quantization_int8.cpp @@ -84,7 +84,7 @@ using DeviceGroupedConvNDFwdInstance = int main(int argc, char* argv[]) { bool do_verification = true; - bool time_kernel = true; + bool time_kernel = false; if(argc == 1) { diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp index f127940377..9bff452a67 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp @@ -90,7 +90,7 @@ using DeviceGroupedConvNDFwdInstance = int main(int argc, char* argv[]) { bool do_verification = true; - bool time_kernel = true; + bool time_kernel = false; if(argc == 1) { diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp index 7a03a3efe0..17a7b632af 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp @@ -88,7 +88,7 @@ using DeviceGroupedConvNDFwdInstance = int main(int argc, char* argv[]) { bool do_verification = true; - bool time_kernel = true; + bool time_kernel = false; if(argc == 1) { diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perchannel_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perchannel_quantization_int8.cpp index 155024dc62..345277e092 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perchannel_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perchannel_quantization_int8.cpp @@ -88,7 +88,7 @@ using DeviceGroupedConvNDFwdInstance = int main(int argc, char* argv[]) { bool do_verification = true; - bool time_kernel = true; + bool time_kernel = false; if(argc == 1) { diff --git a/example/42_groupnorm_fwd/run_groupnorm_fwd_example.inc b/example/42_groupnorm_fwd/run_groupnorm_fwd_example.inc index b1596b5a53..d5f9b831f0 100644 --- a/example/42_groupnorm_fwd/run_groupnorm_fwd_example.inc +++ b/example/42_groupnorm_fwd/run_groupnorm_fwd_example.inc @@ -12,7 +12,7 @@ int run_groupnorm_fwd_example(int argc, char* argv[]) ck::index_t C = 128; bool do_verification = true; - bool time_kernel = true; + bool time_kernel = false; bool log_kernel = true; if(argc == 1) diff --git a/example/44_elementwise_permute/elementwise_binary_4D_fp16.cpp b/example/44_elementwise_permute/elementwise_binary_4D_fp16.cpp index 14b338c9c5..e90880dabd 100644 --- a/example/44_elementwise_permute/elementwise_binary_4D_fp16.cpp +++ b/example/44_elementwise_permute/elementwise_binary_4D_fp16.cpp @@ -53,7 +53,7 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle int main(int argc, char* argv[]) { bool do_verification = true; - bool time_kernel = true; + bool time_kernel = false; std::vector nchw = {16, 128, 32, 64}; diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp index a7d139fc95..2b99d9261f 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp @@ -46,7 +46,7 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle int main(int argc, char* argv[]) { bool do_verification = true; - bool time_kernel = true; + bool time_kernel = false; if(argc == 1) { diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp index cd1db4cdaf..276aa7f3c7 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp @@ -50,7 +50,7 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle int main(int argc, char* argv[]) { bool do_verification = true; - bool time_kernel = true; + bool time_kernel = false; if(argc == 1) { diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp index 683c5cb072..0842325bad 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp @@ -50,7 +50,7 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle int main(int argc, char* argv[]) { bool do_verification = true; - bool time_kernel = true; + bool time_kernel = false; if(argc == 1) { diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp index abfd3ccf7c..a48f2349c9 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp @@ -49,7 +49,7 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle int main(int argc, char* argv[]) { bool do_verification = true; - bool time_kernel = true; + bool time_kernel = false; if(argc == 1) { diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp index ff4e8f3a3d..39d88c47a1 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp @@ -50,7 +50,7 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle int main(int argc, char* argv[]) { bool do_verification = true; - bool time_kernel = true; + bool time_kernel = false; if(argc == 1) { diff --git a/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp b/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp index 939860bf69..3aef0fdaac 100644 --- a/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp +++ b/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp @@ -121,7 +121,7 @@ void reference_scale_permute_amax(Tensor& input, int main(int argc, char* argv[]) { bool do_verification = true; - bool time_kernel = true; + bool time_kernel = false; const float scale = 2.f; diff --git a/example/44_elementwise_permute/elementwise_trinary_4D_fp16.cpp b/example/44_elementwise_permute/elementwise_trinary_4D_fp16.cpp index 497f1c67c8..86af00e4fb 100644 --- a/example/44_elementwise_permute/elementwise_trinary_4D_fp16.cpp +++ b/example/44_elementwise_permute/elementwise_trinary_4D_fp16.cpp @@ -58,7 +58,7 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle int main(int argc, char* argv[]) { bool do_verification = true; - bool time_kernel = true; + bool time_kernel = false; if(argc == 1) { diff --git a/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp b/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp index eb95128f38..71cee9c420 100644 --- a/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp +++ b/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp @@ -84,7 +84,7 @@ void host_elementwise2D(HostTensorC& C, int main(int argc, char* argv[]) { bool do_verification = true; - bool time_kernel = true; + bool time_kernel = false; ck::index_t M = 48 * 256; ck::index_t N = 1024; diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp index c0452b6067..10f7a38863 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp @@ -205,7 +205,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; // GEMM shape ck::index_t N = 4096; diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale.cpp index ecc3034bba..d6082e5882 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale.cpp @@ -193,7 +193,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; #if 1 // GEMM shape ck::index_t N = 4096; diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp index 0067c1d1fb..a2002270dc 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp @@ -194,7 +194,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; // per expert: // GEMM shape diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp index a602838c30..9f4cd13573 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp @@ -185,7 +185,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; // per expert: // GEMM shape diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp index fb5e3b6456..552d3cd7b5 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp @@ -188,7 +188,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; // tokens = 1 // topk = 1 diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp index f56410d37a..377b53b519 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp @@ -164,7 +164,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; // per expert: // GEMM shape diff --git a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp index 3ce059ba20..586ecd81bf 100644 --- a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp +++ b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp @@ -178,7 +178,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; // per expert: // GEMM shape diff --git a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bns.cpp b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bns.cpp index d1d601977d..b3b2ebcbc0 100644 --- a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bns.cpp +++ b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bns.cpp @@ -178,7 +178,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; // per expert: // GEMM shape diff --git a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp index 0078cc5625..5c7668ab73 100644 --- a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp +++ b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp @@ -208,7 +208,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; // per expert: // GEMM shape diff --git a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp index 202241d14f..04c3afc62b 100644 --- a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp +++ b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp @@ -171,7 +171,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; // per expert: // GEMM shape diff --git a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp index 660ccabc94..12bb76eccd 100644 --- a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp +++ b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp @@ -171,7 +171,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; // per expert: // GEMM shape diff --git a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp index f398959114..6a5f5a6b9f 100644 --- a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp +++ b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp @@ -204,7 +204,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; // per expert: // GEMM shape diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index d13c8cfdd9..233eafc366 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -104,7 +104,8 @@ target_link_libraries(test_ckb_reference_execution PRIVATE utility) # Tests convolution trait selection and configuration add_ck_builder_test(test_ckb_conv_traits - conv/ck/test_conv_traits.cpp) + conv/ck/test_conv_traits.cpp + conv/ck/unit_instance_to_conv_traits.cpp) # Tests convolution problem description and parameter handling add_ck_builder_test(test_ckb_conv_description diff --git a/experimental/builder/test/conv/ck/unit_instance_to_conv_traits.cpp b/experimental/builder/test/conv/ck/unit_instance_to_conv_traits.cpp new file mode 100644 index 0000000000..de2a4fdd14 --- /dev/null +++ b/experimental/builder/test/conv/ck/unit_instance_to_conv_traits.cpp @@ -0,0 +1,1128 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// ============================================================================ +// Unit Tests for InstanceTraits to ConvTraits Conversion +// ============================================================================ +// +// PURPOSE: +// -------- +// These tests verify the conversion layer between InstanceTraits (low-level +// template parameter extraction) and ConvTraits (high-level semantic traits). +// The conversion transforms raw CK kernel parameters into builder-friendly +// enums and structures. +// +// DESIGN RATIONALE: +// ----------------- +// ConvTraits uses a single generic specialization that works with any Device +// class satisfying the IsXdlFwdConv concept. This use of concepts is fragile +// and introduces extra complexity. We want to refector to just use functions +// for this conversion. +// +// These tests are intentionally verbose and repetitive to provide maximum +// coverage during refactoring. Once the refactoring is complete and stable, +// they can be simplified or consolidated. +// +// TEST COVERAGE: +// -------------- +// 1. Enum conversion functions (pipeline version, scheduler, etc.) +// 2. Signature extraction (direction, specialization, layout, data type) +// 3. Full transformation verification for each XDL Device class template: +// - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 +// - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle +// - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor +// +// NOTE: WMMA and DL (Direct Load) variants are not covered as they don't +// satisfy the IsXdlFwdConv concept (different tile parameter structure). +// ============================================================================ + +#include "ck/utility/scheduler_enum.hpp" +#include "ck_tile/builder/types.hpp" +#include +#include + +#include +#include +#include +#include +#include + +namespace { + +using ck_tile::builder::ConvDirection; +using ck_tile::builder::DataType; +using ck_tile::builder::ElementwiseOperation; +using ck_tile::builder::GemmPadding; +using ck_tile::builder::PipelineScheduler; +using ck_tile::builder::PipelineVersion; +using ck_tile::builder::TensorLayout; +using ::testing::ElementsAre; + +// ============================================================================ +// Test Enum Conversion Functions +// ============================================================================ + +TEST(InstanceToConvTraits, ConvertsBlockGemmPipelineVersion) +{ + using ck_tile::reflect::conv::convert_pipeline_version; + using enum ::ck::BlockGemmPipelineVersion; + using enum ::ck_tile::builder::PipelineVersion; + + EXPECT_EQ(convert_pipeline_version(), V1); + EXPECT_EQ(convert_pipeline_version(), V2); + EXPECT_EQ(convert_pipeline_version(), V3); + EXPECT_EQ(convert_pipeline_version(), V4); + EXPECT_EQ(convert_pipeline_version(), V5); +} + +TEST(InstanceToConvTraits, ConvertsPipelineVersion) +{ + using ck_tile::reflect::conv::convert_pipeline_version; + using enum ck::PipelineVersion; + using enum PipelineVersion; + + EXPECT_EQ(convert_pipeline_version(), V1); + EXPECT_EQ(convert_pipeline_version(), V2); + EXPECT_EQ(convert_pipeline_version(), V4); + EXPECT_EQ(convert_pipeline_version(), WEIGHT_ONLY); +} + +TEST(InstanceToConvTraits, ConvertsBlockGemmPipelineScheduler) +{ + using ck_tile::reflect::conv::convert_pipeline_scheduler; + using enum ck::BlockGemmPipelineScheduler; + using enum PipelineScheduler; + + EXPECT_EQ(convert_pipeline_scheduler(), INTRAWAVE); + EXPECT_EQ(convert_pipeline_scheduler(), INTERWAVE); +} + +TEST(InstanceToConvTraits, ConvertsLoopScheduler) +{ + using ck_tile::reflect::conv::convert_pipeline_scheduler; + using enum ck::LoopScheduler; + using enum PipelineScheduler; + + EXPECT_EQ(convert_pipeline_scheduler(), DEFAULT); + EXPECT_EQ(convert_pipeline_scheduler(), INTERWAVE); +} + +// ============================================================================ +// Test Convolution Direction Detection +// ============================================================================ + +TEST(InstanceToConvTraits, DetectsForwardDirection) +{ + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // ALayout + ck::tensor_layout::convolution::GKYXC, // BLayout + ck::Tuple<>, // DsLayout + ck::tensor_layout::convolution::GNHWK, // ELayout + ck::half_t, // ADataType + ck::half_t, // BDataType + float, // AccDataType + ck::half_t, // CShuffleDataType + ck::Tuple<>, // DsDataType + ck::half_t, // EDataType + ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, + ck::tensor_operation::device::GemmSpecialization::Default, + 256, + 128, + 128, + 16, + 8, + 8, + 32, + 32, + 4, + 4, + ck::Sequence<4, 64, 1>, + ck::Sequence<1, 0, 2>, + ck::Sequence<1, 0, 2>, + 2, + 8, + 8, + 1, + ck::Sequence<4, 64, 1>, + ck::Sequence<1, 0, 2>, + ck::Sequence<1, 0, 2>, + 2, + 8, + 8, + 1, + 1, + 1, + ck::Sequence<1, 32, 1, 8>, + 8, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::half_t, + ck::half_t, + false>; + + using Traits = ck_tile::reflect::conv::ConvTraits; + + EXPECT_EQ(Traits::direction, ConvDirection::FORWARD); +} + +// ============================================================================ +// Test Convolution Specialization Detection +// ============================================================================ + +TEST(InstanceToConvTraits, ExtractsDefaultSpecialization) +{ + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< + 2, + ck::tensor_layout::convolution::GNHWC, + ck::tensor_layout::convolution::GKYXC, + ck::Tuple<>, + ck::tensor_layout::convolution::GNHWK, + ck::half_t, + ck::half_t, + float, + ck::half_t, + ck::Tuple<>, + ck::half_t, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, + ck::tensor_operation::device::GemmSpecialization::Default, + 256, + 128, + 128, + 16, + 8, + 8, + 32, + 32, + 4, + 4, + ck::Sequence<4, 64, 1>, + ck::Sequence<1, 0, 2>, + ck::Sequence<1, 0, 2>, + 2, + 8, + 8, + 1, + ck::Sequence<4, 64, 1>, + ck::Sequence<1, 0, 2>, + ck::Sequence<1, 0, 2>, + 2, + 8, + 8, + 1, + 1, + 1, + ck::Sequence<1, 32, 1, 8>, + 8, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::half_t, + ck::half_t, + false>; + + using Traits = ck_tile::reflect::conv::ConvTraits; + + EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT); +} + +TEST(InstanceToConvTraits, ExtractsFilter1x1Pad0Specialization) +{ + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< + 2, + ck::tensor_layout::convolution::GNHWC, + ck::tensor_layout::convolution::GKYXC, + ck::Tuple<>, + ck::tensor_layout::convolution::GNHWK, + ck::half_t, + ck::half_t, + float, + ck::half_t, + ck::Tuple<>, + ck::half_t, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0, + ck::tensor_operation::device::GemmSpecialization::Default, + 256, + 128, + 128, + 16, + 8, + 8, + 32, + 32, + 4, + 4, + ck::Sequence<4, 64, 1>, + ck::Sequence<1, 0, 2>, + ck::Sequence<1, 0, 2>, + 2, + 8, + 8, + 1, + ck::Sequence<4, 64, 1>, + ck::Sequence<1, 0, 2>, + ck::Sequence<1, 0, 2>, + 2, + 8, + 8, + 1, + 1, + 1, + ck::Sequence<1, 32, 1, 8>, + 8, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::half_t, + ck::half_t, + false>; + + using Traits = ck_tile::reflect::conv::ConvTraits; + + EXPECT_EQ(Traits::conv_specialization, + ck_tile::builder::ConvFwdSpecialization::FILTER_1X1_PAD0); +} + +// ============================================================================ +// Test Layout Detection +// ============================================================================ + +TEST(InstanceToConvTraits, ExtractsGnhwcLayout) +{ + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< + 2, + ck::tensor_layout::convolution::GNHWC, + ck::tensor_layout::convolution::GKYXC, + ck::Tuple<>, + ck::tensor_layout::convolution::GNHWK, + ck::half_t, + ck::half_t, + float, + ck::half_t, + ck::Tuple<>, + ck::half_t, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, + ck::tensor_operation::device::GemmSpecialization::Default, + 256, + 128, + 128, + 16, + 8, + 8, + 32, + 32, + 4, + 4, + ck::Sequence<4, 64, 1>, + ck::Sequence<1, 0, 2>, + ck::Sequence<1, 0, 2>, + 2, + 8, + 8, + 1, + ck::Sequence<4, 64, 1>, + ck::Sequence<1, 0, 2>, + ck::Sequence<1, 0, 2>, + 2, + 8, + 8, + 1, + 1, + 1, + ck::Sequence<1, 32, 1, 8>, + 8, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::half_t, + ck::half_t, + false>; + + using Traits = ck_tile::reflect::conv::ConvTraits; + + EXPECT_THAT(Traits::layout, + ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); +} + +TEST(InstanceToConvTraits, ExtractsNhwgcLayout) +{ + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< + 2, + ck::tensor_layout::convolution::NHWGC, + ck::tensor_layout::convolution::GKYXC, + ck::Tuple<>, + ck::tensor_layout::convolution::NHWGK, + ck::half_t, + ck::half_t, + float, + ck::half_t, + ck::Tuple<>, + ck::half_t, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, + ck::tensor_operation::device::GemmSpecialization::Default, + 256, + 128, + 128, + 16, + 8, + 8, + 32, + 32, + 4, + 4, + ck::Sequence<4, 64, 1>, + ck::Sequence<1, 0, 2>, + ck::Sequence<1, 0, 2>, + 2, + 8, + 8, + 1, + ck::Sequence<4, 64, 1>, + ck::Sequence<1, 0, 2>, + ck::Sequence<1, 0, 2>, + 2, + 8, + 8, + 1, + 1, + 1, + ck::Sequence<1, 32, 1, 8>, + 8, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::half_t, + ck::half_t, + false>; + + using Traits = ck_tile::reflect::conv::ConvTraits; + + EXPECT_THAT(Traits::layout, + ElementsAre(TensorLayout::NHWGC, TensorLayout::GKYXC, TensorLayout::NHWGK)); +} + +TEST(InstanceToConvTraits, ExtractsNgchwGkyxcLayout) +{ + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< + 2, + ck::tensor_layout::convolution::NGCHW, + ck::tensor_layout::convolution::GKYXC, + ck::Tuple<>, + ck::tensor_layout::convolution::NGKHW, + ck::half_t, + ck::half_t, + float, + ck::half_t, + ck::Tuple<>, + ck::half_t, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, + ck::tensor_operation::device::GemmSpecialization::Default, + 256, + 128, + 128, + 16, + 8, + 8, + 32, + 32, + 4, + 4, + ck::Sequence<4, 64, 1>, + ck::Sequence<1, 0, 2>, + ck::Sequence<1, 0, 2>, + 2, + 8, + 8, + 1, + ck::Sequence<4, 64, 1>, + ck::Sequence<1, 0, 2>, + ck::Sequence<1, 0, 2>, + 2, + 8, + 8, + 1, + 1, + 1, + ck::Sequence<1, 32, 1, 8>, + 8, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::half_t, + ck::half_t, + false>; + + using Traits = ck_tile::reflect::conv::ConvTraits; + + EXPECT_THAT(Traits::layout, + ElementsAre(TensorLayout::NGCHW, TensorLayout::GKYXC, TensorLayout::NGKHW)); +} + +TEST(InstanceToConvTraits, ExtractsNgchwGkcyxLayout) +{ + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< + 2, + ck::tensor_layout::convolution::NGCHW, + ck::tensor_layout::convolution::GKCYX, + ck::Tuple<>, + ck::tensor_layout::convolution::NGKHW, + ck::half_t, + ck::half_t, + float, + ck::half_t, + ck::Tuple<>, + ck::half_t, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, + ck::tensor_operation::device::GemmSpecialization::Default, + 256, + 128, + 128, + 16, + 8, + 8, + 32, + 32, + 4, + 4, + ck::Sequence<4, 64, 1>, + ck::Sequence<1, 0, 2>, + ck::Sequence<1, 0, 2>, + 2, + 8, + 8, + 1, + ck::Sequence<4, 64, 1>, + ck::Sequence<1, 0, 2>, + ck::Sequence<1, 0, 2>, + 2, + 8, + 8, + 1, + 1, + 1, + ck::Sequence<1, 32, 1, 8>, + 8, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::half_t, + ck::half_t, + false>; + + using Traits = ck_tile::reflect::conv::ConvTraits; + + EXPECT_THAT(Traits::layout, + ElementsAre(TensorLayout::NGCHW, TensorLayout::GKCYX, TensorLayout::NGKHW)); +} + +// ============================================================================ +// Test Data Type Detection +// ============================================================================ + +TEST(InstanceToConvTraits, ExtractsFp16DataType) +{ + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< + 2, + ck::tensor_layout::convolution::GNHWC, + ck::tensor_layout::convolution::GKYXC, + ck::Tuple<>, + ck::tensor_layout::convolution::GNHWK, + ck::half_t, + ck::half_t, + float, + ck::half_t, + ck::Tuple<>, + ck::half_t, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, + ck::tensor_operation::device::GemmSpecialization::Default, + 256, + 128, + 128, + 16, + 8, + 8, + 32, + 32, + 4, + 4, + ck::Sequence<4, 64, 1>, + ck::Sequence<1, 0, 2>, + ck::Sequence<1, 0, 2>, + 2, + 8, + 8, + 1, + ck::Sequence<4, 64, 1>, + ck::Sequence<1, 0, 2>, + ck::Sequence<1, 0, 2>, + 2, + 8, + 8, + 1, + 1, + 1, + ck::Sequence<1, 32, 1, 8>, + 8, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::half_t, + ck::half_t, + false>; + + using Traits = ck_tile::reflect::conv::ConvTraits; + + EXPECT_EQ(Traits::data_type, DataType::FP16); +} + +TEST(InstanceToConvTraits, ExtractsBf16DataType) +{ + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< + 2, + ck::tensor_layout::convolution::GNHWC, + ck::tensor_layout::convolution::GKYXC, + ck::Tuple<>, + ck::tensor_layout::convolution::GNHWK, + ck::bhalf_t, + ck::bhalf_t, + float, + ck::bhalf_t, + ck::Tuple<>, + ck::bhalf_t, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, + ck::tensor_operation::device::GemmSpecialization::Default, + 256, + 128, + 128, + 16, + 8, + 8, + 32, + 32, + 4, + 4, + ck::Sequence<4, 64, 1>, + ck::Sequence<1, 0, 2>, + ck::Sequence<1, 0, 2>, + 2, + 8, + 8, + 1, + ck::Sequence<4, 64, 1>, + ck::Sequence<1, 0, 2>, + ck::Sequence<1, 0, 2>, + 2, + 8, + 8, + 1, + 1, + 1, + ck::Sequence<1, 32, 1, 8>, + 8, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::bhalf_t, + ck::bhalf_t, + false>; + + using Traits = ck_tile::reflect::conv::ConvTraits; + + EXPECT_EQ(Traits::data_type, DataType::BF16); +} + +TEST(InstanceToConvTraits, ExtractsFp32DataType) +{ + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< + 2, + ck::tensor_layout::convolution::GNHWC, + ck::tensor_layout::convolution::GKYXC, + ck::Tuple<>, + ck::tensor_layout::convolution::GNHWK, + float, + float, + float, + float, + ck::Tuple<>, + float, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, + ck::tensor_operation::device::GemmSpecialization::Default, + 256, + 128, + 128, + 16, + 8, + 8, + 32, + 32, + 4, + 4, + ck::Sequence<4, 64, 1>, + ck::Sequence<1, 0, 2>, + ck::Sequence<1, 0, 2>, + 2, + 8, + 8, + 1, + ck::Sequence<4, 64, 1>, + ck::Sequence<1, 0, 2>, + ck::Sequence<1, 0, 2>, + 2, + 8, + 8, + 1, + 1, + 1, + ck::Sequence<1, 32, 1, 8>, + 8, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + float, + float, + false>; + + using Traits = ck_tile::reflect::conv::ConvTraits; + + EXPECT_EQ(Traits::data_type, DataType::FP32); +} + +TEST(InstanceToConvTraits, ExtractsI8DataType) +{ + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< + 2, + ck::tensor_layout::convolution::GNHWC, + ck::tensor_layout::convolution::GKYXC, + ck::Tuple<>, + ck::tensor_layout::convolution::GNHWK, + int8_t, + int8_t, + int32_t, + int8_t, + ck::Tuple<>, + int8_t, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, + ck::tensor_operation::device::GemmSpecialization::Default, + 256, + 128, + 128, + 16, + 8, + 8, + 32, + 32, + 4, + 4, + ck::Sequence<4, 64, 1>, + ck::Sequence<1, 0, 2>, + ck::Sequence<1, 0, 2>, + 2, + 8, + 8, + 1, + ck::Sequence<4, 64, 1>, + ck::Sequence<1, 0, 2>, + ck::Sequence<1, 0, 2>, + 2, + 8, + 8, + 1, + 1, + 1, + ck::Sequence<1, 32, 1, 8>, + 8, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + int8_t, + int8_t, + false>; + + using Traits = ck_tile::reflect::conv::ConvTraits; + + EXPECT_EQ(Traits::data_type, DataType::I8); +} + +// ============================================================================ +// Test GEMM Padding Detection +// ============================================================================ + +TEST(InstanceToConvTraits, ExtractsDefaultGemmPadding) +{ + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< + 2, + ck::tensor_layout::convolution::GNHWC, + ck::tensor_layout::convolution::GKYXC, + ck::Tuple<>, + ck::tensor_layout::convolution::GNHWK, + ck::half_t, + ck::half_t, + float, + ck::half_t, + ck::Tuple<>, + ck::half_t, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, + ck::tensor_operation::device::GemmSpecialization::Default, + 256, + 128, + 128, + 16, + 8, + 8, + 32, + 32, + 4, + 4, + ck::Sequence<4, 64, 1>, + ck::Sequence<1, 0, 2>, + ck::Sequence<1, 0, 2>, + 2, + 8, + 8, + 1, + ck::Sequence<4, 64, 1>, + ck::Sequence<1, 0, 2>, + ck::Sequence<1, 0, 2>, + 2, + 8, + 8, + 1, + 1, + 1, + ck::Sequence<1, 32, 1, 8>, + 8, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::half_t, + ck::half_t, + false>; + + using Traits = ck_tile::reflect::conv::ConvTraits; + + EXPECT_EQ(Traits::gemm_padding, GemmPadding::DEFAULT); +} + +TEST(InstanceToConvTraits, ExtractsMnkGemmPadding) +{ + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< + 2, + ck::tensor_layout::convolution::GNHWC, + ck::tensor_layout::convolution::GKYXC, + ck::Tuple<>, + ck::tensor_layout::convolution::GNHWK, + ck::half_t, + ck::half_t, + float, + ck::half_t, + ck::Tuple<>, + ck::half_t, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, + ck::tensor_operation::device::GemmSpecialization::MNKPadding, + 256, + 128, + 128, + 16, + 8, + 8, + 32, + 32, + 4, + 4, + ck::Sequence<4, 64, 1>, + ck::Sequence<1, 0, 2>, + ck::Sequence<1, 0, 2>, + 2, + 8, + 8, + 1, + ck::Sequence<4, 64, 1>, + ck::Sequence<1, 0, 2>, + ck::Sequence<1, 0, 2>, + 2, + 8, + 8, + 1, + 1, + 1, + ck::Sequence<1, 32, 1, 8>, + 8, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::half_t, + ck::half_t, + false>; + + using Traits = ck_tile::reflect::conv::ConvTraits; + + EXPECT_EQ(Traits::gemm_padding, GemmPadding::MNK_PADDING); +} + +// ============================================================================ +// Comprehensive Transformation Tests - Per Device Class Template +// ============================================================================ +// These tests verify the complete InstanceTraits → ConvTraits transformation +// for each forward convolution Device class template. They are verbose to +// provide maximum safety during refactoring. +// ============================================================================ + +TEST(InstanceToConvTraits, TransformsFwdMultipleAbdXdlCShuffleV3) +{ + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // ALayout + ck::tensor_layout::convolution::GKYXC, // BLayout + ck::Tuple<>, // DsLayout + ck::tensor_layout::convolution::GNHWK, // ELayout + ck::half_t, // ADataType + ck::half_t, // BDataType + float, // AccDataType + ck::half_t, // CShuffleDataType + ck::Tuple<>, // DsDataType + ck::half_t, // EDataType + ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, + ck::tensor_operation::device::GemmSpecialization::Default, + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, 32, 1, 8>, // CDEBlockTransferClusterLengths + 8, // CDEBlockTransferScalarPerVector_NPerBlock + ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched + ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer + ck::half_t, // AComputeDataType + ck::half_t, // BComputeDataType + false>; // DirectLoad + + using InstTraits = ck_tile::reflect::InstanceTraits; + using ConvTraits = ck_tile::reflect::conv::ConvTraits; + + // Verify signature information + EXPECT_EQ(ConvTraits::spatial_dim, InstTraits::kSpatialDim); + EXPECT_EQ(ConvTraits::direction, ConvDirection::FORWARD); + EXPECT_EQ(ConvTraits::data_type, DataType::FP16); + EXPECT_EQ(ConvTraits::gemm_padding, GemmPadding::DEFAULT); + + // Verify tile dimensions + EXPECT_EQ(ConvTraits::tile_dims.m, InstTraits::kMPerBlock); + EXPECT_EQ(ConvTraits::tile_dims.n, InstTraits::kNPerBlock); + EXPECT_EQ(ConvTraits::tile_dims.k, InstTraits::kKPerBlock); + + // Verify pipeline configuration + EXPECT_EQ(ConvTraits::pipeline_scheduler, PipelineScheduler::INTRAWAVE); + EXPECT_EQ(ConvTraits::pipeline_version, PipelineVersion::V1); +} + +TEST(InstanceToConvTraits, TransformsFwdMultipleAbdXdlCShuffle) +{ + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // ALayout + ck::tensor_layout::convolution::GKYXC, // BLayout + ck::Tuple<>, // DsLayout + ck::tensor_layout::convolution::GNHWK, // ELayout + ck::half_t, // ADataType + ck::half_t, // BDataType + float, // AccDataType + ck::half_t, // CShuffleDataType + ck::Tuple<>, // DsDataType + ck::half_t, // EDataType + ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, + ck::tensor_operation::device::GemmSpecialization::Default, + 1, // NumGemmKPrefetchStage + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, 32, 1, 8>, // CDEBlockTransferClusterLengths + 8, // CDEBlockTransferScalarPerVector_NPerBlock + ck::half_t, // AComputeDataType + ck::half_t, // BComputeDataType + ck::LoopScheduler::Default, // LoopSched + 1>; // NumGroupsToMerge + + using InstTraits = ck_tile::reflect::InstanceTraits; + using ConvTraits = ck_tile::reflect::conv::ConvTraits; + + // Verify signature information + EXPECT_EQ(ConvTraits::spatial_dim, InstTraits::kSpatialDim); + EXPECT_EQ(ConvTraits::direction, ConvDirection::FORWARD); + EXPECT_EQ(ConvTraits::data_type, DataType::FP16); + EXPECT_EQ(ConvTraits::gemm_padding, GemmPadding::DEFAULT); + + // Verify tile dimensions + EXPECT_EQ(ConvTraits::tile_dims.m, InstTraits::kMPerBlock); + EXPECT_EQ(ConvTraits::tile_dims.n, InstTraits::kNPerBlock); + EXPECT_EQ(ConvTraits::tile_dims.k, InstTraits::kKPerBlock); + + // Verify pipeline configuration (uses LoopScheduler instead of BlockGemmPipelineScheduler) + EXPECT_EQ(ConvTraits::pipeline_scheduler, PipelineScheduler::DEFAULT); + EXPECT_EQ(ConvTraits::pipeline_version, PipelineVersion::V1); +} + +TEST(InstanceToConvTraits, TransformsFwdMultipleDXdlLargeTensor) +{ + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // ALayout + ck::tensor_layout::convolution::GKYXC, // BLayout + ck::Tuple<>, // DsLayout + ck::tensor_layout::convolution::GNHWK, // ELayout + ck::half_t, // ADataType + ck::half_t, // BDataType + float, // AccDataType + ck::half_t, // CShuffleDataType + ck::Tuple<>, // DsDataType + ck::half_t, // EDataType + ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, + ck::tensor_operation::device::GemmSpecialization::Default, + 1, // NumGemmKPrefetchStage + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, 32, 1, 8>, // CDEBlockTransferClusterLengths + 8, // CDEBlockTransferScalarPerVector_NPerBlock + ck::half_t, // AComputeDataType + ck::half_t, // BComputeDataType + ck::LoopScheduler::Default>; // LoopSched + + using InstTraits = ck_tile::reflect::InstanceTraits; + using ConvTraits = ck_tile::reflect::conv::ConvTraits; + + // Verify signature information + EXPECT_EQ(ConvTraits::spatial_dim, InstTraits::kSpatialDim); + EXPECT_EQ(ConvTraits::direction, ConvDirection::FORWARD); + EXPECT_EQ(ConvTraits::data_type, DataType::FP16); + EXPECT_EQ(ConvTraits::gemm_padding, GemmPadding::DEFAULT); + + // Verify tile dimensions + EXPECT_EQ(ConvTraits::tile_dims.m, InstTraits::kMPerBlock); + EXPECT_EQ(ConvTraits::tile_dims.n, InstTraits::kNPerBlock); + EXPECT_EQ(ConvTraits::tile_dims.k, InstTraits::kKPerBlock); + + // Verify pipeline configuration + EXPECT_EQ(ConvTraits::pipeline_scheduler, PipelineScheduler::DEFAULT); + EXPECT_EQ(ConvTraits::pipeline_version, PipelineVersion::V1); +} + +} // anonymous namespace diff --git a/test/batched_gemm_multi_d/test_batched_gemm_multi_d_dl.cpp b/test/batched_gemm_multi_d/test_batched_gemm_multi_d_dl.cpp index e26ac53abe..2403c564b7 100644 --- a/test/batched_gemm_multi_d/test_batched_gemm_multi_d_dl.cpp +++ b/test/batched_gemm_multi_d/test_batched_gemm_multi_d_dl.cpp @@ -61,7 +61,7 @@ class TestBatchedGemmMultiD : public ::testing::Test true, // do_verification 1, // init_method false, // do_log - 1, // time_kernel, + false, // time_kernel, M, N, K, diff --git a/test/gemm/gemm_standalone_xdl_fp16.cpp b/test/gemm/gemm_standalone_xdl_fp16.cpp index 90a5a325b8..2df67a083a 100644 --- a/test/gemm/gemm_standalone_xdl_fp16.cpp +++ b/test/gemm/gemm_standalone_xdl_fp16.cpp @@ -104,7 +104,7 @@ int main(int argc, char* argv[]) }; bool do_verification = true; - bool time_kernel = true; + bool time_kernel = false; int problem_index = -1; if(argc == 1) diff --git a/test/wrapper/test_wrapper_gemm_xdl.cpp b/test/wrapper/test_wrapper_gemm_xdl.cpp index b9d4bc3e57..b8965a217b 100644 --- a/test/wrapper/test_wrapper_gemm_xdl.cpp +++ b/test/wrapper/test_wrapper_gemm_xdl.cpp @@ -306,7 +306,7 @@ void PerformGemm(const ck::index_t M, const auto kernel = DeviceGemm; - const float avg_time = launch_and_time_kernel(StreamConfig{nullptr, true}, + const float avg_time = launch_and_time_kernel(StreamConfig{nullptr, false}, kernel, dim3(grid_size_x, grid_size_y, 1), dim3(ck::wrapper::size(thread_layout)),