diff --git a/Jenkinsfile b/Jenkinsfile index 11a9d9eb74..3fbcdb5849 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -1039,8 +1039,8 @@ pipeline { description: "Use the CK build to verify hipTensor build and tests (default: OFF)") string( name: 'hipTensor_branch', - defaultValue: 'mainline', - description: 'Specify which branch of hipTensor to use (default: mainline)') + defaultValue: 'develop', + description: 'Specify which branch of hipTensor to use (default: develop)') booleanParam( name: "USE_SCCACHE", defaultValue: true, @@ -1190,7 +1190,6 @@ pipeline { when { beforeAgent true expression { env.SHOULD_RUN_CI.toBoolean() } - expression { params.RUN_CPPCHECK.toBoolean() } } parallel{ stage('Clang Format and Cppcheck') { diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index f898d5f7b2..533f7f2f23 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -608,7 +608,7 @@ class KernelComponentFactory: else: pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) - if (hdim, hdim_v) in [(64, 64), (128, 128)] and logits == "f" and bias == "no" and dropout == "f" and lse == "f" and skip == "f": + if (hdim, hdim_v) in [(64, 64), (128, 128)] and logits == "f" and bias == "no" and dropout == "f" and skip == "f": pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 't')) pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 't')) if receipt == 1 and bias != "bias": diff --git a/example/ck_tile/20_grouped_convolution/gemm_configs.hpp b/example/ck_tile/20_grouped_convolution/gemm_configs.hpp index 37a63cd65c..77e1c3af1a 100644 --- a/example/ck_tile/20_grouped_convolution/gemm_configs.hpp +++ b/example/ck_tile/20_grouped_convolution/gemm_configs.hpp @@ -226,20 +226,20 @@ struct ConvTypeConfig; template <> struct ConvTypeConfig { - using InDataType = ck_tile::half_t; - using WeiDataType = ck_tile::half_t; + using InDataType = ck_tile::half_t; + using WeiDataType = ck_tile::half_t; using AccDataType = float; - using OutDataType = ck_tile::half_t; + using OutDataType = ck_tile::half_t; // ToDo: Add more bias config to support different categories of GEMM. }; template <> struct ConvTypeConfig { - using InDataType = ck_tile::bf16_t; - using WeiDataType = ck_tile::bf16_t; + using InDataType = ck_tile::bf16_t; + using WeiDataType = ck_tile::bf16_t; using AccDataType = float; - using OutDataType = ck_tile::bf16_t; + using OutDataType = ck_tile::bf16_t; }; template diff --git a/example/ck_tile/38_block_scale_gemm/README.md b/example/ck_tile/38_block_scale_gemm/README.md index b7b14f9d13..496697ca32 100644 --- a/example/ck_tile/38_block_scale_gemm/README.md +++ b/example/ck_tile/38_block_scale_gemm/README.md @@ -4,9 +4,18 @@ This folder contains examples of quant GEMMs using the ck_tile tile-programming - AQuant kernel with blocks of A matrix sharing scales: custom GEMM pipeline - BQuant kernel with blocks of B matrix sharing scales: custom GEMM pipeline -- Row and Column-wise scaled: All of the rowwise elements in A Matrix and columwise elements in B Matrix will share the same quantization element and the elementwisde operation will complete in epilogue. +- Row and Column-wise scaled: All of the row-wise elements in A Matrix and column-wise elements in B Matrix will share the same quantization element and the element-wise operation will complete in epilogue. - Tensor-wise scaled: Share the same scalar scale across the whole tensor of A or B +## Quantization Mode Comparison + +| Quant Mode | A Matrix Organization | A Scale Shape | B Matrix Organization | B Scale Shape | +|------------|----------------------|---------------|----------------------|---------------| +| **AQuant** | Blocks along K dimension
Each M×GroupSize block shares one scale | `[M, K/GroupSize]` | Not quantized | N/A | +| **BQuant** | Not quantized | N/A | Blocks along K dimension
Each GroupSize×N block shares one scale | `[K/GroupSize, N]` | +| **RowColQuant** | Per-row quantization
All K elements in each row share one scale | `[M, 1]` | Per-column quantization
All K elements in each column share one scale | `[1, N]` | +| **TensorQuant** | Tensor-wise quantization
All M×K elements share one scale | `[1]` | Tensor-wise quantization
All K×N elements share one scale | `[1]` | + --- ## Features diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp index c9cc56d033..0752dfdde4 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp @@ -451,4 +451,7 @@ int run_gemm_example(int argc, char* argv[]) } } -int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } +int main(int argc, char* argv[]) +{ + return !run_gemm_example(argc, argv); +} diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp index 071ea2dccc..15c56f9261 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp @@ -44,13 +44,13 @@ struct GroupedConvBwdDataKernelArgs CK_TILE_HOST GroupedConvBwdDataKernelArgs(const GroupedConvBwdDataHostArgs& args) { in_g_n_c_wis_lengths = {static_cast(args.G_), - static_cast(args.N_), - static_cast(args.C_), - static_cast(args.input_spatial_lengths_[0])}; + static_cast(args.N_), + static_cast(args.C_), + static_cast(args.input_spatial_lengths_[0])}; wei_g_k_c_xs_lengths = {static_cast(args.G_), - static_cast(args.K_), - static_cast(args.C_), - static_cast(args.filter_spatial_lengths_[0])}; + static_cast(args.K_), + static_cast(args.C_), + static_cast(args.filter_spatial_lengths_[0])}; out_g_n_k_wos_lengths = {static_cast(args.G_), static_cast(args.N_), static_cast(args.K_), @@ -145,15 +145,15 @@ struct GroupedConvBwdDataKernelArgs CK_TILE_HOST GroupedConvBwdDataKernelArgs(const GroupedConvBwdDataHostArgs& args) { in_g_n_c_wis_lengths = {static_cast(args.G_), - static_cast(args.N_), - static_cast(args.C_), - static_cast(args.input_spatial_lengths_[0]), - static_cast(args.input_spatial_lengths_[1])}; + static_cast(args.N_), + static_cast(args.C_), + static_cast(args.input_spatial_lengths_[0]), + static_cast(args.input_spatial_lengths_[1])}; wei_g_k_c_xs_lengths = {static_cast(args.G_), - static_cast(args.K_), - static_cast(args.C_), - static_cast(args.filter_spatial_lengths_[0]), - static_cast(args.filter_spatial_lengths_[1])}; + static_cast(args.K_), + static_cast(args.C_), + static_cast(args.filter_spatial_lengths_[0]), + static_cast(args.filter_spatial_lengths_[1])}; out_g_n_k_wos_lengths = {static_cast(args.G_), static_cast(args.N_), static_cast(args.K_), @@ -161,13 +161,13 @@ struct GroupedConvBwdDataKernelArgs static_cast(args.output_spatial_lengths_[1])}; conv_filter_strides = {static_cast(args.conv_filter_strides_[0]), - static_cast(args.conv_filter_strides_[1])}; + static_cast(args.conv_filter_strides_[1])}; conv_filter_dilations = {static_cast(args.conv_filter_dilations_[0]), static_cast(args.conv_filter_dilations_[1])}; input_left_pads = {static_cast(args.input_left_pads_[0]), - static_cast(args.input_left_pads_[1])}; + static_cast(args.input_left_pads_[1])}; input_right_pads = {static_cast(args.input_right_pads_[0]), - static_cast(args.input_right_pads_[1])}; + static_cast(args.input_right_pads_[1])}; k_batch = args.k_batch; @@ -262,17 +262,17 @@ struct GroupedConvBwdDataKernelArgs CK_TILE_HOST GroupedConvBwdDataKernelArgs(const GroupedConvBwdDataHostArgs& args) { in_g_n_c_wis_lengths = {static_cast(args.G_), - static_cast(args.N_), - static_cast(args.C_), - static_cast(args.input_spatial_lengths_[0]), - static_cast(args.input_spatial_lengths_[1]), - static_cast(args.input_spatial_lengths_[2])}; + static_cast(args.N_), + static_cast(args.C_), + static_cast(args.input_spatial_lengths_[0]), + static_cast(args.input_spatial_lengths_[1]), + static_cast(args.input_spatial_lengths_[2])}; wei_g_k_c_xs_lengths = {static_cast(args.G_), - static_cast(args.K_), - static_cast(args.C_), - static_cast(args.filter_spatial_lengths_[0]), - static_cast(args.filter_spatial_lengths_[1]), - static_cast(args.filter_spatial_lengths_[2])}; + static_cast(args.K_), + static_cast(args.C_), + static_cast(args.filter_spatial_lengths_[0]), + static_cast(args.filter_spatial_lengths_[1]), + static_cast(args.filter_spatial_lengths_[2])}; out_g_n_k_wos_lengths = {static_cast(args.G_), static_cast(args.N_), static_cast(args.K_), @@ -281,17 +281,17 @@ struct GroupedConvBwdDataKernelArgs static_cast(args.output_spatial_lengths_[2])}; conv_filter_strides = {static_cast(args.conv_filter_strides_[0]), - static_cast(args.conv_filter_strides_[1]), - static_cast(args.conv_filter_strides_[2])}; + static_cast(args.conv_filter_strides_[1]), + static_cast(args.conv_filter_strides_[2])}; conv_filter_dilations = {static_cast(args.conv_filter_dilations_[0]), static_cast(args.conv_filter_dilations_[1]), static_cast(args.conv_filter_dilations_[2])}; input_left_pads = {static_cast(args.input_left_pads_[0]), - static_cast(args.input_left_pads_[1]), - static_cast(args.input_left_pads_[2])}; + static_cast(args.input_left_pads_[1]), + static_cast(args.input_left_pads_[2])}; input_right_pads = {static_cast(args.input_right_pads_[0]), - static_cast(args.input_right_pads_[1]), - static_cast(args.input_right_pads_[2])}; + static_cast(args.input_right_pads_[1]), + static_cast(args.input_right_pads_[2])}; k_batch = args.k_batch; @@ -387,8 +387,8 @@ struct GroupedConvBwdDataKernelArgs static constexpr index_t MaxGroupedGemmGroupsNum = 128; - using ABCGridDescs = remove_cvref_t; + using ABCGridDescs = remove_cvref_t< + decltype(ConvToGemmTransformer{}.MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(1))>; using AGridDescMK = remove_cvref_t{}])>; using BGridDescNK = remove_cvref_t{}])>; diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp index 14a04615dd..83ecb34a79 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp @@ -40,13 +40,13 @@ struct GroupedConvBwdWeightKernelArgs CK_TILE_HOST GroupedConvBwdWeightKernelArgs(const GroupedConvBwdWeightHostArgs& args) { in_g_n_c_wis_lengths = {static_cast(args.G_), - static_cast(args.N_), - static_cast(args.C_), - static_cast(args.input_spatial_lengths_[0])}; + static_cast(args.N_), + static_cast(args.C_), + static_cast(args.input_spatial_lengths_[0])}; wei_g_k_c_xs_lengths = {static_cast(args.G_), - static_cast(args.K_), - static_cast(args.C_), - static_cast(args.filter_spatial_lengths_[0])}; + static_cast(args.K_), + static_cast(args.C_), + static_cast(args.filter_spatial_lengths_[0])}; out_g_n_k_wos_lengths = {static_cast(args.G_), static_cast(args.N_), static_cast(args.K_), @@ -109,15 +109,15 @@ struct GroupedConvBwdWeightKernelArgs CK_TILE_HOST GroupedConvBwdWeightKernelArgs(const GroupedConvBwdWeightHostArgs& args) { in_g_n_c_wis_lengths = {static_cast(args.G_), - static_cast(args.N_), - static_cast(args.C_), - static_cast(args.input_spatial_lengths_[0]), - static_cast(args.input_spatial_lengths_[1])}; + static_cast(args.N_), + static_cast(args.C_), + static_cast(args.input_spatial_lengths_[0]), + static_cast(args.input_spatial_lengths_[1])}; wei_g_k_c_xs_lengths = {static_cast(args.G_), - static_cast(args.K_), - static_cast(args.C_), - static_cast(args.filter_spatial_lengths_[0]), - static_cast(args.filter_spatial_lengths_[1])}; + static_cast(args.K_), + static_cast(args.C_), + static_cast(args.filter_spatial_lengths_[0]), + static_cast(args.filter_spatial_lengths_[1])}; out_g_n_k_wos_lengths = {static_cast(args.G_), static_cast(args.N_), static_cast(args.K_), @@ -125,13 +125,13 @@ struct GroupedConvBwdWeightKernelArgs static_cast(args.output_spatial_lengths_[1])}; conv_filter_strides = {static_cast(args.conv_filter_strides_[0]), - static_cast(args.conv_filter_strides_[1])}; + static_cast(args.conv_filter_strides_[1])}; conv_filter_dilations = {static_cast(args.conv_filter_dilations_[0]), static_cast(args.conv_filter_dilations_[1])}; input_left_pads = {static_cast(args.input_left_pads_[0]), - static_cast(args.input_left_pads_[1])}; + static_cast(args.input_left_pads_[1])}; input_right_pads = {static_cast(args.input_right_pads_[0]), - static_cast(args.input_right_pads_[1])}; + static_cast(args.input_right_pads_[1])}; k_batch = args.k_batch; @@ -185,17 +185,17 @@ struct GroupedConvBwdWeightKernelArgs CK_TILE_HOST GroupedConvBwdWeightKernelArgs(const GroupedConvBwdWeightHostArgs& args) { in_g_n_c_wis_lengths = {static_cast(args.G_), - static_cast(args.N_), - static_cast(args.C_), - static_cast(args.input_spatial_lengths_[0]), - static_cast(args.input_spatial_lengths_[1]), - static_cast(args.input_spatial_lengths_[2])}; + static_cast(args.N_), + static_cast(args.C_), + static_cast(args.input_spatial_lengths_[0]), + static_cast(args.input_spatial_lengths_[1]), + static_cast(args.input_spatial_lengths_[2])}; wei_g_k_c_xs_lengths = {static_cast(args.G_), - static_cast(args.K_), - static_cast(args.C_), - static_cast(args.filter_spatial_lengths_[0]), - static_cast(args.filter_spatial_lengths_[1]), - static_cast(args.filter_spatial_lengths_[2])}; + static_cast(args.K_), + static_cast(args.C_), + static_cast(args.filter_spatial_lengths_[0]), + static_cast(args.filter_spatial_lengths_[1]), + static_cast(args.filter_spatial_lengths_[2])}; out_g_n_k_wos_lengths = {static_cast(args.G_), static_cast(args.N_), static_cast(args.K_), @@ -204,17 +204,17 @@ struct GroupedConvBwdWeightKernelArgs static_cast(args.output_spatial_lengths_[2])}; conv_filter_strides = {static_cast(args.conv_filter_strides_[0]), - static_cast(args.conv_filter_strides_[1]), - static_cast(args.conv_filter_strides_[2])}; + static_cast(args.conv_filter_strides_[1]), + static_cast(args.conv_filter_strides_[2])}; conv_filter_dilations = {static_cast(args.conv_filter_dilations_[0]), static_cast(args.conv_filter_dilations_[1]), static_cast(args.conv_filter_dilations_[2])}; input_left_pads = {static_cast(args.input_left_pads_[0]), - static_cast(args.input_left_pads_[1]), - static_cast(args.input_left_pads_[2])}; + static_cast(args.input_left_pads_[1]), + static_cast(args.input_left_pads_[2])}; input_right_pads = {static_cast(args.input_right_pads_[0]), - static_cast(args.input_right_pads_[1]), - static_cast(args.input_right_pads_[2])}; + static_cast(args.input_right_pads_[1]), + static_cast(args.input_right_pads_[2])}; k_batch = args.k_batch; @@ -257,8 +257,8 @@ struct GroupedConvBwdWeightKernelArgs GemmBatch = args.G_; } - using ABCGridDescs = remove_cvref_t; + using ABCGridDescs = remove_cvref_t< + decltype(ConvToGemmTransformer{}.MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N())>; using AGridDescKM = remove_cvref_t{}])>; using BGridDescKN = remove_cvref_t{}])>; diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp index 7d7f8b1cf2..0363782d33 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp @@ -41,13 +41,13 @@ struct GroupedConvFwdKernelArgs CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs& args) { in_g_n_c_wis_lengths = {static_cast(args.G_), - static_cast(args.N_), - static_cast(args.C_), - static_cast(args.input_spatial_lengths_[0])}; + static_cast(args.N_), + static_cast(args.C_), + static_cast(args.input_spatial_lengths_[0])}; wei_g_k_c_xs_lengths = {static_cast(args.G_), - static_cast(args.K_), - static_cast(args.C_), - static_cast(args.filter_spatial_lengths_[0])}; + static_cast(args.K_), + static_cast(args.C_), + static_cast(args.filter_spatial_lengths_[0])}; out_g_n_k_wos_lengths = {static_cast(args.G_), static_cast(args.N_), static_cast(args.K_), @@ -124,15 +124,15 @@ struct GroupedConvFwdKernelArgs CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs& args) { in_g_n_c_wis_lengths = {static_cast(args.G_), - static_cast(args.N_), - static_cast(args.C_), - static_cast(args.input_spatial_lengths_[0]), - static_cast(args.input_spatial_lengths_[1])}; + static_cast(args.N_), + static_cast(args.C_), + static_cast(args.input_spatial_lengths_[0]), + static_cast(args.input_spatial_lengths_[1])}; wei_g_k_c_xs_lengths = {static_cast(args.G_), - static_cast(args.K_), - static_cast(args.C_), - static_cast(args.filter_spatial_lengths_[0]), - static_cast(args.filter_spatial_lengths_[1])}; + static_cast(args.K_), + static_cast(args.C_), + static_cast(args.filter_spatial_lengths_[0]), + static_cast(args.filter_spatial_lengths_[1])}; out_g_n_k_wos_lengths = {static_cast(args.G_), static_cast(args.N_), static_cast(args.K_), @@ -140,13 +140,13 @@ struct GroupedConvFwdKernelArgs static_cast(args.output_spatial_lengths_[1])}; conv_filter_strides = {static_cast(args.conv_filter_strides_[0]), - static_cast(args.conv_filter_strides_[1])}; + static_cast(args.conv_filter_strides_[1])}; conv_filter_dilations = {static_cast(args.conv_filter_dilations_[0]), static_cast(args.conv_filter_dilations_[1])}; input_left_pads = {static_cast(args.input_left_pads_[0]), - static_cast(args.input_left_pads_[1])}; + static_cast(args.input_left_pads_[1])}; input_right_pads = {static_cast(args.input_right_pads_[0]), - static_cast(args.input_right_pads_[1])}; + static_cast(args.input_right_pads_[1])}; k_batch = args.k_batch; @@ -216,17 +216,17 @@ struct GroupedConvFwdKernelArgs CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs& args) { in_g_n_c_wis_lengths = {static_cast(args.G_), - static_cast(args.N_), - static_cast(args.C_), - static_cast(args.input_spatial_lengths_[0]), - static_cast(args.input_spatial_lengths_[1]), - static_cast(args.input_spatial_lengths_[2])}; + static_cast(args.N_), + static_cast(args.C_), + static_cast(args.input_spatial_lengths_[0]), + static_cast(args.input_spatial_lengths_[1]), + static_cast(args.input_spatial_lengths_[2])}; wei_g_k_c_xs_lengths = {static_cast(args.G_), - static_cast(args.K_), - static_cast(args.C_), - static_cast(args.filter_spatial_lengths_[0]), - static_cast(args.filter_spatial_lengths_[1]), - static_cast(args.filter_spatial_lengths_[2])}; + static_cast(args.K_), + static_cast(args.C_), + static_cast(args.filter_spatial_lengths_[0]), + static_cast(args.filter_spatial_lengths_[1]), + static_cast(args.filter_spatial_lengths_[2])}; out_g_n_k_wos_lengths = {static_cast(args.G_), static_cast(args.N_), static_cast(args.K_), @@ -235,17 +235,17 @@ struct GroupedConvFwdKernelArgs static_cast(args.output_spatial_lengths_[2])}; conv_filter_strides = {static_cast(args.conv_filter_strides_[0]), - static_cast(args.conv_filter_strides_[1]), - static_cast(args.conv_filter_strides_[2])}; + static_cast(args.conv_filter_strides_[1]), + static_cast(args.conv_filter_strides_[2])}; conv_filter_dilations = {static_cast(args.conv_filter_dilations_[0]), static_cast(args.conv_filter_dilations_[1]), static_cast(args.conv_filter_dilations_[2])}; input_left_pads = {static_cast(args.input_left_pads_[0]), - static_cast(args.input_left_pads_[1]), - static_cast(args.input_left_pads_[2])}; + static_cast(args.input_left_pads_[1]), + static_cast(args.input_left_pads_[2])}; input_right_pads = {static_cast(args.input_right_pads_[0]), - static_cast(args.input_right_pads_[1]), - static_cast(args.input_right_pads_[2])}; + static_cast(args.input_right_pads_[1]), + static_cast(args.input_right_pads_[2])}; k_batch = args.k_batch; @@ -306,15 +306,15 @@ struct GroupedConvFwdKernelArgs args.output_spatial_lengths_[2]; } - using AGridDescMK = remove_cvref_t())>; - using BGridDescNK = remove_cvref_t())>; - using CGridDescMN = remove_cvref_t())>; + using AGridDescMK = remove_cvref_t< + decltype(ConvToGemmFwdTransformer{} + .template MakeADescriptor_M_K())>; + using BGridDescNK = remove_cvref_t< + decltype(ConvToGemmFwdTransformer{} + .template MakeBDescriptor_N_K())>; + using CGridDescMN = remove_cvref_t< + decltype(ConvToGemmFwdTransformer{} + .template MakeCDescriptor_M_N())>; static constexpr index_t NonSpatialDims = 3; array in_g_n_c_wis_lengths; diff --git a/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp b/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp index 0b73fe7adf..2369b2eac8 100644 --- a/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp @@ -177,12 +177,12 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification, in_device_buf.FromDevice(in_device.mData.data()); using ComputeType_ = std::conditional_t; + OutDataType, + WeiDataType>; using ComputeType = std::conditional_t; + ComputeType_, + ComputeDataType>; using AccDataType = std::conditional_t, int32_t, float>; const index_t num_accums = conv_param.K_; diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 96df4e32a1..292bc41a0b 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -45,7 +45,6 @@ set(REGRESSION_TESTS test_ck_tile_fmha_fwd_bf16 test_ck_tile_fmha_fwd_fp16 test_ck_tile_fmha_fwd_fp8 - test_ck_tile_streamk_extended ) function(add_test_executable TEST_NAME) diff --git a/test/ck_tile/gemm_streamk/CMakeLists.txt b/test/ck_tile/gemm_streamk/CMakeLists.txt index ae527a24f7..ec5d56d46d 100644 --- a/test/ck_tile/gemm_streamk/CMakeLists.txt +++ b/test/ck_tile/gemm_streamk/CMakeLists.txt @@ -24,98 +24,98 @@ if(GPU_TARGETS MATCHES "gfx9") ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_ccr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp #${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_ccc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp ) - - add_gtest_executable(test_ck_tile_streamk_extended - # compv3 pipeline - ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/f16_rrr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp - #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/f16_rrc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/f16_rcr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp - #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/f16_rcc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/f16_crr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp - #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/f16_crc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/f16_ccr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp - #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/f16_ccc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/f16_rrr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp - #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/f16_rrc_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/f16_rcr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp - #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/f16_rcc_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/f16_crr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp - #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/f16_crc_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/f16_ccr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp - #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/f16_ccc_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + # TODO: enable extended tests after tolerances for atomic reductions are addressed. + # add_gtest_executable(test_ck_tile_streamk_extended + # # compv3 pipeline + # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/f16_rrr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp + # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/f16_rrc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp + # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/f16_rcr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp + # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/f16_rcc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp + # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/f16_crr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp + # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/f16_crc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp + # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/f16_ccr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp + # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/f16_ccc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp + # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/f16_rrr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/f16_rrc_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/f16_rcr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/f16_rcc_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/f16_crr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/f16_crc_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/f16_ccr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/f16_ccc_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/bf16_rrr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp - #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/bf16_rrc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/bf16_rcr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp - #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/bf16_rcc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/bf16_crr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp - #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/bf16_crc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/bf16_ccr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp - #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/bf16_ccc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/bf16_rrr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp - #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/bf16_rrc_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/bf16_rcr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp - #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/bf16_rcc_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/bf16_crr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp - #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/bf16_crc_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/bf16_ccr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp - #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/bf16_ccc_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/bf16_rrr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp + # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/bf16_rrc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp + # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/bf16_rcr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp + # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/bf16_rcc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp + # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/bf16_crr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp + # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/bf16_crc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp + # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/bf16_ccr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp + # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/bf16_ccc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp + # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/bf16_rrr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/bf16_rrc_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/bf16_rcr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/bf16_rcc_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/bf16_crr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/bf16_crc_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/bf16_ccr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/bf16_ccc_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp - # TODO: add compv4 pipeline - # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/f16_rrr_compv4_256x256x32_2x2x1_32x32x16_NonPersistent.cpp - # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/f16_rrc_compv4_256x256x32_2x2x1_32x32x16_NonPersistent.cpp - # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/f16_rcr_compv4_256x256x32_2x2x1_32x32x16_NonPersistent.cpp - # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/f16_rcc_compv4_256x256x32_2x2x1_32x32x16_NonPersistent.cpp - # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/f16_crr_compv4_256x256x32_2x2x1_32x32x16_NonPersistent.cpp - # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/f16_crc_compv4_256x256x32_2x2x1_32x32x16_NonPersistent.cpp - # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/f16_ccr_compv4_256x256x32_2x2x1_32x32x16_NonPersistent.cpp - # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/f16_ccc_compv4_256x256x32_2x2x1_32x32x16_NonPersistent.cpp - # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/f16_rrr_compv4_128x128x32_2x2x1_32x32x16_NonPersistent.cpp - # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/f16_rrc_compv4_128x128x32_2x2x1_32x32x16_NonPersistent.cpp - # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/f16_rcr_compv4_128x128x32_2x2x1_32x32x16_NonPersistent.cpp - # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/f16_rcc_compv4_128x128x32_2x2x1_32x32x16_NonPersistent.cpp - # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/f16_crr_compv4_128x128x32_2x2x1_32x32x16_NonPersistent.cpp - # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/f16_crc_compv4_128x128x32_2x2x1_32x32x16_NonPersistent.cpp - # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/f16_ccr_compv4_128x128x32_2x2x1_32x32x16_NonPersistent.cpp - # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/f16_ccc_compv4_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + # # TODO: add compv4 pipeline + # # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/f16_rrr_compv4_256x256x32_2x2x1_32x32x16_NonPersistent.cpp + # # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/f16_rrc_compv4_256x256x32_2x2x1_32x32x16_NonPersistent.cpp + # # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/f16_rcr_compv4_256x256x32_2x2x1_32x32x16_NonPersistent.cpp + # # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/f16_rcc_compv4_256x256x32_2x2x1_32x32x16_NonPersistent.cpp + # # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/f16_crr_compv4_256x256x32_2x2x1_32x32x16_NonPersistent.cpp + # # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/f16_crc_compv4_256x256x32_2x2x1_32x32x16_NonPersistent.cpp + # # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/f16_ccr_compv4_256x256x32_2x2x1_32x32x16_NonPersistent.cpp + # # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/f16_ccc_compv4_256x256x32_2x2x1_32x32x16_NonPersistent.cpp + # # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/f16_rrr_compv4_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + # # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/f16_rrc_compv4_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + # # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/f16_rcr_compv4_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + # # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/f16_rcc_compv4_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + # # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/f16_crr_compv4_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + # # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/f16_crc_compv4_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + # # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/f16_ccr_compv4_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + # # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/f16_ccc_compv4_128x128x32_2x2x1_32x32x16_NonPersistent.cpp - # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/bf16_rrr_compv4_256x256x32_2x2x1_32x32x16_NonPersistent.cpp - # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/bf16_rrc_compv4_256x256x32_2x2x1_32x32x16_NonPersistent.cpp - # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/bf16_rcr_compv4_256x256x32_2x2x1_32x32x16_NonPersistent.cpp - # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/bf16_rcc_compv4_256x256x32_2x2x1_32x32x16_NonPersistent.cpp - # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/bf16_crr_compv4_256x256x32_2x2x1_32x32x16_NonPersistent.cpp - # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/bf16_crc_compv4_256x256x32_2x2x1_32x32x16_NonPersistent.cpp - # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/bf16_ccr_compv4_256x256x32_2x2x1_32x32x16_NonPersistent.cpp - # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/bf16_ccc_compv4_256x256x32_2x2x1_32x32x16_NonPersistent.cpp - # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/bf16_rrr_compv4_128x128x32_2x2x1_32x32x16_NonPersistent.cpp - # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/bf16_rrc_compv4_128x128x32_2x2x1_32x32x16_NonPersistent.cpp - # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/bf16_rcr_compv4_128x128x32_2x2x1_32x32x16_NonPersistent.cpp - # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/bf16_rcc_compv4_128x128x32_2x2x1_32x32x16_NonPersistent.cpp - # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/bf16_crr_compv4_128x128x32_2x2x1_32x32x16_NonPersistent.cpp - # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/bf16_crc_compv4_128x128x32_2x2x1_32x32x16_NonPersistent.cpp - # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/bf16_ccr_compv4_128x128x32_2x2x1_32x32x16_NonPersistent.cpp - # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/bf16_ccc_compv4_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + # # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/bf16_rrr_compv4_256x256x32_2x2x1_32x32x16_NonPersistent.cpp + # # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/bf16_rrc_compv4_256x256x32_2x2x1_32x32x16_NonPersistent.cpp + # # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/bf16_rcr_compv4_256x256x32_2x2x1_32x32x16_NonPersistent.cpp + # # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/bf16_rcc_compv4_256x256x32_2x2x1_32x32x16_NonPersistent.cpp + # # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/bf16_crr_compv4_256x256x32_2x2x1_32x32x16_NonPersistent.cpp + # # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/bf16_crc_compv4_256x256x32_2x2x1_32x32x16_NonPersistent.cpp + # # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/bf16_ccr_compv4_256x256x32_2x2x1_32x32x16_NonPersistent.cpp + # # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/bf16_ccc_compv4_256x256x32_2x2x1_32x32x16_NonPersistent.cpp + # # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/bf16_rrr_compv4_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + # # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/bf16_rrc_compv4_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + # # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/bf16_rcr_compv4_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + # # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/bf16_rcc_compv4_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + # # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/bf16_crr_compv4_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + # # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/bf16_crc_compv4_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + # # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/bf16_ccr_compv4_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + # # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/bf16_ccc_compv4_128x128x32_2x2x1_32x32x16_NonPersistent.cpp - # mem pipeline - ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/f16_rrr_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp - #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/f16_rrc_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/f16_rcr_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp - #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/f16_rcc_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/f16_crr_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp - #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/f16_crc_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/f16_ccr_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp - #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/f16_ccc_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + # # mem pipeline + # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/f16_rrr_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/f16_rrc_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/f16_rcr_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/f16_rcc_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/f16_crr_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/f16_crc_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/f16_ccr_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/f16_ccc_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/bf16_rrr_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp - #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/bf16_rrc_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/bf16_rcr_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp - #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/bf16_rcc_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/bf16_crr_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp - #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/bf16_crc_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/bf16_ccr_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp - #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/bf16_ccc_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp - ) + # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/bf16_rrr_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/bf16_rrc_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/bf16_rcr_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/bf16_rcc_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/bf16_crr_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/bf16_crc_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/bf16_ccr_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/bf16_ccc_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + # ) else() message(DEBUG "Skipping test_ck_tile_streamk tests for current target") endif() diff --git a/test/ck_tile/permute/alternative_impl/matrix_core_swizzle.hpp b/test/ck_tile/permute/alternative_impl/matrix_core_swizzle.hpp index 021cc303ad..062afb2664 100644 --- a/test/ck_tile/permute/alternative_impl/matrix_core_swizzle.hpp +++ b/test/ck_tile/permute/alternative_impl/matrix_core_swizzle.hpp @@ -13,119 +13,90 @@ struct matrix_core_swizzle_traits using matrix_core_swizzle_args = matrix_core_swizzle_host_args; -// host API -template // only supported with fp16 data type -float matrix_core_swizzle(matrix_core_swizzle_traits, - matrix_core_swizzle_args, - const ck_tile::stream_config&); - -template <> -float matrix_core_swizzle(matrix_core_swizzle_traits t, - matrix_core_swizzle_args a, - const ck_tile::stream_config& s) +template +void matrix_core_swizzle(matrix_core_swizzle_traits t, + matrix_core_swizzle_args a, + const ck_tile::stream_config& s) { - if(t.inst.compare("32x32x8") == 0) + if constexpr(!std::is_same_v) { - constexpr int BLOCK_SIZE = 256; - constexpr int NPerBlock = 256; - constexpr int KPerBlock = 128; - constexpr matrix_core_inst_enum Inst = matrix_core_inst_enum::MFMA_32x32x8_F16; - if(t.permute.compare("0,1,4,2,5,3,6") == 0) + throw std::runtime_error("matrix_core_swizzle is only supported for fp16"); + } + else + { + if(t.inst.compare("32x32x8") == 0) { - constexpr matrix_core_permute_style pstyle = - matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2; - using Kernel = - matrix_core_swizzle_kernel; + constexpr int BLOCK_SIZE = 256; + constexpr int NPerBlock = 256; + constexpr int KPerBlock = 128; + constexpr matrix_core_inst_enum Inst = matrix_core_inst_enum::MFMA_32x32x8_F16; + if(t.permute.compare("0,1,4,2,5,3,6") == 0) + { + constexpr matrix_core_permute_style pstyle = + matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2; + using Kernel = + matrix_core_swizzle_kernel; - auto k = Kernel(a); - float ave_time = ck_tile::launch_kernel(s, k); + auto k = Kernel(a); + ck_tile::launch_kernel(s, k); + } + else if(t.permute.compare("0,1,2,4,5,3,6") == 0) + { + constexpr matrix_core_permute_style pstyle = + matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2; + using Kernel = + matrix_core_swizzle_kernel; - return ave_time; + auto k = Kernel(a); + ck_tile::launch_kernel(s, k); + } + else if(t.permute.compare("0,1,3,4,2,5") == 0) + { + constexpr matrix_core_permute_style pstyle = + matrix_core_permute_style::b_nr_kr_kw_nw_kv; + using Kernel = + matrix_core_swizzle_kernel; + + auto k = Kernel(a); + ck_tile::launch_kernel(s, k); + } } - else if(t.permute.compare("0,1,2,4,5,3,6") == 0) + else if(t.inst.compare("16x16x16") == 0) { - constexpr matrix_core_permute_style pstyle = - matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2; - using Kernel = - matrix_core_swizzle_kernel; + constexpr int BLOCK_SIZE = 256; + constexpr int NPerBlock = 256; + constexpr int KPerBlock = 128; + constexpr matrix_core_inst_enum Inst = matrix_core_inst_enum::MFMA_16x16x16_F16; + if(t.permute.compare("0,1,4,2,5,3,6") == 0) + { + constexpr matrix_core_permute_style pstyle = + matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2; + using Kernel = + matrix_core_swizzle_kernel; - auto k = Kernel(a); - float ave_time = ck_tile::launch_kernel(s, k); + auto k = Kernel(a); + ck_tile::launch_kernel(s, k); + } + else if(t.permute.compare("0,1,2,4,5,3,6") == 0) + { + constexpr matrix_core_permute_style pstyle = + matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2; + using Kernel = + matrix_core_swizzle_kernel; - return ave_time; - } - else if(t.permute.compare("0,1,3,4,2,5") == 0) - { - constexpr matrix_core_permute_style pstyle = - matrix_core_permute_style::b_nr_kr_kw_nw_kv; - using Kernel = - matrix_core_swizzle_kernel; + auto k = Kernel(a); + ck_tile::launch_kernel(s, k); + } + else if(t.permute.compare("0,1,3,4,2,5") == 0) + { + constexpr matrix_core_permute_style pstyle = + matrix_core_permute_style::b_nr_kr_kw_nw_kv; + using Kernel = + matrix_core_swizzle_kernel; - auto k = Kernel(a); - float ave_time = ck_tile::launch_kernel(s, k); - - return ave_time; + auto k = Kernel(a); + ck_tile::launch_kernel(s, k); + } } } - else if(t.inst.compare("16x16x16") == 0) - { - constexpr int BLOCK_SIZE = 256; - constexpr int NPerBlock = 256; - constexpr int KPerBlock = 128; - constexpr matrix_core_inst_enum Inst = matrix_core_inst_enum::MFMA_16x16x16_F16; - if(t.permute.compare("0,1,4,2,5,3,6") == 0) - { - constexpr matrix_core_permute_style pstyle = - matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2; - using Kernel = - matrix_core_swizzle_kernel; - - auto k = Kernel(a); - float ave_time = ck_tile::launch_kernel(s, k); - - return ave_time; - } - else if(t.permute.compare("0,1,2,4,5,3,6") == 0) - { - constexpr matrix_core_permute_style pstyle = - matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2; - using Kernel = - matrix_core_swizzle_kernel; - - auto k = Kernel(a); - float ave_time = ck_tile::launch_kernel(s, k); - - return ave_time; - } - else if(t.permute.compare("0,1,3,4,2,5") == 0) - { - constexpr matrix_core_permute_style pstyle = - matrix_core_permute_style::b_nr_kr_kw_nw_kv; - using Kernel = - matrix_core_swizzle_kernel; - - auto k = Kernel(a); - float ave_time = ck_tile::launch_kernel(s, k); - - return ave_time; - } - } - - return -1; -} - -template <> -float matrix_core_swizzle(matrix_core_swizzle_traits, - matrix_core_swizzle_args, - const ck_tile::stream_config&) -{ - throw std::runtime_error("Not supported for fp8"); -} - -template <> -float matrix_core_swizzle(matrix_core_swizzle_traits, - matrix_core_swizzle_args, - const ck_tile::stream_config&) -{ - throw std::runtime_error("Not supported for fp32"); }