From 118afa455cd9a703125707e5a11668306907c013 Mon Sep 17 00:00:00 2001 From: Khushbu Agarwal Date: Wed, 14 Jan 2026 10:00:19 -0800 Subject: [PATCH 01/99] [CK_Tile] Support for group size 128 for Preshuffle quant for 2d block scale gemm (#3462) * formatted * formatted * formatting * formatting * formatting * [CK TILE GEMM] Refactor block_scale_gemm examples - Split cpp file to reduce building time - Support multiple GemmConfig * [CK TILE GEMM] Refactor block_scale_gemm examples - Update Readme * enable prefill shapes * [CK TILE GEMM] Refactor block_scale_gemm examples - Add support for rowcol and tensor GEMM operations * [CK TILE GEMM] Refactor block_scale_gemm examples - Update README * adding preshuffle quant as new parameter and its associated new files * remove debugging statements * adding test * enable preshuffle quant with permuteN * updating readme and correcponding gemmconfigs * updating cmake file * fixing CI failures for grouped quant gemm * debugging permuteN * debugging * debugging PermuteN * initial commit * resolving merge conflicts * adding test cases * initial commit with prints * debugging * fine-grained working * debugging medium grained * fixing the tile window * formatting * enabling prefill shapes * working prefill shapes * formatted * clean up * code cleanup * bug fix after merging with develop * G128 working for both prefill and decode shapes for preshufflequant * clean up after merging with develop * fixing group 64 for decode shapes * non preshufflequant working for group size 128 * enable preshuffleb and preshufflequant with variour group sizes * reduce build time by splitting example into diff datatype files * Adding tests for preshuffleQuant * address review comment * fix for gfx1201 * compile time fix for gfx1201 * clang formatted --------- Co-authored-by: Cong Ma Co-authored-by: Thomas Ning Co-authored-by: Agarwal --- .../38_block_scale_gemm/CMakeLists.txt | 15 +- .../gemm_bquant_quantgrouped_bf8.cpp | 6 + .../gemm_bquant_quantgrouped_bf8i4.cpp | 6 + .../gemm_bquant_quantgrouped_fp8.cpp | 6 + .../gemm_bquant_quantgrouped_fp8i4.cpp | 6 + .../gemm_bquant_quantgrouped_preshuffleb.cpp | 222 -------------- ...mm_bquant_quantgrouped_preshuffleb_bf8.cpp | 53 ++++ ..._bquant_quantgrouped_preshuffleb_bf8i4.cpp | 57 ++++ ...mm_bquant_quantgrouped_preshuffleb_fp8.cpp | 53 ++++ ..._bquant_quantgrouped_preshuffleb_fp8i4.cpp | 57 ++++ ...antgrouped_preshuffleb_preshufflequant.cpp | 62 ---- ...rouped_preshuffleb_preshufflequant_bf8.cpp | 50 ++++ ...uped_preshuffleb_preshufflequant_bf8i4.cpp | 52 ++++ ...rouped_preshuffleb_preshufflequant_fp8.cpp | 50 ++++ ...uped_preshuffleb_preshufflequant_fp8i4.cpp | 52 ++++ ...mm_bquant_quantgrouped_preshufflequant.cpp | 270 ------------------ ...quant_quantgrouped_preshufflequant_bf8.cpp | 55 ++++ ...ant_quantgrouped_preshufflequant_bf8i4.cpp | 59 ++++ ...quant_quantgrouped_preshufflequant_fp8.cpp | 55 ++++ ...ant_quantgrouped_preshufflequant_fp8i4.cpp | 59 ++++ .../38_block_scale_gemm/gemm_quant.cpp | 39 ++- .../block_universal_gemm_as_bs_bquant_cr.hpp | 18 +- .../gemm_quant/kernel/gemm_quant_kernel.hpp | 75 +++-- .../gemm_bquant_pipeline_ag_bg_cr_base.hpp | 9 +- .../gemm_bquant_pipeline_ag_bg_cr_policy.hpp | 4 +- .../gemm_bquant_pipeline_ag_bg_cr_v3.hpp | 4 +- .../pipeline/gemm_group_quant_utils.hpp | 74 +++-- .../pipeline/gemm_quant_pipeline_problem.hpp | 2 - .../gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp | 43 +-- test/ck_tile/gemm_block_scale/CMakeLists.txt | 26 ++ ...quant_bquant_preshuffleQuant_decode_1d.cpp | 39 +++ ...quant_bquant_preshuffleQuant_decode_2d.cpp | 54 ++++ ...uant_bquant_preshuffleQuant_prefill_1d.cpp | 41 +++ ...uant_bquant_preshuffleQuant_prefill_2d.cpp | 63 ++++ ...gemm_quant_bquant_preshuffle_decode_2d.cpp | 13 +- ...emm_quant_bquant_preshuffle_prefill_2d.cpp | 15 +- .../test_gemm_quant_fixtures.hpp | 53 ++-- 37 files changed, 1136 insertions(+), 681 deletions(-) delete mode 100644 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_bf8.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_bf8i4.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_fp8.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_fp8i4.cpp delete mode 100644 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant_bf8.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant_fp8.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4.cpp delete mode 100644 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_bf8.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_bf8i4.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_fp8.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_fp8i4.cpp create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_decode_1d.cpp create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_decode_2d.cpp create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_prefill_1d.cpp create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_prefill_2d.cpp diff --git a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt index 28e52b9275..ec536f7287 100644 --- a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt +++ b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt @@ -20,9 +20,18 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") gemm_bquant_quantgrouped_bf16mxfp4.cpp gemm_bquant_quantgrouped_bf8.cpp gemm_bquant_quantgrouped_fp8.cpp - gemm_bquant_quantgrouped_preshuffleb.cpp - gemm_bquant_quantgrouped_preshufflequant.cpp - gemm_bquant_quantgrouped_preshuffleb_preshufflequant.cpp + gemm_bquant_quantgrouped_preshuffleb_bf8i4.cpp + gemm_bquant_quantgrouped_preshuffleb_fp8i4.cpp + gemm_bquant_quantgrouped_preshuffleb_bf8.cpp + gemm_bquant_quantgrouped_preshuffleb_fp8.cpp + gemm_bquant_quantgrouped_preshufflequant_bf8i4.cpp + gemm_bquant_quantgrouped_preshufflequant_fp8i4.cpp + gemm_bquant_quantgrouped_preshufflequant_bf8.cpp + gemm_bquant_quantgrouped_preshufflequant_fp8.cpp + gemm_bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4.cpp + gemm_bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4.cpp + gemm_bquant_quantgrouped_preshuffleb_preshufflequant_bf8.cpp + gemm_bquant_quantgrouped_preshuffleb_preshufflequant_fp8.cpp gemm_quant_rowcol.cpp gemm_quant_tensor.cpp ) diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8.cpp index 61fd65960f..82e30e56d2 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8.cpp @@ -49,4 +49,10 @@ void bquant_quantgrouped_bf8_instance_factory( using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; + lut[hash_multiple_strings( + {"bf8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x128x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; } diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8i4.cpp index 1d471068eb..515e6eb027 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8i4.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8i4.cpp @@ -51,4 +51,10 @@ void bquant_quantgrouped_bf8i4_instance_factory( using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x128x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; } diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8.cpp index 280029033b..eaf10f057c 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8.cpp @@ -49,4 +49,10 @@ void bquant_quantgrouped_fp8_instance_factory( using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; + lut[hash_multiple_strings( + {"fp8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x128x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; } diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8i4.cpp index a277c864bb..c91867534f 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8i4.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8i4.cpp @@ -51,4 +51,10 @@ void bquant_quantgrouped_fp8i4_instance_factory( using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x128x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; } diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb.cpp deleted file mode 100644 index b32356c29d..0000000000 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb.cpp +++ /dev/null @@ -1,222 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "run_gemm_quant_example.inc" - -#if CK_TILE_USE_WMMA -template -using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill_Wmma; -#else -template -using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill; -#endif - -void bquant_quantgrouped_preshuffleb_instance_factory( - std::unordered_map>& lut) -{ - lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"fp8", - "bquant", - "preshuffleb", - "non-preshufflequant", - "1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"fp8", - "bquant", - "preshuffleb", - "non-preshufflequant", - "1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - - lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"bf8", - "bquant", - "preshuffleb", - "non-preshufflequant", - "1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"bf8", - "bquant", - "preshuffleb", - "non-preshufflequant", - "1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x32x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x64x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x32x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x64x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; -} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_bf8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_bf8.cpp new file mode 100644 index 0000000000..7166a5647e --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_bf8.cpp @@ -0,0 +1,53 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +#if CK_TILE_USE_WMMA +template +using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill_Wmma; +#else +template +using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill; +#endif + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +void bquant_quantgrouped_preshuffleb_bf8_instance_factory( + std::unordered_map>& lut) +{ + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x128x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_bf8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_bf8i4.cpp new file mode 100644 index 0000000000..85599864db --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_bf8i4.cpp @@ -0,0 +1,57 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +#if CK_TILE_USE_WMMA +template +using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill_Wmma; +#else +template +using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill; +#endif + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +void bquant_quantgrouped_preshuffleb_bf8i4_instance_factory( + std::unordered_map>& lut) +{ + using TypeConfig = decltype(GemmQuantTypeConfig{}); + lut[hash_multiple_strings( + {"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x128x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_fp8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_fp8.cpp new file mode 100644 index 0000000000..87cb4c9d10 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_fp8.cpp @@ -0,0 +1,53 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +#if CK_TILE_USE_WMMA +template +using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill_Wmma; +#else +template +using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill; +#endif + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +void bquant_quantgrouped_preshuffleb_fp8_instance_factory( + std::unordered_map>& lut) +{ + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x128x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_fp8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_fp8i4.cpp new file mode 100644 index 0000000000..0cb16441a9 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_fp8i4.cpp @@ -0,0 +1,57 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +#if CK_TILE_USE_WMMA +template +using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill_Wmma; +#else +template +using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill; +#endif + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +void bquant_quantgrouped_preshuffleb_fp8i4_instance_factory( + std::unordered_map>& lut) +{ + using TypeConfig = decltype(GemmQuantTypeConfig{}); + lut[hash_multiple_strings( + {"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x128x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant.cpp deleted file mode 100644 index 180f353df8..0000000000 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant.cpp +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "run_gemm_quant_example.inc" - -#if CK_TILE_USE_WMMA -template -using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill_Wmma; -#else -template -using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill; -#endif - -void bquant_quantgrouped_preshuffleb_preshufflequant_instance_factory( - std::unordered_map>& lut) -{ - using QuantGroupSize = ck_tile::QuantGroupShape>; - lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; -} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant_bf8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant_bf8.cpp new file mode 100644 index 0000000000..640757a956 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant_bf8.cpp @@ -0,0 +1,50 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +#if CK_TILE_USE_WMMA +template +using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill_Wmma; +#else +template +using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill; +#endif + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +void bquant_quantgrouped_preshuffleb_preshufflequant_bf8_instance_factory( + std::unordered_map>& lut) +{ + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "preshufflequant", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "preshufflequant", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "preshufflequant", "1x128x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4.cpp new file mode 100644 index 0000000000..575a43afd8 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4.cpp @@ -0,0 +1,52 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +#if CK_TILE_USE_WMMA +template +using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill_Wmma; +#else +template +using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill; +#endif + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +void bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4_instance_factory( + std::unordered_map>& lut) +{ + using TypeConfig = decltype(GemmQuantTypeConfig{}); + lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "preshufflequant", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "preshufflequant", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "preshufflequant", "1x128x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant_fp8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant_fp8.cpp new file mode 100644 index 0000000000..9e40fbaa87 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant_fp8.cpp @@ -0,0 +1,50 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +#if CK_TILE_USE_WMMA +template +using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill_Wmma; +#else +template +using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill; +#endif + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +void bquant_quantgrouped_preshuffleb_preshufflequant_fp8_instance_factory( + std::unordered_map>& lut) +{ + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "preshufflequant", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "preshufflequant", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "preshufflequant", "1x128x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4.cpp new file mode 100644 index 0000000000..2552a1d134 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4.cpp @@ -0,0 +1,52 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +#if CK_TILE_USE_WMMA +template +using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill_Wmma; +#else +template +using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill; +#endif + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +void bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4_instance_factory( + std::unordered_map>& lut) +{ + using TypeConfig = decltype(GemmQuantTypeConfig{}); + lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "preshufflequant", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "preshufflequant", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "preshufflequant", "1x128x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp deleted file mode 100644 index 62ca34b057..0000000000 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp +++ /dev/null @@ -1,270 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "run_gemm_quant_example.inc" - -template -using GemmConfig = GemmConfigPreshuffleBQuantPrefill; - -void bquant_quantgrouped_preshufflequant_instance_factory( - std::unordered_map>& lut) -{ - lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - - lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"fp8", - "bquant", - "non-preshuffleb", - "preshufflequant", - "1x16x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"fp8", - "bquant", - "non-preshuffleb", - "preshufflequant", - "1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"fp8", - "bquant", - "non-preshuffleb", - "preshufflequant", - "1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - - lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"bf8", - "bquant", - "non-preshuffleb", - "preshufflequant", - "1x16x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"bf8", - "bquant", - "non-preshuffleb", - "preshufflequant", - "1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"bf8", - "bquant", - "non-preshuffleb", - "preshufflequant", - "1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x16x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x32x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x64x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x16x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x32x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x64x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; -} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_bf8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_bf8.cpp new file mode 100644 index 0000000000..edb28236af --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_bf8.cpp @@ -0,0 +1,55 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigPreshuffleBQuantPrefill; + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +void bquant_quantgrouped_preshufflequant_bf8_instance_factory( + std::unordered_map>& lut) +{ + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + + lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x16x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x128x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_bf8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_bf8i4.cpp new file mode 100644 index 0000000000..59da63447e --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_bf8i4.cpp @@ -0,0 +1,59 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigPreshuffleBQuantPrefill; + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +void bquant_quantgrouped_preshufflequant_bf8i4_instance_factory( + std::unordered_map>& lut) +{ + using TypeConfig = decltype(GemmQuantTypeConfig{}); + lut[hash_multiple_strings( + {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + + lut[hash_multiple_strings( + {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x16x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x128x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_fp8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_fp8.cpp new file mode 100644 index 0000000000..29c88001e8 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_fp8.cpp @@ -0,0 +1,55 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigPreshuffleBQuantPrefill; + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +void bquant_quantgrouped_preshufflequant_fp8_instance_factory( + std::unordered_map>& lut) +{ + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + + lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x16x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x128x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_fp8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_fp8i4.cpp new file mode 100644 index 0000000000..f487132557 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_fp8i4.cpp @@ -0,0 +1,59 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigPreshuffleBQuantPrefill; + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +void bquant_quantgrouped_preshufflequant_fp8i4_instance_factory( + std::unordered_map>& lut) +{ + using TypeConfig = decltype(GemmQuantTypeConfig{}); + lut[hash_multiple_strings( + {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + + lut[hash_multiple_strings( + {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x16x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x128x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp index 940c1b8cf3..8de58b0a30 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp @@ -111,11 +111,29 @@ void bquant_quantgrouped_bf8i4_instance_factory( std::unordered_map>& lut); void bquant_quantgrouped_bf16fp4_instance_factory( std::unordered_map>& lut); -void bquant_quantgrouped_preshuffleb_instance_factory( +void bquant_quantgrouped_preshuffleb_fp8_instance_factory( std::unordered_map>& lut); -void bquant_quantgrouped_preshufflequant_instance_factory( +void bquant_quantgrouped_preshuffleb_bf8_instance_factory( std::unordered_map>& lut); -void bquant_quantgrouped_preshuffleb_preshufflequant_instance_factory( +void bquant_quantgrouped_preshuffleb_fp8i4_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_preshuffleb_bf8i4_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_preshufflequant_fp8_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_preshufflequant_bf8_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_preshufflequant_fp8i4_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_preshufflequant_bf8i4_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_preshuffleb_preshufflequant_fp8_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_preshuffleb_preshufflequant_bf8_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4_instance_factory( std::unordered_map>& lut); void quant_rowcol_instance_factory( std::unordered_map>& lut); @@ -144,9 +162,18 @@ int main(int argc, char* argv[]) bquant_quantgrouped_fp8i4_instance_factory(lut); bquant_quantgrouped_bf8i4_instance_factory(lut); bquant_quantgrouped_bf16fp4_instance_factory(lut); - bquant_quantgrouped_preshuffleb_instance_factory(lut); - bquant_quantgrouped_preshufflequant_instance_factory(lut); - bquant_quantgrouped_preshuffleb_preshufflequant_instance_factory(lut); + bquant_quantgrouped_preshuffleb_fp8_instance_factory(lut); + bquant_quantgrouped_preshuffleb_bf8_instance_factory(lut); + bquant_quantgrouped_preshuffleb_fp8i4_instance_factory(lut); + bquant_quantgrouped_preshuffleb_bf8i4_instance_factory(lut); + bquant_quantgrouped_preshufflequant_fp8_instance_factory(lut); + bquant_quantgrouped_preshufflequant_bf8_instance_factory(lut); + bquant_quantgrouped_preshufflequant_fp8i4_instance_factory(lut); + bquant_quantgrouped_preshufflequant_bf8i4_instance_factory(lut); + bquant_quantgrouped_preshuffleb_preshufflequant_fp8_instance_factory(lut); + bquant_quantgrouped_preshuffleb_preshufflequant_bf8_instance_factory(lut); + bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4_instance_factory(lut); + bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4_instance_factory(lut); quant_rowcol_instance_factory(lut); quant_tensor_instance_factory(lut); diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index 313e449c7b..03b9dfe34d 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -319,7 +319,23 @@ struct BQuantBlockUniversalGemmAsBsCr if constexpr(PreshuffleQuant) { - constexpr index_t reg_offset = nIter; + // constexpr index_t reg_offset = nIter; + constexpr index_t reg_offset = [&]() { + if constexpr(GemmTraits::QuantGroupSize::kN > + (NWarp * WarpGemm::kN)) + { + if constexpr(Traits::NPerBlock == + GemmTraits::QuantGroupSize::kN) + return kQScale; + else + return nIter; // for prefill needs kQscale, for decode needs + // nIter + } + else + { + return nIter; + } + }(); auto pull_from_lane = (__lane_id() & (WarpGemm::kN - 1)) * Traits::KQPerBlock + kQScale; diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index 004fb18e0b..fd94dfb6b3 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -887,23 +887,27 @@ struct QuantGemmKernel if constexpr(PreshuffleQuant) { static_assert(std::is_same_v); - constexpr auto block_n = - TilePartitioner::NPerBlock / - QuantGroupSize::kN; // Number of N-dimension quantization groups per block - constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at( - I1); // Number of N-dimension elements per warp - constexpr auto warp_per_group = - (QuantGroupSize::kN < - warp_n) // Determine how many warps share the same scale in N-dimension - ? (warp_n / QuantGroupSize::kN) - : (QuantGroupSize::kN / warp_n); - constexpr auto bqk_per_block = - TilePartitioner::KPerBlock / - QuantGroupSize::kK; // Number of K-dimension quantization groups per block - constexpr auto - tile_window_width = // The pre-shuffled layout flattens warp_n × - // bqk_per_block scales per row, Padded up to warp_size - // to ensure coalesced memory access. + + // Number of N-dimension quantization groups per block + constexpr auto block_n = (QuantGroupSize::kN <= TilePartitioner::NPerBlock) + ? TilePartitioner::NPerBlock / QuantGroupSize::kN + : QuantGroupSize::kN / TilePartitioner::NPerBlock; + + // Number of N-dimension elements per warp + constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1); + + // Determine how many warps share the same scale in N-dimension + constexpr auto warp_per_group = (QuantGroupSize::kN < warp_n) + ? (warp_n / QuantGroupSize::kN) + : (QuantGroupSize::kN / warp_n); + + // Number of K-dimension quantization groups per block + constexpr auto bqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK; + + // The pre-shuffled layout flattens warp_n × + // bqk_per_block scales per row, Padded up to warp_size + // to ensure coalesced memory access. + constexpr auto tile_window_width = ck_tile::integer_least_multiple(warp_n * bqk_per_block, get_warp_size()); // Adapts based on fine vs coarse quantization granularity: @@ -916,23 +920,42 @@ struct QuantGemmKernel // height = block_n constexpr auto tile_window_height = (QuantGroupSize::kN < warp_n) ? block_n / warp_per_group : block_n; - auto block_n_idx = - i_n / TilePartitioner::NPerBlock; // Converts the global N-index (i_n) to a - // block index. - return make_tile_window( - bq_tensor_view, - make_tuple(number{}, number{}), - {block_n_idx * tile_window_height, 0}); + auto block_n_idx = i_n / TilePartitioner::NPerBlock; + + // For decode shapes GN: 128, Blocks needs to repeat 0,0,1,1,2,2 ... + if(QuantGroupSize::kN > TilePartitioner::NPerBlock) + { + block_n_idx = block_n_idx >> 1; + } + + if(QuantGroupSize::kN > TilePartitioner::NPerBlock) + { + return make_tile_window( + bq_tensor_view, + make_tuple(number{}, number{}), + {block_n_idx, 0}); + } + else + { + return make_tile_window( + bq_tensor_view, + make_tuple(number{}, number{}), + {block_n_idx * tile_window_height, 0}); + } } else { + constexpr auto tensor_dim = + (QuantGroupSize::kN <= TilePartitioner::NPerBlock) + ? TilePartitioner::NPerBlock / QuantGroupSize::kN + : 1; if constexpr(std::is_same_v) { return make_tile_window( bq_tensor_view, make_tuple(number{}, - number{}), + number{}), {0, i_n / QuantGroupSize::kN}); } else @@ -940,7 +963,7 @@ struct QuantGemmKernel static_assert(std::is_same_v); return make_tile_window( bq_tensor_view, - make_tuple(number{}, + make_tuple(number{}, number{}), {i_n / QuantGroupSize::kN, 0}); } diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp index 7e4182e84f..271b35859e 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp @@ -26,14 +26,15 @@ struct GemmBQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase= 1, "NPerBlock must be >= QuantGroupSize"); + // static_assert(NPerBlockBQ >= 1, "NPerBlock must be >= QuantGroupSize"); static_assert(KPerBlockBQ >= 1, "KPerBlock must be >= QuantGroupSize"); - static_assert(NPerBlock % QuantGroupSize::kN == 0, - "NPerBlock must be a multiple of QuantGroupSize::kN"); + // static_assert(NPerBlock % QuantGroupSize::kN == 0, + // "NPerBlock must be a multiple of QuantGroupSize::kN"); static_assert(KPerBlock % QuantGroupSize::kK == 0, "KPerBlock must be a multiple of QuantGroupSize::kK"); diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp index a4bba6cf76..5c4dfd37c7 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp @@ -45,7 +45,9 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t NPerBlockBQ = NPerBlock / Problem::BQuantGroupSize::kN; + constexpr index_t NPerBlockBQ = (Problem::QuantGroupSize::kN <= NPerBlock) + ? NPerBlock / Problem::QuantGroupSize::kN + : 1; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t KPerBlockBQ = KPerBlock / Problem::BQuantGroupSize::kK; constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp index 13d400d5fc..be91002cdb 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp @@ -66,7 +66,9 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3, @@ -275,15 +281,24 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding // Example: NPerQ=32, WarpGemm::kN=16, NWarps=4 // → KR=2 (2 warps share same scale), K1=2 (2 unique scale groups) - constexpr auto KR = NPerQ / WarpGemm::kN; // Number of warps sharing the same scale - constexpr auto K1 = NWarps / KR; // Number of distinct warp groups (unique scales) - constexpr auto K0 = KPerTile / K1; // Iterations to cover K-tile per warp group - constexpr auto N1 = BlockGemmShape::kK / KPerQ; // K-dimension quantization groups - constexpr auto N0 = 1; // Scales per warp in N-dim (1 since NPerQ >= WarpGemm::kN) - constexpr auto N2 = 1; // Elements per thread - constexpr auto NR1 = NPerQ; // Scale broadcast factor (full NPerQ) + // KR: Number of warps sharing the same scale + // K1: Number of distinct warp groups (unique scales) + // K0: Iterations to cover K-tile per warp group + // N1: K-dimension quantization groups + // N0: Scales per warp in N-dim (1 since NPerQ >= WarpGemm::kN) + // N2: Elements per thread + // NR1: Scale broadcast factor (full NPerQ) + // NR0: Remaining interleave factor + + constexpr auto KR = NPerQ / WarpGemm::kN; + constexpr auto K1 = NWarps / KR; + constexpr auto K0 = KPerTile / K1; + constexpr auto N1 = BlockGemmShape::kK / KPerQ; + constexpr auto N0 = 1; + constexpr auto N2 = 1; + constexpr auto NR1 = NPerQ; constexpr auto NR0 = - warp_size / (N0 * N1 * N2 * NR1); // Remaining interleave factor + (warp_size <= (N0 * N1 * N2 * NR1)) ? 1 : warp_size / (N0 * N1 * N2 * NR1); return make_static_tile_distribution( tile_distribution_encoding, @@ -303,12 +318,19 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding // // Example: NPerQ=128, WarpGemm::kN=16, NWarps=4 // → 128 >= 16*4=64, so all 4 warps use the same scale - constexpr auto N1 = BlockGemmShape::kK / KPerQ; // K-dimension quantization groups - constexpr auto N0 = 1; // Minimal (1) since scale is shared across N - constexpr auto N2 = 1; // Elements per thread - constexpr auto NR1 = 32; // Fixed broadcast size + + // N1: K-dimension quantization groups + // N0: Minimal (1) since scale is shared across N + // N2: Elements per thread + // NR1: Fixed broadcast size + // NR0: Remaining interleave factor + + constexpr auto N1 = BlockGemmShape::kK / KPerQ; + constexpr auto N0 = 1; + constexpr auto N2 = 1; + constexpr auto NR1 = 32; constexpr auto NR0 = - warp_size / (N0 * N1 * N2 * NR1); // Remaining interleave factor + (warp_size <= (N0 * N1 * N2 * NR1)) ? 1 : warp_size / (N0 * N1 * N2 * NR1); return make_static_tile_distribution( tile_distribution_encoding, tuple, sequence>, diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp index c8acb785cf..39b00d2501 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp @@ -79,10 +79,8 @@ struct GemmQuantPipelineProblemBase : public GemmPipelineProblemBase 0 && i_inst % ds_rep == 0) + if constexpr(ds_rep > 0) { - __builtin_amdgcn_sched_group_barrier( - LLVMSchedGroupMask::DS_READ, 1, 0); // DS read - } - if constexpr(ds_rep > 0 && i_inst % ds_rep == 1) - { - __builtin_amdgcn_sched_group_barrier( - LLVMSchedGroupMask::DS_WRITE, 1, 0); // DS write - } - - if constexpr(buffer_load_rep > 0 && i_inst % buffer_load_rep == 0) - { - if constexpr(ds_write_inst > 0) + if(i_inst % ds_rep == 0) { __builtin_amdgcn_sched_group_barrier( - LLVMSchedGroupMask::VMEM_READ, 1, 0); // VMEM read + LLVMSchedGroupMask::DS_READ, 1, 0); // DS read + } + } + if constexpr(ds_rep > 0) + { + if(i_inst % ds_rep == 1) + { + __builtin_amdgcn_sched_group_barrier( + LLVMSchedGroupMask::DS_WRITE, 1, 0); // DS write + } + } + + if constexpr(buffer_load_rep > 0) + { + if(i_inst % buffer_load_rep == 0) + { + if constexpr(ds_write_inst > 0) + { + __builtin_amdgcn_sched_group_barrier( + LLVMSchedGroupMask::VMEM_READ, 1, 0); // VMEM read + } } } // Always mark some VALU work in the loop to reflect auxiliary scalar @@ -354,7 +363,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV if constexpr(PreshuffleQuant) { move_tile_window(bq_copy_dram_window, - {((NPerBlockBQ < BlockGemmShape::BlockWarps::at(number<1>{})) + {((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{})) ? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN) : ck_tile::integer_least_multiple(n, kNPerBlock) / BlockGemmShape::WarpTile::at(number<1>{})), @@ -431,7 +440,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV if constexpr(PreshuffleQuant) { move_tile_window(bq_copy_dram_window, - {((NPerBlockBQ < BlockGemmShape::BlockWarps::at(number<1>{})) + {((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{})) ? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN) : ck_tile::integer_least_multiple(n, kNPerBlock) / BlockGemmShape::WarpTile::at(number<1>{})), @@ -468,7 +477,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV if constexpr(PreshuffleQuant) { move_tile_window(bq_copy_dram_window, - {((NPerBlockBQ < BlockGemmShape::BlockWarps::at(number<1>{})) + {((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{})) ? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN) : ck_tile::integer_least_multiple(n, kNPerBlock) / BlockGemmShape::WarpTile::at(number<1>{})), diff --git a/test/ck_tile/gemm_block_scale/CMakeLists.txt b/test/ck_tile/gemm_block_scale/CMakeLists.txt index 2dad8be205..5749a8d3b2 100644 --- a/test/ck_tile/gemm_block_scale/CMakeLists.txt +++ b/test/ck_tile/gemm_block_scale/CMakeLists.txt @@ -117,6 +117,27 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") ) target_compile_options(test_tile_gemm_quant_bquant_preshuffle_prefill_2d PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + # BQuant tests (with PreshuffleQuant) - split into 4 files + add_gtest_executable(test_tile_gemm_quant_bquant_preshuffleQuant_decode_1d + test_gemm_quant_bquant_preshuffleQuant_decode_1d.cpp + ) + target_compile_options(test_tile_gemm_quant_bquant_preshuffleQuant_decode_1d PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_tile_gemm_quant_bquant_preshuffleQuant_prefill_1d + test_gemm_quant_bquant_preshuffleQuant_prefill_1d.cpp + ) + target_compile_options(test_tile_gemm_quant_bquant_preshuffleQuant_prefill_1d PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_tile_gemm_quant_bquant_preshuffleQuant_decode_2d + test_gemm_quant_bquant_preshuffleQuant_decode_2d.cpp + ) + target_compile_options(test_tile_gemm_quant_bquant_preshuffleQuant_decode_2d PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_tile_gemm_quant_bquant_preshuffleQuant_prefill_2d + test_gemm_quant_bquant_preshuffleQuant_prefill_2d.cpp + ) + target_compile_options(test_tile_gemm_quant_bquant_preshuffleQuant_prefill_2d PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + # RowColQuant tests add_gtest_executable(test_tile_gemm_quant_rowcol test_gemm_quant_rowcol.cpp @@ -152,6 +173,11 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") test_tile_gemm_quant_bquant_preshuffle_tiled_permute test_tile_gemm_quant_bquant_preshuffle_decode_2d test_tile_gemm_quant_bquant_preshuffle_prefill_2d + # BQuant preshuffleQuant tests + test_tile_gemm_quant_bquant_preshuffleQuant_decode_1d + test_tile_gemm_quant_bquant_preshuffleQuant_prefill_1d + test_tile_gemm_quant_bquant_preshuffleQuant_decode_2d + test_tile_gemm_quant_bquant_preshuffleQuant_prefill_2d # Other quant tests test_tile_gemm_quant_rowcol test_tile_gemm_quant_tensor diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_decode_1d.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_decode_1d.cpp new file mode 100644 index 0000000000..661fd5bd33 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_decode_1d.cpp @@ -0,0 +1,39 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using BQuantGrouped = std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant Preshuffle tests - Decode Config 1D +// Tuple format: +// clang-format off +using BPreshuffleDecode1DTypes = ::testing::Types< + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for BQuant Preshuffle Decode 1D +TYPED_TEST_SUITE(TestCkTileGemmPreshuffleBBQuant, BPreshuffleDecode1DTypes); + +// BQuant PreshuffleB tests +TYPED_TEST(TestCkTileGemmPreshuffleBBQuant, BQuantPreshuffleTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_decode_2d.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_decode_2d.cpp new file mode 100644 index 0000000000..fb4020bcd7 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_decode_2d.cpp @@ -0,0 +1,54 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using BQuantGrouped = std::integral_constant; + +// 2d block sizes for BQuant +using GroupSize2D8N = ck_tile::QuantGroupShape>; +using GroupSize2D16N = ck_tile::QuantGroupShape>; +using GroupSize2D32N = ck_tile::QuantGroupShape>; +using GroupSize2D64N = ck_tile::QuantGroupShape>; +using GroupSize2D128N = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant Preshuffle tests - Decode 2D +// Tuple format: +// clang-format off +using BPreshuffleDecode2DTypes = ::testing::Types< + // 2d cases with preshuffle B + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for BQuant Preshuffle Decode 2D +TYPED_TEST_SUITE(TestCkTileGemmPreshuffleBBQuant, BPreshuffleDecode2DTypes); + +// BQuant PreshuffleB tests +TYPED_TEST(TestCkTileGemmPreshuffleBBQuant, BQuantPreshuffleTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_prefill_1d.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_prefill_1d.cpp new file mode 100644 index 0000000000..0d4e4d5f03 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_prefill_1d.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using BQuantGrouped = std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant Preshuffle tests - Prefill Config 1D +// Tuple format: +// clang-format off +using BPreshufflePrefill1DTypes = ::testing::Types< + std::tuple, + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for BQuant Preshuffle Prefill 1D +TYPED_TEST_SUITE(TestCkTileGemmPreshuffleBBQuant, BPreshufflePrefill1DTypes); + +// BQuant PreshuffleB tests +TYPED_TEST(TestCkTileGemmPreshuffleBBQuant, BQuantPreshuffleTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_prefill_2d.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_prefill_2d.cpp new file mode 100644 index 0000000000..edc7bcaa09 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_prefill_2d.cpp @@ -0,0 +1,63 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using BQuantGrouped = std::integral_constant; + +// 2d block sizes for BQuant +using GroupSize2D8N = ck_tile::QuantGroupShape>; +using GroupSize2D16N = ck_tile::QuantGroupShape>; +using GroupSize2D32N = ck_tile::QuantGroupShape>; +using GroupSize2D64N = ck_tile::QuantGroupShape>; +using GroupSize2D128N = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant Preshuffle tests - Prefill 2D +// Tuple format: +// clang-format off +using BPreshufflePrefill2DTypes = ::testing::Types< + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for BQuant Preshuffle Prefill 2D +TYPED_TEST_SUITE(TestCkTileGemmPreshuffleBBQuant, BPreshufflePrefill2DTypes); + +// BQuant PreshuffleB tests +TYPED_TEST(TestCkTileGemmPreshuffleBBQuant, BQuantPreshuffleTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_decode_2d.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_decode_2d.cpp index 65ea165b10..66fb62e67e 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_decode_2d.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_decode_2d.cpp @@ -19,10 +19,11 @@ using PkInt4 = ck_tile::pk_int4_t; using BQuantGrouped = std::integral_constant; // 2d block sizes for BQuant -using GroupSize2D8N = ck_tile::QuantGroupShape>; -using GroupSize2D16N = ck_tile::QuantGroupShape>; -using GroupSize2D32N = ck_tile::QuantGroupShape>; -using GroupSize2D64N = ck_tile::QuantGroupShape>; +using GroupSize2D8N = ck_tile::QuantGroupShape>; +using GroupSize2D16N = ck_tile::QuantGroupShape>; +using GroupSize2D32N = ck_tile::QuantGroupShape>; +using GroupSize2D64N = ck_tile::QuantGroupShape>; +using GroupSize2D128N = ck_tile::QuantGroupShape>; // Type combinations for BQuant Preshuffle tests - Decode 2D // Tuple format: , std::tuple, std::tuple, - std::tuple + std::tuple, + std::tuple, + std::tuple >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_prefill_2d.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_prefill_2d.cpp index 368204987a..ace07a37ae 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_prefill_2d.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_prefill_2d.cpp @@ -19,10 +19,11 @@ using PkInt4 = ck_tile::pk_int4_t; using BQuantGrouped = std::integral_constant; // 2d block sizes for BQuant -using GroupSize2D8N = ck_tile::QuantGroupShape>; -using GroupSize2D16N = ck_tile::QuantGroupShape>; -using GroupSize2D32N = ck_tile::QuantGroupShape>; -using GroupSize2D64N = ck_tile::QuantGroupShape>; +using GroupSize2D8N = ck_tile::QuantGroupShape>; +using GroupSize2D16N = ck_tile::QuantGroupShape>; +using GroupSize2D32N = ck_tile::QuantGroupShape>; +using GroupSize2D64N = ck_tile::QuantGroupShape>; +using GroupSize2D128N = ck_tile::QuantGroupShape>; // Type combinations for BQuant Preshuffle tests - Prefill 2D // Tuple format: , std::tuple, std::tuple, - std::tuple + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 3798cc4443..79c86935ef 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -53,11 +53,20 @@ struct GemmConfigBase static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); }; +struct GemmConfigDecode : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 16; + static constexpr ck_tile::index_t N_Tile = 64; + static constexpr ck_tile::index_t K_Tile = 256; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); +}; + struct GemmConfigPrefill : public GemmConfigBase { - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 128; + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); }; struct GemmConfigMxFp4 : public GemmConfigBase @@ -89,42 +98,26 @@ struct GemmConfigPadding : public GemmConfigBase static constexpr bool kPadK = true; }; -struct GemmConfigPreshuffleBDecode : public GemmConfigBase +struct GemmConfigPreshuffleBDecode : public GemmConfigDecode { static constexpr bool PreshuffleB = true; static constexpr bool DoubleSmemBuffer = true; - - // Default GEMM tile sizes for tests - static constexpr ck_tile::index_t M_Tile = 16; - static constexpr ck_tile::index_t N_Tile = 64; - static constexpr ck_tile::index_t K_Tile = 256; - - static constexpr ck_tile::index_t M_Warp = 1; - static constexpr ck_tile::index_t N_Warp = 4; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 16; - static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); }; -struct GemmConfigPreshuffleBPrefill : public GemmConfigBase +struct GemmConfigPreshuffleQuantDecode : public GemmConfigDecode +{ + static constexpr bool PreshuffleQuant = true; +}; + +struct GemmConfigPreshuffleBPrefill : public GemmConfigPrefill { static constexpr bool PreshuffleB = true; static constexpr bool DoubleSmemBuffer = true; +}; - // Default GEMM tile sizes for tests - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 128; - - static constexpr ck_tile::index_t M_Warp = 1; - static constexpr ck_tile::index_t N_Warp = 4; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 16; - static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); +struct GemmConfigPreshuffleQuantPrefill : public GemmConfigPrefill +{ + static constexpr bool PreshuffleQuant = true; }; struct GemmConfigPreshuffleBPrefillTiledPermuteN : public GemmConfigPreshuffleBPrefill From a07c8e38bd5152f2582dd0c8c1f8eef72f1086e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Wed, 14 Jan 2026 20:04:37 +0100 Subject: [PATCH 02/99] Fix grouped conv bwd data wmma check (#3562) --- ...e_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index d33e807828..b324845c3e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -1698,6 +1698,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 valid = false; } } + else + { + valid = false; + } } else { @@ -1716,6 +1720,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 valid = false; } } + else + { + valid = false; + } } if(!valid) { From a346cfa9607b6b334f99c8e32318cb29b81203dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Wed, 14 Jan 2026 21:37:12 +0100 Subject: [PATCH 03/99] Disable ActiveWorkgroupsPerCU for different arch in wmma kernels (#3566) --- .../impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp | 4 ++++ ...ice_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp | 4 ++++ .../impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp | 4 ++++ 3 files changed, 12 insertions(+) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp index 2a1a210398..126d107725 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp @@ -314,6 +314,10 @@ struct DeviceBatchedGemmMultiD_Wmma_CShuffleV3 { ActiveWorkgroupsPerCU() { + if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) + { + return; + } constexpr int dynamic_smem_size = 0; int max_occupancy = 0; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp index 843705692b..f9b2ff0596 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp @@ -466,6 +466,10 @@ struct DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3 { ActiveWorkgroupsPerCU() { + if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) + { + return; + } constexpr int dynamic_smem_size = 0; constexpr index_t minimum_occupancy = BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp index c070d8d9e9..3f8093afe1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp @@ -415,6 +415,10 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 { ActiveWorkgroupsPerCU() { + if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) + { + return; + } constexpr int dynamic_smem_size = 0; constexpr index_t minimum_occupancy = BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; From f08fb3f748ca693f0932d2552f30684b8a81f8f0 Mon Sep 17 00:00:00 2001 From: John Shumway Date: Wed, 14 Jan 2026 12:43:55 -0800 Subject: [PATCH 04/99] [CK_BUILDER] Update owners file for more reviews for CK Builder (#3572) Adding owners permissions for two leading developers on the CK Builder subproject to help with reviews on that project, especially in the EU time zones. Remove aska-0096, who has left AMD --- .github/CODEOWNERS | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index af36f492ba..0d7bcd6b18 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,8 +1,8 @@ -* @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @tenpercent @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd +* @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @tenpercent @ThomasNing @coderfeli @cgmillette @shumway @vidyasagar-amd @vpietila-amd @Snektron # Documentation files -docs/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd @ddembeckAMD -*.md @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd @ddembeckAMD -*.rst @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd @ddembeckAMD -.readthedocs.yaml @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd @ddembeckAMD +docs/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @cgmillette @shumway @vidyasagar-amd @ddembeckAMD @vpietila-amd @Snektron +*.md @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @cgmillette @shumway @vidyasagar-amd @ddembeckAMD @vpietila-amd @Snektron +*.rst @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @cgmillette @shumway @vidyasagar-amd @ddembeckAMD @vpietila-amd @Snektron +.readthedocs.yaml @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @cgmillette @shumway @vidyasagar-amd @ddembeckAMD @vpietila-amd @Snektron # Header directory for Doxygen documentation -library/include/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd +library/include/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @cgmillette @shumway @vidyasagar-amd @vpietila-amd @Snektron From 7f912909ca2c3cedfa1c6397d75daba4903a6d0d Mon Sep 17 00:00:00 2001 From: Emily Martins <65371150+ecamartins@users.noreply.github.com> Date: Wed, 14 Jan 2026 14:02:21 -0700 Subject: [PATCH 05/99] Disable CK Tile Stream-K reduction tests (#3559) The test_ck_tile_streamk_reduction test suite seems to have transient failures; hence, we are disabling these tests for now. We will re-enable them once the bug is resolved. --- test/ck_tile/gemm_streamk/CMakeLists.txt | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/ck_tile/gemm_streamk/CMakeLists.txt b/test/ck_tile/gemm_streamk/CMakeLists.txt index 1390e5ee07..6aaa145c7d 100644 --- a/test/ck_tile/gemm_streamk/CMakeLists.txt +++ b/test/ck_tile/gemm_streamk/CMakeLists.txt @@ -23,9 +23,10 @@ if(GPU_TARGETS MATCHES "gfx90a|gfx942|gfx950") #TODO: support all arches #TODO: current c-shuffle only supports C layout as R add_gtest_executable(test_ck_tile_streamk_tile_partitioner test_streamk_tile_partitioner.cpp) - add_gtest_executable(test_ck_tile_streamk_reduction - ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp16_reduction.cpp - test_gemm_streamk_util.cpp) + # TODO: Renable once transient bug for reduction is resolved. + # add_gtest_executable(test_ck_tile_streamk_reduction + # ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp16_reduction.cpp + # test_gemm_streamk_util.cpp) add_gtest_executable(test_ck_tile_streamk_smoke ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp16_persistent.cpp ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_bf16_persistent.cpp From 8705fdcb0c738907fea74b7ed39c9f73fb9a5892 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Wed, 14 Jan 2026 14:07:47 -0800 Subject: [PATCH 06/99] add aiter test_batch_prefill and simplify jenkins file a bit (#3570) --- Jenkinsfile | 40 ++++++---------------------------------- 1 file changed, 6 insertions(+), 34 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 9c670183fd..e01cfcbf01 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -811,41 +811,12 @@ def Build_CK(Map conf=[:]){ archiveArtifacts "perf_*.log" stash includes: "perf_**.log", name: "perf_log_${arch}" } - // disable performance tests on gfx1030 for now. - //else if ( arch == "gfx10"){ - // run basic tests on gfx1030 - // echo "Run gemm performance tests" - // sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx10" - // archiveArtifacts "perf_onnx_gemm_gfx10.log" - // stash includes: "perf_onnx_gemm_gfx10.log", name: "perf_log_gfx10" - //} - else if ( arch == "gfx11"){ - // run basic tests on gfx11 + else if ( arch != "gfx10"){ + // run basic tests on gfx11/gfx12/gfx908/gfx950, but not on gfx10, it takes too long echo "Run gemm performance tests" - sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx11" - archiveArtifacts "perf_onnx_gemm_gfx11.log" - stash includes: "perf_onnx_gemm_gfx11.log", name: "perf_log_gfx11" - } - else if ( arch == "gfx120" ){ - // run basic tests on gfx12 - echo "Run gemm performance tests" - sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx12" - archiveArtifacts "perf_onnx_gemm_gfx12.log" - stash includes: "perf_onnx_gemm_gfx12.log", name: "perf_log_gfx12" - } - else if ( arch == "gfx908" ){ - // run basic tests on gfx908 - echo "Run performance tests" - sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx908" - archiveArtifacts "perf_onnx_gemm_gfx908.log" - stash includes: "perf_onnx_gemm_gfx908.log", name: "perf_log_gfx908" - } - else if ( arch == "gfx950" ){ - // run basic tests on gfx950 - echo "Run performance tests" - sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx950" - archiveArtifacts "perf_onnx_gemm_gfx950.log" - stash includes: "perf_onnx_gemm_gfx950.log", name: "perf_log_gfx950" + sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} ${arch}" + archiveArtifacts "perf_onnx_gemm_*.log" + stash includes: "perf_onnx_gemm_**.log", name: "perf_log_${arch}" } } } @@ -1049,6 +1020,7 @@ def run_aiter_tests(Map conf=[:]){ sh "python3 /home/jenkins/workspace/aiter/op_tests/test_gemm_a8w8_blockscale.py" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_mha.py" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_mha_varlen.py" + sh "python3 /home/jenkins/workspace/aiter/op_tests/test_batch_prefill.py" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe.py" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe_2stage.py" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe_blockscale.py" From df7ee270a6bbe5d8562d954a919c7299512dad73 Mon Sep 17 00:00:00 2001 From: John Shumway Date: Wed, 14 Jan 2026 16:41:34 -0500 Subject: [PATCH 07/99] Update README.md files to match recent code changes This is mostly adjustments to enum values so that the docs align correctly with the current code. Also updated the calendar scope of the project to extend through March 2026. --- experimental/builder/README.md | 4 ++-- experimental/builder/include/ck_tile/builder/README.md | 10 ++++++++-- .../builder/include/ck_tile/builder/testing/README.md | 8 ++++---- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/experimental/builder/README.md b/experimental/builder/README.md index 1156de0e9c..850bcf136e 100644 --- a/experimental/builder/README.md +++ b/experimental/builder/README.md @@ -2,13 +2,13 @@ This directory contains the experimental builder feature for composable_kernel. -* Status: In development (October - December 2025) +* Status: In development (October 2025 - March 2026) ## Overview The builder provides a high-level, semantically-clear interface for constructing composable kernel operations, with an initial focus on convolution kernels for MIOpen. It leverages modern C++20 features (such as POD structs as non-type template parameters, concepts, and designated initializers) to simplify kernel instantiation and improve developer experience. -This project is a prototype for a more general builder pattern for all of composable_kernel (CK) and CKTile, but is currently limited to formalizing the interface between MIOpen and CK. +This project is a prototype for a more general builder pattern for all of composable_kernel (CK) and CK Tile, but is currently limited to formalizing the interface between MIOpen and CK. ## Design descriptions diff --git a/experimental/builder/include/ck_tile/builder/README.md b/experimental/builder/include/ck_tile/builder/README.md index af8c4ec01b..0af0cede60 100644 --- a/experimental/builder/include/ck_tile/builder/README.md +++ b/experimental/builder/include/ck_tile/builder/README.md @@ -100,8 +100,8 @@ concept ConvSignatureDescriptor = requires(T t) { - `FORWARD`: Standard forward convolution - `BACKWARD_DATA`: Gradient computation w.r.t. input - `BACKWARD_WEIGHT`: Gradient computation w.r.t. weights -- **`data_type`**: Default data type for all tensors (FP32, FP16, BF16, FP8, I8, U8). (Optional, defaults to UNDEFINED_DATA_TYPE, may be overridden by tensors) -- **`operation`**: Default Operation (Optional, defaults to PASS_THROUGH, may be overridden by tensors) +- **`data_type`**: Default data type for all tensors (FP32, FP16, BF16, FP8, I8, U8). (Optional, defaults to UNDEFINED_DATA_TYPE which indicates the type should be inferred or specified per-tensor, may be overridden by individual tensors) +- **`elementwise_operation`**: Default elementwise operation for all tensors (Optional, defaults to PASS_THROUGH, may be overridden by individual tensors via their `operation` field) - **`accumulation_data_type`**: Type used for internal accumulation #### 2. Tensor Level @@ -133,6 +133,9 @@ concept TensorConfigDescriptor = requires(T t) { ``` **Layout Types** (dimension-specific): +- **Special Values**: + - `UNDEFINED_TENSOR_LAYOUT`: Placeholder value indicating layout is not yet specified or should be inferred + - **1D Convolution**: - Input: `GNCW`, `GNWC`, `NWGC`, `NGCW`, `G_NW_C_strided` - Weight: `GKXC`, `GKCX`, `KXGC`, `G_K_X_C_strided` @@ -148,6 +151,9 @@ concept TensorConfigDescriptor = requires(T t) { - Weight: `GKZYXC`, `GKCZYX`, `KZYXGC`, `G_K_ZYX_C_strided` - Output: `GNKDHW`, `GNDHWK`, `NDHWGK`, `NGKDHW`, `G_NDHW_K_strided` +- **Bias Tensors**: + - `GC`, `G_C_strided`, `G_K_strided` + Where: - `G` = Groups - `N` = Batch size diff --git a/experimental/builder/include/ck_tile/builder/testing/README.md b/experimental/builder/include/ck_tile/builder/testing/README.md index 85adc59d80..c6662c2b04 100644 --- a/experimental/builder/include/ck_tile/builder/testing/README.md +++ b/experimental/builder/include/ck_tile/builder/testing/README.md @@ -53,7 +53,7 @@ struct ConvSignature { ck_tile::builder::DataType data_type = ck_tile::builder::DataType::FP16; ck_tile::builder::ElementwiseOperation elementwise_operation = - ck_tile::builder::ElementwiseOperation::NONE; + ck_tile::builder::ElementwiseOperation::PASS_THROUGH; }; // Double-check that out structure is well-defined according to the CK-Builder API. @@ -66,7 +66,7 @@ constexpr auto SIGNATURE = ConvSignature{ .direction = ck_tile::builder::ConvDirection::FORWARD, .layout = ck_tile::builder::GroupConvLayout2D::NHWGC_GKYXC_NHWGK, .data_type = ck_tile::builder::DataType::FP16, - .elementwise_operation = ck_tile::builder::ElementwiseOperation::NONE, + .elementwise_operation = ck_tile::builder::ElementwiseOperation::PASS_THROUGH, }; ``` @@ -243,7 +243,7 @@ struct ConvSignature { ck_tile::builder::DataType data_type = ck_tile::builder::DataType::FP16; ck_tile::builder::ElementwiseOperation elementwise_operation = - ck_tile::builder::ElementwiseOperation::NONE; + ck_tile::builder::ElementwiseOperation::PASS_THROUGH; }; static_assert(ck_tile::builder::ConvSignatureDescriptor); constexpr auto SIGNATURE = ConvSignature{ @@ -251,7 +251,7 @@ constexpr auto SIGNATURE = ConvSignature{ .direction = ck_tile::builder::ConvDirection::FORWARD, .layout = ck_tile::builder::GroupConvLayout2D::NHWGC_GKYXC_NHWGK, .data_type = ck_tile::builder::DataType::FP16, - .elementwise_operation = ck_tile::builder::ElementwiseOperation::NONE, + .elementwise_operation = ck_tile::builder::ElementwiseOperation::PASS_THROUGH, }; // Define the convolution algorithm From 51226372156901aa20a34ed5146d6bd57c63e519 Mon Sep 17 00:00:00 2001 From: John Shumway Date: Thu, 15 Jan 2026 01:03:21 -0800 Subject: [PATCH 08/99] [CK_BUILDER] Convert convolution traits to a struct with factory functions (#3547) * Factor helpers out of conv_traits.hpp * Create a non-templated conv_traits struct * Migrate to new instance-specific instance_to_conv_traits functions * Clean up reflection concepts * Clean up ConvTraits helpers * Update testing for convolution traits This is a lot of cleanup on tests to have verbose coverage of feature extraction, explicit tests for each supported device kernel, and simple, readable test code. * Address reviewer comments and resolve merge conflict --- .../ck_tile/builder/reflect/conv_describe.hpp | 61 +- .../ck_tile/builder/reflect/conv_traits.hpp | 727 ++--------- ...ped_conv_fwd_multiple_abd_xdl_cshuffle.hpp | 84 ++ ..._conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp | 84 ++ ...d_multiple_d_xdl_large_tensor_cshuffle.hpp | 84 ++ .../builder/reflect/conv_traits_helpers.hpp | 739 +++++++++++ .../reflect/instance_to_conv_traits.hpp | 8 + ...ped_conv_fwd_multiple_abd_xdl_cshuffle.hpp | 8 + ..._conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp | 8 + ...d_multiple_d_xdl_large_tensor_cshuffle.hpp | 8 + experimental/builder/test/CMakeLists.txt | 3 +- .../builder/test/conv/ck/test_conv_traits.cpp | 156 +-- .../conv/ck/unit_instance_to_conv_traits.cpp | 1127 ----------------- .../unit_instance_to_conv_traits_features.cpp | 800 ++++++++++++ ...unit_instance_to_conv_traits_instances.cpp | 262 ++++ ...ped_conv_fwd_multiple_abd_xdl_cshuffle.hpp | 2 +- ...d_multiple_d_xdl_large_tensor_cshuffle.hpp | 2 +- 17 files changed, 2288 insertions(+), 1875 deletions(-) create mode 100644 experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp create mode 100644 experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp create mode 100644 experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp create mode 100644 experimental/builder/include/ck_tile/builder/reflect/conv_traits_helpers.hpp create mode 100644 experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp delete mode 100644 experimental/builder/test/conv/ck/unit_instance_to_conv_traits.cpp create mode 100644 experimental/builder/test/conv/ck/unit_instance_to_conv_traits_features.cpp create mode 100644 experimental/builder/test/conv/ck/unit_instance_to_conv_traits_instances.cpp diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_describe.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_describe.hpp index fdbfa7c4e1..359b12c4a3 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_describe.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_describe.hpp @@ -7,43 +7,52 @@ #pragma once #include "ck_tile/builder/reflect/conv_description.hpp" -#include "ck_tile/builder/reflect/conv_traits.hpp" +#include "ck_tile/builder/reflect/instance_to_conv_traits.hpp" namespace ck_tile::reflect { -/// @brief Factory function to create ConvDescription from a convolution instance type -/// @tparam Instance The convolution instance type (must have ConvTraits) -/// @return A ConvDescription object populated with the instance's configuration details -template +/// @brief Concept to check if an Instance type has conv traits +template +concept HasConvTraits = requires { + { conv::instance_to_conv_traits() }; +}; + +/// Factory function to create ConvDescription from a convolution instance type +/// Instance The convolution instance type +/// A ConvDescription object populated with the instance's configuration details +/// +/// TODO: Fix ConvDescription to just use the ConvTraits directly. +template + requires HasConvTraits conv::ConvDescription describe() { - using Traits = conv::ConvTraits; + const auto traits = conv::instance_to_conv_traits(); return conv::ConvDescription( conv::ConvSignatureInfo{ - .spatial_dim = Traits::spatial_dim, - .direction = Traits::direction, - .input_layout = Traits::layout[0], - .weight_layout = Traits::layout[1], - .output_layout = Traits::layout[2], - .data_type = Traits::data_type, - .input_element_op = Traits::input_element_op, - .weight_element_op = Traits::weight_element_op, - .output_element_op = Traits::output_element_op, + .spatial_dim = traits.spatial_dim, + .direction = traits.direction, + .input_layout = traits.layout[0], + .weight_layout = traits.layout[1], + .output_layout = traits.layout[2], + .data_type = traits.data_type, + .input_element_op = traits.input_element_op, + .weight_element_op = traits.weight_element_op, + .output_element_op = traits.output_element_op, }, conv::GemmAlgorithmInfo{ - .thread_block_size = Traits::thread_block_size, - .tile_dims = Traits::tile_dims, - .warp_gemm = Traits::warp_gemm, - .a_tile_transfer = Traits::a_tile_transfer, - .b_tile_transfer = Traits::b_tile_transfer, - .c_tile_transfer = Traits::c_tile_transfer, - .pipeline_version = Traits::pipeline_version, - .pipeline_scheduler = Traits::pipeline_scheduler, - .conv_specialization = Traits::conv_specialization, - .padding = Traits::gemm_padding, + .thread_block_size = traits.thread_block_size, + .tile_dims = traits.tile_dims, + .warp_gemm = traits.warp_gemm, + .a_tile_transfer = traits.a_tile_transfer, + .b_tile_transfer = traits.b_tile_transfer, + .c_tile_transfer = traits.c_tile_transfer, + .pipeline_version = traits.pipeline_version, + .pipeline_scheduler = traits.pipeline_scheduler, + .conv_specialization = traits.conv_specialization, + .padding = traits.gemm_padding, }, - []() { return reflect::instance_string(); }); + []() { return reflect::instance_string(); }); } } // namespace ck_tile::reflect diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp index 8caa11618e..451a74be34 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp @@ -1,664 +1,109 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT +// Runtime-accessible convolution kernel configuration data structure +// +// This file defines ConvTraits, a pure data structure that captures the complete +// configuration of a convolution kernel in a domain-specific abstraction, without +// requiring knowledge of the underlying kernel instance implementation details. +// +// ## Purpose and Design +// +// ConvTraits provides type erasure for convolution kernel configurations, allowing +// for reflection of convolution kernel objects. The struct represents kernel +// traits in terms of convolution-specific concepts for AMD GPUs rather than raw +// template parameters. +// +// ## Architecture and Usage +// +// ConvTraits sits at the center of the reflection system: +// +// 1. **Population**: Values are created by `instance_to_conv_traits()` template +// specializations that extract configuration from compile-time InstanceTraits +// +// 2. **Consumption**: Used by ConvDescription to provide human-readable descriptions +// of kernel configurations for debugging, logging, and documentation +// +// ## Structure Organization +// +// The struct separates kernel configuration into two logical categories: +// +// - **Signature Information**: Defines what the kernel computes (direction, layouts, +// data types, elementwise operations, specializations) +// +// - **Algorithm Information**: Defines how the kernel computes (thread block size, +// tile dimensions, memory access patterns, pipeline configuration) +// +// ## Evolution and Extensibility +// +// ConvTraits is designed to evolve through composition (not inheritance): +// +// - Currently supports XDL forward convolution kernels +// - Will extend to the other forward convolutions +// - Will be extended to cover backward data and backward weight convolutions +// - Will incorporate fusion operations and additional specializations +// - Uses std::optional and std::variant for optional/variant fields +// - Eventually will generalize to KernelTraits for GEMM, flash attention, etc. + #pragma once -#include -#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" -#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" -#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/utility/pipeline_enum.hpp" -#include "ck/utility/scheduler_enum.hpp" -#include "ck_tile/builder/conv_signature_concepts.hpp" #include "ck_tile/builder/reflect/conv_types.hpp" -#include "ck_tile/builder/reflect/instance_traits.hpp" -#include "ck_tile/builder/reflect/instance_traits_util.hpp" #include "ck_tile/builder/types.hpp" -#include "ck_tile/ops/epilogue.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" -#include "ck_tile/ops/grouped_convolution.hpp" namespace ck_tile::reflect::conv { -// Forward convolution layout concept - checks for A/B/E layout types -template -concept HasFwdConvLayouts = requires { - typename T::ALayout; - typename T::BLayout; - typename T::ELayout; -}; - -// GEMM specialization concept - checks for kGemmSpecialization member -template -concept HasGemmSpec = requires { - { - T::kGemmSpecialization - } -> std::convertible_to; -}; - -// Data types concept - checks for ADataType member -template -concept HasDataTypes = requires { typename T::ADataType; }; - -// Elementwise operations concept - checks for A/B/CDE elementwise operation types -template -concept HasElementwiseOps = requires { - typename T::AElementwiseOperation; - typename T::BElementwiseOperation; - typename T::CDEElementwiseOperation; -}; - -// Tile parameters concept - checks for tile dimension and transfer members -template -concept HasTileParams = requires { - { T::kKPerBlock } -> std::convertible_to; - { T::kMPerBlock } -> std::convertible_to; - { T::kNPerBlock } -> std::convertible_to; - { T::kAK1 } -> std::convertible_to; - { T::kBK1 } -> std::convertible_to; - T::kCThreadClusterLengths; -}; - -// Comprehensive concept that checks if an instance has all XDL forward convolution traits -// This concept is used to constrain ConvTraits specialization that expect XDL forward convolutions -template -concept IsXdlFwdConv = HasFwdConvLayouts && HasGemmSpec && HasDataTypes && - HasElementwiseOps && HasTileParams; - -// Primary concept for checking if a type can be described -// Currently only forward convolutions are supported, but this can be extended -// in the future to include backward data and backward weight convolutions -template -concept HasConvTraits = IsXdlFwdConv>; - -// Helper metafunctions to convert from ck enums to builder enums - -/// @brief Converts a CK BlockGemmPipelineVersion enum to a builder PipelineVersion enum. -/// @tparam ck_ver The CK BlockGemmPipelineVersion enum value to convert. -/// @return The corresponding builder::PipelineVersion enum value (V1, V2, V3, V4, or V5). -/// @details This function maps CK's block GEMM pipeline version identifiers to the -/// builder framework's standardized pipeline version enum. The pipeline version -/// determines the strategy used for data movement and computation overlap in the -/// GEMM kernel's main loop. -template -constexpr auto convert_pipeline_version() +// Runtime data structure representing a convolution kernel's complete configuration +// +// This pure data struct (no template parameters, no static members) provides +// type erasure for convolution kernel configurations. It can hold the configuration +// from any convolution kernel instance, enabling runtime storage, comparison, and +// manipulation of kernel properties. +// +// The struct is populated by `instance_to_conv_traits()` template specializations +// that extract compile-time configuration from InstanceTraits and convert it to +// this standardized runtime representation. +// +// Members are organized into two categories: +// - **Signature Information**: Defines the computational interface (what to compute) +// - **Algorithm Information**: Defines the implementation strategy (how to compute) +// +// Note: This struct will evolve to support additional convolution variants and +// eventually generalize to other kernel types through composition. +// +// There is a lot we still need to do: +// +// TODO: Generalize type support for all tensors and accumulator. +// TODO: Describe all tensros. +// TODO: Include the full generalization of the signature from the input schema. +// TODO: Include the full generalization of the algorithm from the input schema. +struct ConvTraits { - using enum ck::BlockGemmPipelineVersion; - using enum builder::PipelineVersion; - - switch(ck_ver) - { - case v1: return V1; - case v2: return V2; - case v3: return V3; - case v4: return V4; - case v5: return V5; - } -} - -/// @brief Converts a CK PipelineVersion enum to a builder PipelineVersion enum. -/// @tparam ck_ver The CK PipelineVersion enum value to convert. -/// @return The corresponding builder::PipelineVersion enum value (V1, V2, V4, or WEIGHT_ONLY). -/// @details This function maps CK's general pipeline version identifiers to the -/// builder framework's standardized pipeline version enum. Note that this overload -/// handles a different set of pipeline versions compared to the BlockGemmPipelineVersion -/// variant, including support for specialized weight-only pipelines. -template -constexpr auto convert_pipeline_version() -{ - using enum ck::PipelineVersion; - using enum builder::PipelineVersion; - - switch(ck_ver) - { - case v1: return V1; - case v2: return V2; - case v4: return V4; - case weight_only: return WEIGHT_ONLY; - } -} - -/// @brief Converts a CK BlockGemmPipelineScheduler enum to a builder PipelineScheduler enum. -/// @tparam ck_sched The CK BlockGemmPipelineScheduler enum value to convert. -/// @return The corresponding builder::PipelineScheduler enum value (INTRAWAVE or INTERWAVE). -/// @details This function maps CK's block GEMM pipeline scheduler identifiers to the -/// builder framework's standardized scheduler enum. The scheduler determines how work -/// is distributed and synchronized within and across wavefronts during pipeline execution. -/// INTRAWAVE scheduling operates within a single wavefront, while INTERWAVE coordinates -/// across multiple wavefronts. -template -constexpr auto convert_pipeline_scheduler() -{ - using enum ck::BlockGemmPipelineScheduler; - using enum builder::PipelineScheduler; - - switch(ck_sched) - { - case Intrawave: return INTRAWAVE; - case Interwave: return INTERWAVE; - } -} - -/// @brief Converts a CK LoopScheduler enum to a builder PipelineScheduler enum. -/// @tparam ck_sched The CK LoopScheduler enum value to convert. -/// @return The corresponding builder::PipelineScheduler enum value (DEFAULT or INTERWAVE). -/// @details This function maps CK's loop scheduler identifiers to the builder framework's -/// standardized pipeline scheduler enum. The loop scheduler controls how iterations of -/// the main computational loop are scheduled across threads. DEFAULT uses the standard -/// scheduling strategy, while INTERWAVE enables cross-wavefront coordination for improved -/// performance in certain scenarios. -template -constexpr auto convert_pipeline_scheduler() -{ - using enum ck::LoopScheduler; - using enum builder::PipelineScheduler; - - switch(ck_sched) - { - case Default: return DEFAULT; - case Interwave: return INTERWAVE; - } -} - -// Helper metafunctions to derive signature information from Instance types - -/// @brief Helper function to report unsupported convolution direction with a clear error message. -template -[[noreturn]] consteval void report_unsupported_conv_direction_error() -{ - throw "Unsupported convolution direction detected!\n" - "The kernel instance does not have a recognized convolution specialization.\n" - "Expected one of: kConvForwardSpecialization, kConvBwdDataSpecialization, or " - "kConvBwdWeightSpecialization.\n" - "Please verify that your kernel instance is properly configured."; -} - -/// @brief Derives the convolution direction from a device kernel `Instance` type. -/// @tparam Instance The device kernel instance type. -/// @return A `builder::ConvDirection` enum value (FORWARD, BACKWARD_DATA, or BACKWARD_WEIGHT). -template -constexpr builder::ConvDirection conv_direction() -{ - using InstTraits = InstanceTraits; - - if constexpr(requires { &InstTraits::kConvForwardSpecialization; }) - return builder::ConvDirection::FORWARD; - else if constexpr(requires { &InstTraits::kConvBwdDataSpecialization; }) - return builder::ConvDirection::BACKWARD_DATA; - else if constexpr(requires { &InstTraits::kConvBwdWeightSpecialization; }) - return builder::ConvDirection::BACKWARD_WEIGHT; - else - { - report_unsupported_conv_direction_error(); - return builder::ConvDirection::FORWARD; // Unreachable - } -} - -/// @brief Derives the convolution-specific specialization from a device kernel `Instance` type. -/// @tparam Instance The device kernel instance type. -/// @return A `builder::ConvSpecialization` enum value. -template -constexpr auto conv_spec() -{ - using InstTraits = InstanceTraits; - using enum builder::ConvSpecialization; - - if constexpr(requires { InstTraits::kConvForwardSpecialization; }) - { - using enum ck::tensor_operation::device::ConvolutionForwardSpecialization; - switch(InstTraits::kConvForwardSpecialization) - { - case Default: return DEFAULT; - case Filter1x1Pad0: return FILTER_1X1_PAD0; - case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0; - case Filter3x3: return FILTER_3x3; - case OddC: return ODD_C; - } - } - else if constexpr(requires { InstTraits::kConvBwdDataSpecialization; }) - { - using enum ck::tensor_operation::device::ConvolutionBackwardDataSpecialization; - switch(InstTraits::kConvBwdDataSpecialization) - { - case Default: return DEFAULT; - case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0; - } - } - else if constexpr(requires { InstTraits::kConvBwdWeightSpecialization; }) - { - using enum ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization; - switch(InstTraits::kConvBwdWeightSpecialization) - { - case Default: return DEFAULT; - case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0; - case Filter1x1Pad0: return FILTER_1X1_PAD0; - case OddC: return ODD_C; - } - } -} - -// Helper variable template to check if CK layout enums match -template -inline constexpr bool layouts_are = - std::is_same_v && std::is_same_v && std::is_same_v; - -/// @brief Helper function to report unsupported layout combinations with a clear error message. -/// @details This consteval function is designed to fail at compile time with a descriptive -/// error message when an unsupported layout combination is encountered. -template -[[noreturn]] consteval void report_unsupported_layout_error() -{ - // This will produce a compile-time error with the exception message - throw "Unsupported convolution layout combination detected!\n" - "The combination of ALayout, BLayout, and ELayout template parameters\n" - "is not recognized for the given spatial dimension.\n" - "Please verify that your convolution instance uses a supported layout configuration.\n" - "Check the conv_layout() function for the list of supported layout combinations."; -} - -/// @brief Derives the grouped convolution layout from a device kernel `Instance` type. -/// @tparam Instance The device kernel instance type. -/// @return An std::array corresponding to the tensor layouts: -/// index 0 -> Input layout -/// index 1 -> Weight layout -/// index 2 -> Output layout -template -constexpr auto conv_layout() - requires HasFwdConvLayouts> -{ - // Helper lambda to construct layout array - auto layouts = [](auto... Ls) { return std::array{Ls...}; }; - - using A = typename InstanceTraits::ALayout; - using B = typename InstanceTraits::BLayout; - using E = typename InstanceTraits::ELayout; - namespace ctl = ck::tensor_layout::convolution; - using enum builder::TensorLayout; - - switch(InstanceTraits::kSpatialDim) - { - case 1: - if constexpr(layouts_are) - return layouts(GNWC, GKXC, GNWK); - if constexpr(layouts_are) - return layouts(GNWC, GKXC, GNWK); - if constexpr(layouts_are) - return layouts(NWGC, GKXC, NWGK); - if constexpr(layouts_are) - return layouts(NGCW, GKXC, NGKW); - if constexpr(layouts_are) - return layouts(NGCW, GKCX, NGKW); - break; - case 2: - if constexpr(layouts_are) - return layouts(GNHWC, GKYXC, GNHWK); - if constexpr(layouts_are) - return layouts(GNHWC, GKYXC, GNHWK); - if constexpr(layouts_are) - return layouts(NHWGC, GKYXC, NHWGK); - if constexpr(layouts_are) - return layouts(NHWGC, GKYXC, NHWGK); - if constexpr(layouts_are) - return layouts(NGCHW, GKYXC, NGKHW); - if constexpr(layouts_are) - return layouts(NGCHW, GKCYX, NGKHW); - break; - case 3: - if constexpr(layouts_are) - return layouts(GNDHWC, GKZYXC, GNDHWK); - if constexpr(layouts_are) - return layouts(GNDHWC, GKZYXC, GNDHWK); - if constexpr(layouts_are) - return layouts(NDHWGC, GKZYXC, NDHWGK); - if constexpr(layouts_are) - return layouts(NGCDHW, GKZYXC, NGKDHW); - if constexpr(layouts_are) - return layouts(NGCDHW, GKCZYX, NGKDHW); - break; - } - - // If we reach here, the layout combination is not supported - // Call consteval function to trigger a compile-time error with a clear message - report_unsupported_layout_error::kSpatialDim>(); - - // This return is unreachable but needed to satisfy the compiler - return layouts(GNHWC, GKYXC, GNHWK); -} - -/// @brief Helper function to report unsupported data type with a clear error message. -template -[[noreturn]] consteval void report_unsupported_data_type_error() -{ - throw "Unsupported data type detected!\n" - "The ADataType is not recognized.\n" - "Supported types are: ck::half_t (FP16), ck::Tuple (FP16_FP16), " - "ck::bhalf_t (BF16), ck::Tuple (BF16_BF16), float (FP32), " - "ck::Tuple (FP32_FP32), double (FP64), ck::f8_t (FP8), ck::bf8_fnuz_t " - "(BF8), " - "int8_t (I8), ck::Tuple (I8_I8), uint8_t (U8).\n" - "Please verify that your kernel instance uses a supported data type."; -} - -/// @brief Derives the data type from a device kernel `Instance` type. -/// Returns a `builder::DataType` enum value (e.g., FP16, BF16, FP32, BF8). -template -constexpr builder::DataType conv_data_type() - requires HasDataTypes> -{ - using InstTraits = InstanceTraits; - using ADataType = typename InstTraits::ADataType; - using enum builder::DataType; - - if constexpr(std::is_same_v) - return FP16; - else if constexpr(std::is_same_v>) - return FP16_FP16; - else if constexpr(std::is_same_v) - return BF16; - else if constexpr(std::is_same_v>) - return BF16_BF16; - else if constexpr(std::is_same_v) - return FP32; - else if constexpr(std::is_same_v>) - return FP32_FP32; - else if constexpr(std::is_same_v) - return FP64; - else if constexpr(std::is_same_v) - return FP8; - else if constexpr(std::is_same_v) - return BF8; - else if constexpr(std::is_same_v) - return BF8; - else if constexpr(std::is_same_v) - return I8; - else if constexpr(std::is_same_v>) - return I8_I8; - else if constexpr(std::is_same_v) - return U8; - else - { - report_unsupported_data_type_error(); - return FP32; // Unreachable - } -} - -/// @brief Helper function to report unsupported elementwise operation with a clear error message. -template -[[noreturn]] consteval void report_unsupported_elementwise_op_error() -{ - throw "Unsupported elementwise operation detected!\n" - "The elementwise operation type is not recognized.\n" - "Supported operations are: AddClamp, AddReluAdd, BiasBnormClamp, Bilinear, " - "BiasNormalizeInInferClamp, Clamp, ConvInvscale, ConvScale, ConvScaleAdd, " - "ConvScaleRelu, Scale, ScaleAdd, PassThrough, ScaleAddScaleAddRelu, DynamicUnaryOp, " - "UnaryCombinedOp, Activation_Mul2_Clamp, Activation_Mul_Clamp, Add_Activation_Mul_Clamp, " - "Add_Activation_Mul2_Clamp, Add_Mul_Activation_Mul_Clamp, Add_Mul2_Activation_Mul_Clamp, " - "UnaryConvert.\n" - "Please verify that your kernel instance uses a supported elementwise operation."; -} - -/// @brief Derives the elementwise operation from op type. -/// @tparam ElementwiseOp Elementwise operation functor type. -/// @return A `builder::ElementwiseOperation` enum value corresponding to elementwise operation. -template -constexpr builder::ElementwiseOperation elementwise_op() -{ - using enum builder::ElementwiseOperation; - constexpr std::string_view name = detail::elementwise_op_name(); - - if constexpr(detail::case_insensitive_equal(name, "AddClamp")) - return ADD_CLAMP; - else if constexpr(detail::case_insensitive_equal(name, "AddReluAdd")) - return ADD_RELU_ADD; - else if constexpr(detail::case_insensitive_equal(name, "BiasBnormClamp")) - return BIAS_BNORM_CLAMP; - else if constexpr(detail::case_insensitive_equal(name, "Bilinear")) - return BILINEAR; - else if constexpr(detail::case_insensitive_equal(name, "BiasNormalizeInInferClamp")) - return BIAS_BNORM_CLAMP; - else if constexpr(detail::case_insensitive_equal(name, "Clamp")) - return CLAMP; - else if constexpr(detail::case_insensitive_equal(name, "ConvInvscale")) - return CONV_INVSCALE; - else if constexpr(detail::case_insensitive_equal(name, "ConvScale")) - return CONV_SCALE; - else if constexpr(detail::case_insensitive_equal(name, "ConvScaleAdd")) - return CONV_SCALE_ADD; - else if constexpr(detail::case_insensitive_equal(name, "ConvScaleRelu")) - return CONV_SCALE_RELU; - else if constexpr(detail::case_insensitive_equal(name, "Scale")) - return SCALE; - else if constexpr(detail::case_insensitive_equal(name, "ScaleAdd")) - return SCALE_ADD; - else if constexpr(detail::case_insensitive_equal(name, "PassThrough")) - return PASS_THROUGH; - else if constexpr(detail::case_insensitive_equal(name, "ScaleAddScaleAddRelu")) - return SCALEADD_SCALEADD_RELU; - else if constexpr(detail::case_insensitive_equal(name, "DynamicUnaryOp")) - return DYNAMIC_UNARY_OP; - else if constexpr(detail::case_insensitive_equal(name, "UnaryCombinedOp")) - return UNARY_COMBINED_OP; - else if constexpr(detail::case_insensitive_equal(name, "Activation_Mul2_Clamp")) - return ACTIVATION_MUL2_CLAMP; - else if constexpr(detail::case_insensitive_equal(name, "Activation_Mul_Clamp")) - return ACTIVATION_MUL_CLAMP; - else if constexpr(detail::case_insensitive_equal(name, "Add_Activation_Mul_Clamp")) - return ADD_ACTIVATION_MUL_CLAMP; - else if constexpr(detail::case_insensitive_equal(name, "Add_Activation_Mul2_Clamp")) - return ADD_ACTIVATION_MUL2_CLAMP; - else if constexpr(detail::case_insensitive_equal(name, "Add_Mul_Activation_Mul_Clamp")) - return ADD_MUL_ACTIVATION_MUL_CLAMP; - else if constexpr(detail::case_insensitive_equal(name, "Add_Mul2_Activation_Mul_Clamp")) - return ADD_MUL2_ACTIVATION_MUL_CLAMP; - else if constexpr(detail::case_insensitive_equal(name, "UnaryConvert")) - return UNARY_CONVERT; - else if constexpr(detail::case_insensitive_equal(name, "Logistic")) - return LOGISTIC; - else if constexpr(detail::case_insensitive_equal(name, "ClippedRelu")) - return CLIPPED_RELU; - else if constexpr(detail::case_insensitive_equal(name, "Swish")) - return SWISH; - else if constexpr(detail::case_insensitive_equal(name, "Elu")) - return ELU; - else if constexpr(detail::case_insensitive_equal(name, "Power")) - return POWER; - else if constexpr(detail::case_insensitive_equal(name, "LeakyRelu")) - return LEAKY_RELU; - else if constexpr(detail::case_insensitive_equal(name, "UnaryAbs")) - return UNARY_ABS; - else if constexpr(detail::case_insensitive_equal(name, "Relu")) - return RELU; - else if constexpr(detail::case_insensitive_equal(name, "SoftRelu")) - return SOFT_RELU; - else if constexpr(detail::case_insensitive_equal(name, "Sigmoid")) - return SIGMOID; - else if constexpr(detail::case_insensitive_equal(name, "TanH")) - return TANH; - else if constexpr(detail::case_insensitive_equal(name, "Gelu")) - return GELU; - else if constexpr(detail::case_insensitive_equal(name, "Silu")) - return SILU; - else - { - report_unsupported_elementwise_op_error(); - return PASS_THROUGH; // Unreachable - } -} - -/// @brief Derives a gemm padding from a kernel instance type. -/// @tparam Instance - A Device Kernel object type. -/// @return A `builder::GemmPadding` enum value corresponding to kernel padding. -template -constexpr builder::GemmPadding gemm_spec() - requires HasGemmSpec> -{ - using InstTraits = InstanceTraits; - using enum builder::GemmPadding; - using enum ck::tensor_operation::device::GemmSpecialization; - - constexpr auto gemm_spec = InstTraits::kGemmSpecialization; - - switch(gemm_spec) - { - case Default: return DEFAULT; - case MPadding: return M_PADDING; - case NPadding: return N_PADDING; - case KPadding: return K_PADDING; - case MNPadding: return MN_PADDING; - case MKPadding: return MK_PADDING; - case NKPadding: return NK_PADDING; - case MNKPadding: return MNK_PADDING; - case OPadding: return O_PADDING; - case MOPadding: return MO_PADDING; - case NOPadding: return NO_PADDING; - case KOPadding: return KO_PADDING; - case MNOPadding: return MNO_PADDING; - case MKOPadding: return MKO_PADDING; - case NKOPadding: return NKO_PADDING; - case MNKOPadding: return MNKO_PADDING; - } -} - -/// @brief Primary template for extracting convolution traits. -/// @details This struct is the main entry point for reflecting on a convolution -/// kernel's properties. It is specialized to handle different kinds of input types. -template -struct ConvTraits; - -/// @brief Specialization of `ConvTraits` for a direct device kernel `Instance`. -/// @details This is the primary specialization used to extract a comprehensive -/// set of traits directly from a fully-formed device kernel `Instance` type. -/// It uses `InstanceTraits` to access the kernel's template parameters. -template - requires IsXdlFwdConv> -struct ConvTraits -{ - using InstTraits = InstanceTraits; - // --- Signature Information --- - /// @brief The number of spatial dimensions in the convolution (1, 2, or 3). - static constexpr int spatial_dim = InstTraits::kSpatialDim; - /// @brief The direction of the convolution (Forward, Backward Data, or Backward Weight). - static constexpr builder::ConvDirection direction = conv_direction(); - /// @brief The memory layout of the convolution tensors (e.g., GNHWC_GKYXC_GNHWK). - static constexpr auto layout = conv_layout(); - /// @brief The primary data type used in the computation (e.g., FP16, FP32). - static constexpr builder::DataType data_type = conv_data_type(); + int spatial_dim; + builder::ConvDirection direction; + std::array layout; // [input, weight, output] + builder::DataType data_type; - static constexpr builder::ElementwiseOperation input_element_op = - elementwise_op(); - static constexpr builder::ElementwiseOperation weight_element_op = - elementwise_op(); - static constexpr builder::ElementwiseOperation output_element_op = - elementwise_op(); + builder::ElementwiseOperation input_element_op; + builder::ElementwiseOperation weight_element_op; + builder::ElementwiseOperation output_element_op; - /// @brief The GEMM specialization used by the kernel - padding - static constexpr auto gemm_padding = gemm_spec(); - /// @brief The convolution-specific specialization (e.g., Default, 1x1). - static constexpr auto conv_specialization = conv_spec(); + builder::GemmPadding gemm_padding; + builder::ConvSpecialization conv_specialization; // --- Algorithm Information --- - /// @brief The total number of threads in a thread block (workgroup). - static constexpr int thread_block_size = InstTraits::kBlockSize; - /// @brief The dimensions of the data tile processed by the thread block. - static constexpr DataTileInfo tile_dims = { - .m = InstTraits::kMPerBlock, .n = InstTraits::kNPerBlock, .k = InstTraits::kKPerBlock}; + int thread_block_size; + DataTileInfo tile_dims; - /// @brief Configuration for the A-matrix (input) tile transfer. - static constexpr InputTileTransferInfo a_tile_transfer = { - .tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1, - .m_or_n = InstTraits::kMPerBlock, - .k1 = InstTraits::kAK1}, - .transfer_params = {.k1 = InstTraits::kAK1, - .thread_cluster_dims = InstTraits::kAThreadClusterLengths, - .thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder, - .src_access_order = InstTraits::kABlockTransferSrcAccessOrder, - .src_vector_dim = InstTraits::kABlockTransferSrcVectorDim, - .src_scalar_per_vector = InstTraits::kABlockTransferSrcScalarPerVector, - .dst_scalar_per_vector_k1 = - InstTraits::kABlockTransferDstScalarPerVectorK1, - .lds_padding = static_cast(InstTraits::kABlockLdsExtraM)}}; + InputTileTransferInfo a_tile_transfer; + InputTileTransferInfo b_tile_transfer; - /// @brief Configuration for the B-matrix (weights) tile transfer. - static constexpr InputTileTransferInfo b_tile_transfer = { - .tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1, - .m_or_n = InstTraits::kNPerBlock, - .k1 = InstTraits::kBK1}, - .transfer_params = {.k1 = InstTraits::kBK1, - .thread_cluster_dims = InstTraits::kBThreadClusterLengths, - .thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder, - .src_access_order = InstTraits::kBBlockTransferSrcAccessOrder, - .src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim, - .src_scalar_per_vector = InstTraits::kBBlockTransferSrcScalarPerVector, - .dst_scalar_per_vector_k1 = - InstTraits::kBBlockTransferDstScalarPerVectorK1, - .lds_padding = static_cast(InstTraits::kBBlockLdsExtraN)}}; + WarpGemmParams warp_gemm; - /// @brief Parameters for the warp-level GEMM computation. - static constexpr WarpGemmParams warp_gemm = {.gemm_m = InstTraits::kMPerXDL, - .gemm_n = InstTraits::kNPerXDL, - .m_iter = InstTraits::kMXdlPerWave, - .n_iter = InstTraits::kNXdlPerWave}; + OutputTileTransferInfo c_tile_transfer; - /// @brief Configuration for the C-matrix (output) tile transfer. - static constexpr OutputTileTransferInfo c_tile_transfer = { - .shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle, - .n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle}, - .thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0], - InstTraits::kCThreadClusterLengths[1], - InstTraits::kCThreadClusterLengths[2], - InstTraits::kCThreadClusterLengths[3]}, - .scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector}; - - /// @brief Helper to safely get the pipeline version. - /// @details This is only available for some convolutions (e.g., forward). - /// If not present in `InstanceTraits`, it returns a default value. - template - static constexpr auto get_pipeline_version() - { - if constexpr(requires { T::kPipelineVersion; }) - { - return convert_pipeline_version(); - } - else - { - // Return a default or indicate not available - return builder::PipelineVersion::V1; - } - } - - /// @brief The block GEMM pipeline version used by the kernel. - static constexpr auto pipeline_version = get_pipeline_version(); - - /// @brief Helper to safely get the pipeline scheduler. - /// @details This is only available for some convolutions. If not present - /// in `InstanceTraits`, it returns a default value. - template - static constexpr auto get_pipeline_scheduler() - { - if constexpr(requires { T::kPipelineScheduler; }) - { - return convert_pipeline_scheduler(); - } - else if constexpr(requires { T::kLoopScheduler; }) - { - return convert_pipeline_scheduler(); - } - else - { - // Return a default or indicate not available - return builder::PipelineScheduler::DEFAULT; - } - } - - /// @brief The pipeline scheduler used by the kernel. - static constexpr auto pipeline_scheduler = get_pipeline_scheduler(); + builder::PipelineVersion pipeline_version; + builder::PipelineScheduler pipeline_scheduler; }; } // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp new file mode 100644 index 0000000000..cdd238f36a --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -0,0 +1,84 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include "ck_tile/builder/reflect/conv_traits.hpp" +#include "ck_tile/builder/reflect/conv_traits_helpers.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +namespace ck_tile::reflect::conv { + +/// @brief Tag dispatch implementation for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle +template + requires HasInstanceTraits && + std::same_as::device_kernel_tag, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Tag> +constexpr ConvTraits instance_to_conv_traits() +{ + using InstTraits = InstanceTraits; + + return ConvTraits{ + .spatial_dim = InstTraits::kSpatialDim, + .direction = conv_direction(), + .layout = conv_layout(), + .data_type = conv_data_type(), + .input_element_op = elementwise_op(), + .weight_element_op = elementwise_op(), + .output_element_op = elementwise_op(), + .gemm_padding = gemm_spec(), + .conv_specialization = conv_spec(), + .thread_block_size = InstTraits::kBlockSize, + .tile_dims = {.m = InstTraits::kMPerBlock, + .n = InstTraits::kNPerBlock, + .k = InstTraits::kKPerBlock}, + .a_tile_transfer = + {.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1, + .m_or_n = InstTraits::kMPerBlock, + .k1 = InstTraits::kAK1}, + .transfer_params = {.k1 = InstTraits::kAK1, + .thread_cluster_dims = InstTraits::kAThreadClusterLengths, + .thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder, + .src_access_order = InstTraits::kABlockTransferSrcAccessOrder, + .src_vector_dim = InstTraits::kABlockTransferSrcVectorDim, + .src_scalar_per_vector = + InstTraits::kABlockTransferSrcScalarPerVector, + .dst_scalar_per_vector_k1 = + InstTraits::kABlockTransferDstScalarPerVectorK1, + .lds_padding = static_cast(InstTraits::kABlockLdsExtraM)}}, + .b_tile_transfer = + {.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1, + .m_or_n = InstTraits::kNPerBlock, + .k1 = InstTraits::kBK1}, + .transfer_params = {.k1 = InstTraits::kBK1, + .thread_cluster_dims = InstTraits::kBThreadClusterLengths, + .thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder, + .src_access_order = InstTraits::kBBlockTransferSrcAccessOrder, + .src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim, + .src_scalar_per_vector = + InstTraits::kBBlockTransferSrcScalarPerVector, + .dst_scalar_per_vector_k1 = + InstTraits::kBBlockTransferDstScalarPerVectorK1, + .lds_padding = static_cast(InstTraits::kBBlockLdsExtraN)}}, + .warp_gemm = {.gemm_m = InstTraits::kMPerXDL, + .gemm_n = InstTraits::kNPerXDL, + .m_iter = InstTraits::kMXdlPerWave, + .n_iter = InstTraits::kNXdlPerWave}, + .c_tile_transfer = {.shuffle_params = {.m_gemms_per_shuffle = + InstTraits::kCShuffleMXdlPerWavePerShuffle, + .n_gemms_per_shuffle = + InstTraits::kCShuffleNXdlPerWavePerShuffle}, + .thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0], + InstTraits::kCThreadClusterLengths[1], + InstTraits::kCThreadClusterLengths[2], + InstTraits::kCThreadClusterLengths[3]}, + .scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector}, + .pipeline_version = get_pipeline_version(), + .pipeline_scheduler = get_pipeline_scheduler(), + }; +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp new file mode 100644 index 0000000000..28c43c342f --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -0,0 +1,84 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include "ck_tile/builder/reflect/conv_traits.hpp" +#include "ck_tile/builder/reflect/conv_traits_helpers.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp" + +namespace ck_tile::reflect::conv { + +/// @brief Tag dispatch implementation for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 +template + requires HasInstanceTraits && + std::same_as::device_kernel_tag, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Tag> +constexpr ConvTraits instance_to_conv_traits() +{ + using InstTraits = InstanceTraits; + + return ConvTraits{ + .spatial_dim = InstTraits::kSpatialDim, + .direction = conv_direction(), + .layout = conv_layout(), + .data_type = conv_data_type(), + .input_element_op = elementwise_op(), + .weight_element_op = elementwise_op(), + .output_element_op = elementwise_op(), + .gemm_padding = gemm_spec(), + .conv_specialization = conv_spec(), + .thread_block_size = InstTraits::kBlockSize, + .tile_dims = {.m = InstTraits::kMPerBlock, + .n = InstTraits::kNPerBlock, + .k = InstTraits::kKPerBlock}, + .a_tile_transfer = + {.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1, + .m_or_n = InstTraits::kMPerBlock, + .k1 = InstTraits::kAK1}, + .transfer_params = {.k1 = InstTraits::kAK1, + .thread_cluster_dims = InstTraits::kAThreadClusterLengths, + .thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder, + .src_access_order = InstTraits::kABlockTransferSrcAccessOrder, + .src_vector_dim = InstTraits::kABlockTransferSrcVectorDim, + .src_scalar_per_vector = + InstTraits::kABlockTransferSrcScalarPerVector, + .dst_scalar_per_vector_k1 = + InstTraits::kABlockTransferDstScalarPerVectorK1, + .lds_padding = static_cast(InstTraits::kABlockLdsExtraM)}}, + .b_tile_transfer = + {.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1, + .m_or_n = InstTraits::kNPerBlock, + .k1 = InstTraits::kBK1}, + .transfer_params = {.k1 = InstTraits::kBK1, + .thread_cluster_dims = InstTraits::kBThreadClusterLengths, + .thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder, + .src_access_order = InstTraits::kBBlockTransferSrcAccessOrder, + .src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim, + .src_scalar_per_vector = + InstTraits::kBBlockTransferSrcScalarPerVector, + .dst_scalar_per_vector_k1 = + InstTraits::kBBlockTransferDstScalarPerVectorK1, + .lds_padding = static_cast(InstTraits::kBBlockLdsExtraN)}}, + .warp_gemm = {.gemm_m = InstTraits::kMPerXDL, + .gemm_n = InstTraits::kNPerXDL, + .m_iter = InstTraits::kMXdlPerWave, + .n_iter = InstTraits::kNXdlPerWave}, + .c_tile_transfer = {.shuffle_params = {.m_gemms_per_shuffle = + InstTraits::kCShuffleMXdlPerWavePerShuffle, + .n_gemms_per_shuffle = + InstTraits::kCShuffleNXdlPerWavePerShuffle}, + .thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0], + InstTraits::kCThreadClusterLengths[1], + InstTraits::kCThreadClusterLengths[2], + InstTraits::kCThreadClusterLengths[3]}, + .scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector}, + .pipeline_version = get_pipeline_version(), + .pipeline_scheduler = get_pipeline_scheduler(), + }; +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp new file mode 100644 index 0000000000..c4bed850eb --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp @@ -0,0 +1,84 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include "ck_tile/builder/reflect/conv_traits.hpp" +#include "ck_tile/builder/reflect/conv_traits_helpers.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp" + +namespace ck_tile::reflect::conv { + +/// @brief Tag dispatch implementation for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor +template + requires HasInstanceTraits && + std::same_as::device_kernel_tag, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Tag> +constexpr ConvTraits instance_to_conv_traits() +{ + using InstTraits = InstanceTraits; + + return ConvTraits{ + .spatial_dim = InstTraits::kSpatialDim, + .direction = conv_direction(), + .layout = conv_layout(), + .data_type = conv_data_type(), + .input_element_op = elementwise_op(), + .weight_element_op = elementwise_op(), + .output_element_op = elementwise_op(), + .gemm_padding = gemm_spec(), + .conv_specialization = conv_spec(), + .thread_block_size = InstTraits::kBlockSize, + .tile_dims = {.m = InstTraits::kMPerBlock, + .n = InstTraits::kNPerBlock, + .k = InstTraits::kKPerBlock}, + .a_tile_transfer = + {.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1, + .m_or_n = InstTraits::kMPerBlock, + .k1 = InstTraits::kAK1}, + .transfer_params = {.k1 = InstTraits::kAK1, + .thread_cluster_dims = InstTraits::kAThreadClusterLengths, + .thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder, + .src_access_order = InstTraits::kABlockTransferSrcAccessOrder, + .src_vector_dim = InstTraits::kABlockTransferSrcVectorDim, + .src_scalar_per_vector = + InstTraits::kABlockTransferSrcScalarPerVector, + .dst_scalar_per_vector_k1 = + InstTraits::kABlockTransferDstScalarPerVectorK1, + .lds_padding = static_cast(InstTraits::kABlockLdsExtraM)}}, + .b_tile_transfer = + {.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1, + .m_or_n = InstTraits::kNPerBlock, + .k1 = InstTraits::kBK1}, + .transfer_params = {.k1 = InstTraits::kBK1, + .thread_cluster_dims = InstTraits::kBThreadClusterLengths, + .thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder, + .src_access_order = InstTraits::kBBlockTransferSrcAccessOrder, + .src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim, + .src_scalar_per_vector = + InstTraits::kBBlockTransferSrcScalarPerVector, + .dst_scalar_per_vector_k1 = + InstTraits::kBBlockTransferDstScalarPerVectorK1, + .lds_padding = static_cast(InstTraits::kBBlockLdsExtraN)}}, + .warp_gemm = {.gemm_m = InstTraits::kMPerXDL, + .gemm_n = InstTraits::kNPerXDL, + .m_iter = InstTraits::kMXdlPerWave, + .n_iter = InstTraits::kNXdlPerWave}, + .c_tile_transfer = {.shuffle_params = {.m_gemms_per_shuffle = + InstTraits::kCShuffleMXdlPerWavePerShuffle, + .n_gemms_per_shuffle = + InstTraits::kCShuffleNXdlPerWavePerShuffle}, + .thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0], + InstTraits::kCThreadClusterLengths[1], + InstTraits::kCThreadClusterLengths[2], + InstTraits::kCThreadClusterLengths[3]}, + .scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector}, + .pipeline_version = get_pipeline_version(), + .pipeline_scheduler = get_pipeline_scheduler(), + }; +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_helpers.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_helpers.hpp new file mode 100644 index 0000000000..46c196e95a --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_helpers.hpp @@ -0,0 +1,739 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include + +#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/pipeline_enum.hpp" +#include "ck/utility/scheduler_enum.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/reflect/conv_types.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_util.hpp" +#include "ck_tile/builder/types.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/grouped_convolution.hpp" + +/// @file conv_traits_helpers.hpp +/// @brief Helper utilities for extracting convolution traits from kernel instances +/// +/// This file provides compile-time reflection utilities to extract configuration +/// information from CK convolution kernel instances and convert them to the builder +/// framework's standardized representation. +/// +/// ## Organization +/// +/// The file is organized into the following sections: +/// +/// 1. **Enum Conversions**: Functions to convert CK enums to builder enums +/// - Pipeline version conversions (BlockGemmPipelineVersion, PipelineVersion) +/// - Pipeline scheduler conversions (BlockGemmPipelineScheduler, LoopScheduler) +/// +/// 2. **Signature Derivation**: Functions to extract signature information from instances +/// - Convolution direction (conv_direction) +/// - Convolution specialization (conv_spec) +/// - Tensor layouts (conv_layout) +/// - Data types (conv_data_type) +/// - Elementwise operations (elementwise_op) +/// - GEMM padding (gemm_spec) +/// +/// 3. **Pipeline Configuration Helpers**: Safe extraction of pipeline parameters +/// - Pipeline version extraction (get_pipeline_version) +/// - Pipeline scheduler extraction (get_pipeline_scheduler) +/// +/// ## Error Handling Strategy +/// +/// This file uses a specific error handling pattern for compile-time errors: +/// - **consteval functions with throw**: Used for error reporting to ensure SFINAE doesn't +/// silently ignore errors. The thrown string becomes part of the compiler error message, +/// providing clear context to developers. +/// - **DO NOT replace with static_assert**: static_assert is silently ignored during SFINAE, +/// which would hide errors instead of reporting them clearly. +/// +/// @example +/// ```cpp +/// using Instance = +/// ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<...>; +/// +/// // Extract convolution direction +/// constexpr auto dir = conv_direction(); +/// +/// // Extract data type +/// constexpr auto dtype = conv_data_type(); +/// +/// // Extract layout configuration +/// constexpr auto layouts = conv_layout(); +/// ``` + +namespace ck_tile::reflect::conv { + +// ============================================================================ +// SECTION 1: ENUM CONVERSIONS +// ============================================================================ + +/// @brief Converts a CK BlockGemmPipelineVersion enum to a builder PipelineVersion enum. +/// @tparam ck_ver The CK BlockGemmPipelineVersion enum value to convert. +/// @return The corresponding builder::PipelineVersion enum value. +/// @details This function maps CK's block GEMM pipeline version identifiers to the +/// builder framework's standardized pipeline version enum. The pipeline version +/// determines the strategy used for data movement and computation overlap in the +/// GEMM kernel's main loop. +/// +/// Supported mappings: +/// - v1 -> V1 +/// - v2 -> V2 +/// - v3 -> V3 +/// - v4 -> V4 +/// - v5 -> V5 +template +constexpr builder::PipelineVersion convert_pipeline_version() +{ + using enum ck::BlockGemmPipelineVersion; + using enum builder::PipelineVersion; + + switch(ck_ver) + { + case v1: return V1; + case v2: return V2; + case v3: return V3; + case v4: return V4; + case v5: return V5; + } +} + +/// @brief Converts a CK PipelineVersion enum to a builder PipelineVersion enum. +/// @tparam ck_ver The CK PipelineVersion enum value to convert. +/// @return The corresponding builder::PipelineVersion enum value. +/// @details This function maps CK's general pipeline version identifiers to the +/// builder framework's standardized pipeline version enum. Note that this overload +/// handles a different set of pipeline versions compared to the BlockGemmPipelineVersion +/// variant, including support for specialized weight-only pipelines. +/// +/// Supported mappings: +/// - v1 -> V1 +/// - v2 -> V2 +/// - v4 -> V4 +/// - weight_only -> WEIGHT_ONLY +template +constexpr builder::PipelineVersion convert_pipeline_version() +{ + using enum ck::PipelineVersion; + using enum builder::PipelineVersion; + + switch(ck_ver) + { + case v1: return V1; + case v2: return V2; + case v4: return V4; + case weight_only: return WEIGHT_ONLY; + } +} + +/// @brief Converts a CK BlockGemmPipelineScheduler enum to a builder PipelineScheduler enum. +/// @tparam ck_sched The CK BlockGemmPipelineScheduler enum value to convert. +/// @return The corresponding builder::PipelineScheduler enum value. +/// @details This function maps CK's block GEMM pipeline scheduler identifiers to the +/// builder framework's standardized scheduler enum. The scheduler determines how work +/// is distributed and synchronized within and across wavefronts during pipeline execution. +/// +/// Supported mappings: +/// - Intrawave -> INTRAWAVE: Scheduling within a single wavefront +/// - Interwave -> INTERWAVE: Coordination across multiple wavefronts +template +constexpr builder::PipelineScheduler convert_pipeline_scheduler() +{ + using enum ck::BlockGemmPipelineScheduler; + using enum builder::PipelineScheduler; + + switch(ck_sched) + { + case Intrawave: return INTRAWAVE; + case Interwave: return INTERWAVE; + } +} + +/// @brief Converts a CK LoopScheduler enum to a builder PipelineScheduler enum. +/// @tparam ck_sched The CK LoopScheduler enum value to convert. +/// @return The corresponding builder::PipelineScheduler enum value. +/// @details This function maps CK's loop scheduler identifiers to the builder framework's +/// standardized pipeline scheduler enum. The loop scheduler controls how iterations of +/// the main computational loop are scheduled across threads. +/// +/// Supported mappings: +/// - Default -> DEFAULT: Standard scheduling strategy +/// - Interwave -> INTERWAVE: Cross-wavefront coordination for improved performance +template +constexpr builder::PipelineScheduler convert_pipeline_scheduler() +{ + using enum ck::LoopScheduler; + using enum builder::PipelineScheduler; + + switch(ck_sched) + { + case Default: return DEFAULT; + case Interwave: return INTERWAVE; + } +} + +// ============================================================================ +// SECTION 2: SIGNATURE DERIVATION FUNCTIONS +// ============================================================================ + +// ---------------------------------------------------------------------------- +// Convolution Direction +// ---------------------------------------------------------------------------- + +/// @brief Helper function to report unsupported convolution direction with a clear error message. +/// @details This consteval function uses throw (not static_assert) to ensure the error is not +/// silently ignored during SFINAE. The thrown string becomes part of the compiler error message. +template +[[noreturn]] consteval void report_unsupported_conv_direction_error() +{ + throw "Unsupported convolution direction detected!\n" + "The kernel instance does not have a recognized convolution specialization.\n" + "Expected one of: kConvForwardSpecialization, kConvBwdDataSpecialization, or " + "kConvBwdWeightSpecialization.\n" + "Please verify that your kernel instance is properly configured."; +} + +/// @brief Derives the convolution direction from a device kernel Instance type. +/// @tparam Instance The device kernel instance type. +/// @return A builder::ConvDirection enum value (FORWARD, BACKWARD_DATA, or BACKWARD_WEIGHT). +/// @details This function inspects the Instance's InstanceTraits to determine which +/// convolution specialization field is present, and returns the corresponding direction. +/// +/// The function checks for the presence of: +/// - kConvForwardSpecialization -> FORWARD +/// - kConvBwdDataSpecialization -> BACKWARD_DATA +/// - kConvBwdWeightSpecialization -> BACKWARD_WEIGHT +/// +/// @note Compilation will fail with a clear error message if the instance does not +/// have a recognized convolution specialization field. +template +constexpr builder::ConvDirection conv_direction() +{ + using InstTraits = InstanceTraits; + + if constexpr(requires { &InstTraits::kConvForwardSpecialization; }) + return builder::ConvDirection::FORWARD; + else if constexpr(requires { &InstTraits::kConvBwdDataSpecialization; }) + return builder::ConvDirection::BACKWARD_DATA; + else if constexpr(requires { &InstTraits::kConvBwdWeightSpecialization; }) + return builder::ConvDirection::BACKWARD_WEIGHT; + else + { + report_unsupported_conv_direction_error(); + return builder::ConvDirection::FORWARD; // Unreachable + } +} + +// ---------------------------------------------------------------------------- +// Convolution Specialization +// ---------------------------------------------------------------------------- + +/// @brief Helper function to report unsupported convolution specialization with a clear error +/// message. +/// @details This consteval function uses throw (not static_assert) to ensure the error is not +/// silently ignored during SFINAE. The thrown string becomes part of the compiler error message. +template +[[noreturn]] consteval void report_unsupported_conv_spec_error() +{ + throw "Unsupported convolution specialization detected!\n" + "The kernel instance does not have a recognized convolution specialization field.\n" + "Expected one of: kConvForwardSpecialization, kConvBwdDataSpecialization, or " + "kConvBwdWeightSpecialization.\n" + "Please verify that your kernel instance is properly configured."; +} + +/// @brief Derives the convolution-specific specialization from a device kernel Instance type. +/// @tparam Instance The device kernel instance type. +/// @return A builder::ConvSpecialization enum value. +/// @details This function extracts the specialization enum from the Instance's InstanceTraits +/// and converts it to the corresponding builder framework enum. +/// +/// For forward convolutions, supported specializations include: +/// - Default, Filter1x1Pad0, Filter1x1Stride1Pad0, Filter3x3, OddC +/// +/// For backward data convolutions: +/// - Default, Filter1x1Stride1Pad0 +/// +/// For backward weight convolutions: +/// - Default, Filter1x1Stride1Pad0, Filter1x1Pad0, OddC +template +constexpr builder::ConvSpecialization conv_spec() +{ + using InstTraits = InstanceTraits; + + if constexpr(requires { InstTraits::kConvForwardSpecialization; }) + { + using enum ck::tensor_operation::device::ConvolutionForwardSpecialization; + using enum builder::ConvSpecialization; + + switch(InstTraits::kConvForwardSpecialization) + { + case Default: return DEFAULT; + case Filter1x1Pad0: return FILTER_1X1_PAD0; + case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0; + case Filter3x3: return FILTER_3x3; + case OddC: return ODD_C; + } + } + else if constexpr(requires { InstTraits::kConvBwdDataSpecialization; }) + { + using enum ck::tensor_operation::device::ConvolutionBackwardDataSpecialization; + using enum builder::ConvSpecialization; + + switch(InstTraits::kConvBwdDataSpecialization) + { + case Default: return DEFAULT; + case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0; + } + } + else if constexpr(requires { InstTraits::kConvBwdWeightSpecialization; }) + { + using enum ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization; + using enum builder::ConvSpecialization; + + switch(InstTraits::kConvBwdWeightSpecialization) + { + case Default: return DEFAULT; + case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0; + case Filter1x1Pad0: return FILTER_1X1_PAD0; + case OddC: return ODD_C; + } + } + else + { + report_unsupported_conv_spec_error(); + return builder::ConvSpecialization::DEFAULT; // Unreachable + } +} + +// ---------------------------------------------------------------------------- +// Tensor Layouts +// ---------------------------------------------------------------------------- + +/// @brief Helper function to report unsupported layout combinations with a clear error message. +/// @details This consteval function uses throw (not static_assert) to ensure the error is not +/// silently ignored during SFINAE. The thrown string becomes part of the compiler error message. +template +[[noreturn]] consteval void report_unsupported_layout_error() +{ + throw "Unsupported convolution layout combination detected!\n" + "The combination of ALayout, BLayout, and ELayout template parameters\n" + "is not recognized for the given spatial dimension.\n" + "Please verify that your convolution instance uses a supported layout configuration.\n" + "Check the conv_layout() function for the list of supported layout combinations."; +} + +/// @brief Derives the grouped convolution layout from a device kernel Instance type. +/// @tparam Instance The device kernel instance type. +/// @return An std::array containing the layouts for: +/// - [0] Input tensor layout +/// - [1] Weight tensor layout +/// - [2] Output tensor layout +/// @details This function examines the Instance's ALayout, BLayout, and ELayout types +/// along with the spatial dimension to determine the appropriate layout configuration. +/// +/// Supported layout combinations vary by spatial dimension (1D, 2D, 3D convolutions). +/// Common patterns include GNHWC (grouped, batch, spatial, channels) and variants. +/// +/// @note Compilation will fail with a clear error message if the layout combination +/// is not supported for the given spatial dimension. +/// +/// TODO: If we don't check for supported layouts, this function can be simplified. +template +constexpr std::array conv_layout() +{ + using InstTraits = InstanceTraits; + using A = typename InstTraits::ALayout; + using B = typename InstTraits::BLayout; + using E = typename InstTraits::ELayout; + namespace ctl = ck::tensor_layout::convolution; + using enum builder::TensorLayout; + + // Helper to check if layouts match expected types + constexpr auto layouts_match = []() { + return std::is_same_v && std::is_same_v && std::is_same_v; + }; + + // Helper to construct layout array + constexpr auto make_layouts = [](auto in, auto weight, auto out) { + return std::array{in, weight, out}; + }; + + constexpr int spatial_dim = InstTraits::kSpatialDim; + + if constexpr(spatial_dim == 1) + { + if constexpr(layouts_match.template operator()()) + return make_layouts(GNWC, GKXC, GNWK); + else if constexpr(layouts_match + .template operator()()) + return make_layouts(GNWC, GKXC, GNWK); + else if constexpr(layouts_match.template operator()()) + return make_layouts(NWGC, GKXC, NWGK); + else if constexpr(layouts_match.template operator()()) + return make_layouts(NGCW, GKXC, NGKW); + else if constexpr(layouts_match.template operator()()) + return make_layouts(NGCW, GKCX, NGKW); + else + { + report_unsupported_layout_error(); + return make_layouts(GNWC, GKXC, GNWK); // Unreachable + } + } + else if constexpr(spatial_dim == 2) + { + if constexpr(layouts_match.template operator()()) + return make_layouts(GNHWC, GKYXC, GNHWK); + else if constexpr(layouts_match + .template operator()()) + return make_layouts(GNHWC, GKYXC, GNHWK); + else if constexpr(layouts_match.template operator()()) + return make_layouts(NHWGC, GKYXC, NHWGK); + else if constexpr(layouts_match.template operator()()) + return make_layouts(NHWGC, GKYXC, NHWGK); + else if constexpr(layouts_match.template operator()()) + return make_layouts(NGCHW, GKYXC, NGKHW); + else if constexpr(layouts_match.template operator()()) + return make_layouts(NGCHW, GKCYX, NGKHW); + else + { + report_unsupported_layout_error(); + return make_layouts(GNHWC, GKYXC, GNHWK); // Unreachable + } + } + else if constexpr(spatial_dim == 3) + { + if constexpr(layouts_match.template operator()()) + return make_layouts(GNDHWC, GKZYXC, GNDHWK); + else if constexpr(layouts_match + .template operator()()) + return make_layouts(GNDHWC, GKZYXC, GNDHWK); + else if constexpr(layouts_match + .template operator()()) + return make_layouts(NDHWGC, GKZYXC, NDHWGK); + else if constexpr(layouts_match + .template operator()()) + return make_layouts(NGCDHW, GKZYXC, NGKDHW); + else if constexpr(layouts_match + .template operator()()) + return make_layouts(NGCDHW, GKCZYX, NGKDHW); + else + { + report_unsupported_layout_error(); + return make_layouts(GNDHWC, GKZYXC, GNDHWK); // Unreachable + } + } + else + { + report_unsupported_layout_error(); + return make_layouts(GNHWC, GKYXC, GNHWK); // Unreachable + } +} + +// ---------------------------------------------------------------------------- +// Data Types +// ---------------------------------------------------------------------------- + +/// @brief Helper function to report unsupported data type with a clear error message. +/// @details This consteval function uses throw (not static_assert) to ensure the error is not +/// silently ignored during SFINAE. The thrown string becomes part of the compiler error message. +template +[[noreturn]] consteval void report_unsupported_data_type_error() +{ + throw "Unsupported data type detected!\n" + "The ADataType is not recognized.\n" + "Supported types are: ck::half_t (FP16), ck::Tuple (FP16_FP16), " + "ck::bhalf_t (BF16), ck::Tuple (BF16_BF16), float (FP32), " + "ck::Tuple (FP32_FP32), double (FP64), ck::f8_t (FP8), ck::bf8_fnuz_t " + "(BF8), " + "int8_t (I8), ck::Tuple (I8_I8), uint8_t (U8).\n" + "Please verify that your kernel instance uses a supported data type."; +} + +/// @brief Derives the data type from a device kernel Instance type. +/// @tparam Instance The device kernel instance type. +/// @return A builder::DataType enum value representing the input data type. +/// @details This function examines the Instance's ADataType to determine the data type +/// used for the input tensor. The function supports various floating-point and integer +/// types, including tuple types for mixed-precision operations. +/// +/// Supported data types include: +/// - FP16 (ck::half_t) +/// - FP16_FP16 (ck::Tuple) +/// - BF16 (ck::bhalf_t) +/// - BF16_BF16 (ck::Tuple) +/// - FP32 (float) +/// - FP32_FP32 (ck::Tuple) +/// - FP64 (double) +/// - FP8 (ck::f8_t) +/// - BF8 (ck::bf8_fnuz_t, ck::bf8_ocp_t) +/// - I8 (int8_t) +/// - I8_I8 (ck::Tuple) +/// - U8 (uint8_t) +template +constexpr builder::DataType conv_data_type() +{ + using InstTraits = InstanceTraits; + using ADataType = typename InstTraits::ADataType; + using enum builder::DataType; + + if constexpr(std::is_same_v) + return FP16; + else if constexpr(std::is_same_v>) + return FP16_FP16; + else if constexpr(std::is_same_v) + return BF16; + else if constexpr(std::is_same_v>) + return BF16_BF16; + else if constexpr(std::is_same_v) + return FP32; + else if constexpr(std::is_same_v>) + return FP32_FP32; + else if constexpr(std::is_same_v) + return FP64; + else if constexpr(std::is_same_v) + return FP8; + else if constexpr(std::is_same_v) + return BF8; + else if constexpr(std::is_same_v) + return BF8; + else if constexpr(std::is_same_v) + return I8; + else if constexpr(std::is_same_v>) + return I8_I8; + else if constexpr(std::is_same_v) + return U8; + else + { + report_unsupported_data_type_error(); + return FP32; // Unreachable + } +} + +// ---------------------------------------------------------------------------- +// Elementwise Operations +// ---------------------------------------------------------------------------- + +/// @brief Helper function to report unsupported elementwise operation with a clear error message. +/// @details This consteval function uses throw (not static_assert) to ensure the error is not +/// silently ignored during SFINAE. The thrown string becomes part of the compiler error message. +template +[[noreturn]] consteval void report_unsupported_elementwise_op_error() +{ + throw "Unsupported elementwise operation detected!\n" + "The elementwise operation type is not recognized.\n" + "Supported operations are: AddClamp, AddReluAdd, BiasBnormClamp, Bilinear, " + "BiasNormalizeInInferClamp, Clamp, ConvInvscale, ConvScale, ConvScaleAdd, " + "ConvScaleRelu, Scale, ScaleAdd, PassThrough, ScaleAddScaleAddRelu, DynamicUnaryOp, " + "UnaryCombinedOp, Activation_Mul2_Clamp, Activation_Mul_Clamp, Add_Activation_Mul_Clamp, " + "Add_Activation_Mul2_Clamp, Add_Mul_Activation_Mul_Clamp, Add_Mul2_Activation_Mul_Clamp, " + "UnaryConvert.\n" + "Please verify that your kernel instance uses a supported elementwise operation."; +} + +/// @brief Derives the elementwise operation from an operation functor type. +/// @tparam ElementwiseOp Elementwise operation functor type. +/// @return A builder::ElementwiseOperation enum value corresponding to the operation. +/// @details This function uses the operation's type name to determine which elementwise +/// operation is being used. The comparison is case-insensitive. +/// +/// Supported operations include: +/// - Activation operations: Relu, Sigmoid, Tanh, Gelu, Silu, Elu, Swish, etc. +/// - Scaling operations: Scale, ScaleAdd, ConvScale, ConvScaleAdd, etc. +/// - Clamping operations: Clamp, AddClamp, etc. +/// - Combined operations: Add_Activation_Mul_Clamp, etc. +/// - Utility operations: PassThrough, UnaryConvert, etc. +/// +/// TODO: Consider changing this to direct checks on the types, not strings. +template +constexpr builder::ElementwiseOperation elementwise_op() +{ + using enum builder::ElementwiseOperation; + constexpr std::string_view name = detail::elementwise_op_name(); + + if constexpr(detail::case_insensitive_equal(name, "AddClamp")) + return ADD_CLAMP; + else if constexpr(detail::case_insensitive_equal(name, "AddReluAdd")) + return ADD_RELU_ADD; + else if constexpr(detail::case_insensitive_equal(name, "BiasBnormClamp")) + return BIAS_BNORM_CLAMP; + else if constexpr(detail::case_insensitive_equal(name, "Bilinear")) + return BILINEAR; + else if constexpr(detail::case_insensitive_equal(name, "BiasNormalizeInInferClamp")) + return BIAS_BNORM_CLAMP; + else if constexpr(detail::case_insensitive_equal(name, "Clamp")) + return CLAMP; + else if constexpr(detail::case_insensitive_equal(name, "ConvInvscale")) + return CONV_INVSCALE; + else if constexpr(detail::case_insensitive_equal(name, "ConvScale")) + return CONV_SCALE; + else if constexpr(detail::case_insensitive_equal(name, "ConvScaleAdd")) + return CONV_SCALE_ADD; + else if constexpr(detail::case_insensitive_equal(name, "ConvScaleRelu")) + return CONV_SCALE_RELU; + else if constexpr(detail::case_insensitive_equal(name, "Scale")) + return SCALE; + else if constexpr(detail::case_insensitive_equal(name, "ScaleAdd")) + return SCALE_ADD; + else if constexpr(detail::case_insensitive_equal(name, "PassThrough")) + return PASS_THROUGH; + else if constexpr(detail::case_insensitive_equal(name, "ScaleAddScaleAddRelu")) + return SCALEADD_SCALEADD_RELU; + else if constexpr(detail::case_insensitive_equal(name, "DynamicUnaryOp")) + return DYNAMIC_UNARY_OP; + else if constexpr(detail::case_insensitive_equal(name, "UnaryCombinedOp")) + return UNARY_COMBINED_OP; + else if constexpr(detail::case_insensitive_equal(name, "Activation_Mul2_Clamp")) + return ACTIVATION_MUL2_CLAMP; + else if constexpr(detail::case_insensitive_equal(name, "Activation_Mul_Clamp")) + return ACTIVATION_MUL_CLAMP; + else if constexpr(detail::case_insensitive_equal(name, "Add_Activation_Mul_Clamp")) + return ADD_ACTIVATION_MUL_CLAMP; + else if constexpr(detail::case_insensitive_equal(name, "Add_Activation_Mul2_Clamp")) + return ADD_ACTIVATION_MUL2_CLAMP; + else if constexpr(detail::case_insensitive_equal(name, "Add_Mul_Activation_Mul_Clamp")) + return ADD_MUL_ACTIVATION_MUL_CLAMP; + else if constexpr(detail::case_insensitive_equal(name, "Add_Mul2_Activation_Mul_Clamp")) + return ADD_MUL2_ACTIVATION_MUL_CLAMP; + else if constexpr(detail::case_insensitive_equal(name, "UnaryConvert")) + return UNARY_CONVERT; + else if constexpr(detail::case_insensitive_equal(name, "Logistic")) + return LOGISTIC; + else if constexpr(detail::case_insensitive_equal(name, "ClippedRelu")) + return CLIPPED_RELU; + else if constexpr(detail::case_insensitive_equal(name, "Swish")) + return SWISH; + else if constexpr(detail::case_insensitive_equal(name, "Elu")) + return ELU; + else if constexpr(detail::case_insensitive_equal(name, "Power")) + return POWER; + else if constexpr(detail::case_insensitive_equal(name, "LeakyRelu")) + return LEAKY_RELU; + else if constexpr(detail::case_insensitive_equal(name, "UnaryAbs")) + return UNARY_ABS; + else if constexpr(detail::case_insensitive_equal(name, "Relu")) + return RELU; + else if constexpr(detail::case_insensitive_equal(name, "SoftRelu")) + return SOFT_RELU; + else if constexpr(detail::case_insensitive_equal(name, "Sigmoid")) + return SIGMOID; + else if constexpr(detail::case_insensitive_equal(name, "TanH")) + return TANH; + else if constexpr(detail::case_insensitive_equal(name, "Gelu")) + return GELU; + else if constexpr(detail::case_insensitive_equal(name, "Silu")) + return SILU; + else + { + report_unsupported_elementwise_op_error(); + return PASS_THROUGH; // Unreachable + } +} + +// ---------------------------------------------------------------------------- +// GEMM Padding +// ---------------------------------------------------------------------------- + +/// @brief Derives the GEMM padding specification from a kernel instance type. +/// @tparam Instance A device kernel instance type. +/// @return A builder::GemmPadding enum value corresponding to the kernel's padding configuration. +/// @details This function extracts the GEMM specialization from the Instance's InstanceTraits +/// and converts it to the builder framework's GemmPadding enum. The padding specification +/// indicates which dimensions (M, N, K, O) are padded to handle non-aligned tensor sizes. +/// +/// Supported padding configurations include: +/// - DEFAULT: No padding +/// - M_PADDING, N_PADDING, K_PADDING, O_PADDING: Single dimension padding +/// - MN_PADDING, MK_PADDING, NK_PADDING, etc.: Two dimension padding +/// - MNK_PADDING, MNO_PADDING, etc.: Three dimension padding +/// - MNKO_PADDING: All dimensions padded +template +constexpr builder::GemmPadding gemm_spec() +{ + using InstTraits = InstanceTraits; + using enum builder::GemmPadding; + using enum ck::tensor_operation::device::GemmSpecialization; + + constexpr auto spec = InstTraits::kGemmSpecialization; + + switch(spec) + { + case Default: return DEFAULT; + case MPadding: return M_PADDING; + case NPadding: return N_PADDING; + case KPadding: return K_PADDING; + case MNPadding: return MN_PADDING; + case MKPadding: return MK_PADDING; + case NKPadding: return NK_PADDING; + case MNKPadding: return MNK_PADDING; + case OPadding: return O_PADDING; + case MOPadding: return MO_PADDING; + case NOPadding: return NO_PADDING; + case KOPadding: return KO_PADDING; + case MNOPadding: return MNO_PADDING; + case MKOPadding: return MKO_PADDING; + case NKOPadding: return NKO_PADDING; + case MNKOPadding: return MNKO_PADDING; + } +} + +// ============================================================================ +// SECTION 3: PIPELINE CONFIGURATION HELPERS +// ============================================================================ + +/// @brief Safely extracts the pipeline version from InstanceTraits. +/// @tparam InstTraits The InstanceTraits type to extract pipeline version from. +/// @return The pipeline version as a builder::PipelineVersion enum value. +/// @details This helper function checks if the InstanceTraits has a kPipelineVersion +/// field and extracts it if present. If not present, it returns a default value (V1). +/// This is necessary because not all convolution types expose pipeline version information. +template +constexpr builder::PipelineVersion get_pipeline_version() +{ + if constexpr(requires { InstTraits::kPipelineVersion; }) + { + return convert_pipeline_version(); + } + else + { + return builder::PipelineVersion::V1; + } +} + +/// @brief Safely extracts the pipeline scheduler from InstanceTraits. +/// @tparam InstTraits The InstanceTraits type to extract pipeline scheduler from. +/// @return The pipeline scheduler as a builder::PipelineScheduler enum value. +/// @details This helper function checks if the InstanceTraits has a kPipelineScheduler +/// or kLoopScheduler field and extracts it if present. If neither is present, it returns +/// a default value (DEFAULT). This is necessary because different convolution types may +/// expose scheduler information through different field names. +template +constexpr builder::PipelineScheduler get_pipeline_scheduler() +{ + if constexpr(requires { InstTraits::kPipelineScheduler; }) + { + return convert_pipeline_scheduler(); + } + else if constexpr(requires { InstTraits::kLoopScheduler; }) + { + return convert_pipeline_scheduler(); + } + else + { + return builder::PipelineScheduler::DEFAULT; + } +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp new file mode 100644 index 0000000000..00010e2d48 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp @@ -0,0 +1,8 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" +#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp" +#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp" diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index f5f3df3159..71db59afb6 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -74,6 +74,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle; namespace ck_tile::reflect { +/// @brief Tag type for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle device kernel +struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Tag +{ +}; + // Specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle template > { + /// @brief Tag type identifying this device kernel variant + using device_kernel_tag = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Tag; + // Spatial dimension static constexpr int kSpatialDim = NDimSpatial; diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index ace1b09224..4549b76a3f 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -78,6 +78,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3; namespace ck_tile::reflect { +/// @brief Tag type for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 device kernel +struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Tag +{ +}; + // Specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 template > { + /// @brief Tag type identifying this device kernel variant + using device_kernel_tag = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Tag; + // Spatial dimension static constexpr int kSpatialDim = NDimSpatial; diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp index 09274d5acd..046e5c3078 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp @@ -73,6 +73,11 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor; namespace ck_tile::reflect { +/// @brief Tag type for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor device kernel +struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Tag +{ +}; + // Specialization for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor template > { + /// @brief Tag type identifying this device kernel variant + using device_kernel_tag = DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Tag; + // Spatial dimension static constexpr int kSpatialDim = NDimSpatial; diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index ddcf8db476..9890563859 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -108,7 +108,8 @@ target_link_libraries(test_ckb_reference_execution PRIVATE utility) # Tests convolution trait selection and configuration add_ck_builder_test(test_ckb_conv_traits conv/ck/test_conv_traits.cpp - conv/ck/unit_instance_to_conv_traits.cpp) + conv/ck/unit_instance_to_conv_traits_features.cpp + conv/ck/unit_instance_to_conv_traits_instances.cpp) # Tests convolution problem description and parameter handling add_ck_builder_test(test_ckb_conv_description diff --git a/experimental/builder/test/conv/ck/test_conv_traits.cpp b/experimental/builder/test/conv/ck/test_conv_traits.cpp index b3a76e4e11..42235df2fe 100644 --- a/experimental/builder/test/conv/ck/test_conv_traits.cpp +++ b/experimental/builder/test/conv/ck/test_conv_traits.cpp @@ -6,7 +6,7 @@ #include #include -#include +#include #include #include #include @@ -86,72 +86,72 @@ TEST_F(ConvTraitsTest, ConvFwdTraitsExtraction) ck::half_t, // BComputeDataType false>; // DirectLoad - // Use ConvTraits to extract compile-time information - using Traits = ck_tile::reflect::conv::ConvTraits; + // Use ConvTraitsTmpl to extract compile-time information + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); // Verify signature information - EXPECT_EQ(Traits::spatial_dim, 2); - EXPECT_EQ(Traits::direction, ConvDirection::FORWARD); - EXPECT_THAT(Traits::layout, + EXPECT_EQ(traits.spatial_dim, 2); + EXPECT_EQ(traits.direction, ConvDirection::FORWARD); + EXPECT_THAT(traits.layout, ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); - EXPECT_EQ(Traits::data_type, DataType::FP16); - EXPECT_EQ(Traits::input_element_op, ElementwiseOperation::PASS_THROUGH); - EXPECT_EQ(Traits::weight_element_op, ElementwiseOperation::PASS_THROUGH); - EXPECT_EQ(Traits::output_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.data_type, DataType::FP16); + EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH); // Verify specializations - EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT); - EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT); + EXPECT_EQ(traits.gemm_padding, ck_tile::builder::GemmPadding::DEFAULT); + EXPECT_EQ(traits.conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT); // Verify algorithm information - EXPECT_EQ(Traits::thread_block_size, 256); + EXPECT_EQ(traits.thread_block_size, 256); // Verify tile dimensions - EXPECT_EQ(Traits::tile_dims.m, 128); - EXPECT_EQ(Traits::tile_dims.n, 128); - EXPECT_EQ(Traits::tile_dims.k, 16); + EXPECT_EQ(traits.tile_dims.m, 128); + EXPECT_EQ(traits.tile_dims.n, 128); + EXPECT_EQ(traits.tile_dims.k, 16); // Verify A tile transfer info - EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.k0, 2); - EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.m_or_n, 128); - EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.k1, 8); - EXPECT_EQ(Traits::a_tile_transfer.transfer_params.k1, 8); - EXPECT_THAT(Traits::a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); - EXPECT_THAT(Traits::a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); - EXPECT_THAT(Traits::a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); - EXPECT_EQ(Traits::a_tile_transfer.transfer_params.src_vector_dim, 2); - EXPECT_EQ(Traits::a_tile_transfer.transfer_params.src_scalar_per_vector, 8); - EXPECT_EQ(Traits::a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); - EXPECT_TRUE(Traits::a_tile_transfer.transfer_params.lds_padding); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.a_tile_transfer.transfer_params.lds_padding); // Verify B tile transfer info - EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.k0, 2); - EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.m_or_n, 128); - EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.k1, 8); - EXPECT_EQ(Traits::b_tile_transfer.transfer_params.k1, 8); - EXPECT_THAT(Traits::b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); - EXPECT_THAT(Traits::b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); - EXPECT_THAT(Traits::b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); - EXPECT_EQ(Traits::b_tile_transfer.transfer_params.src_vector_dim, 2); - EXPECT_EQ(Traits::b_tile_transfer.transfer_params.src_scalar_per_vector, 8); - EXPECT_EQ(Traits::b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); - EXPECT_TRUE(Traits::b_tile_transfer.transfer_params.lds_padding); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.b_tile_transfer.transfer_params.lds_padding); // Verify warp GEMM params - EXPECT_EQ(Traits::warp_gemm.gemm_m, 32); - EXPECT_EQ(Traits::warp_gemm.gemm_n, 32); - EXPECT_EQ(Traits::warp_gemm.m_iter, 4); - EXPECT_EQ(Traits::warp_gemm.n_iter, 4); + EXPECT_EQ(traits.warp_gemm.gemm_m, 32); + EXPECT_EQ(traits.warp_gemm.gemm_n, 32); + EXPECT_EQ(traits.warp_gemm.m_iter, 4); + EXPECT_EQ(traits.warp_gemm.n_iter, 4); // Verify output tile transfer info - EXPECT_EQ(Traits::c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1); - EXPECT_EQ(Traits::c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1); - EXPECT_THAT(Traits::c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8)); - EXPECT_EQ(Traits::c_tile_transfer.scalar_per_vector, 8); + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1); + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1); + EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8)); + EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8); // Verify pipeline configuration - EXPECT_EQ(Traits::pipeline_scheduler, PipelineScheduler::INTRAWAVE); - EXPECT_EQ(Traits::pipeline_version, PipelineVersion::V1); + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::INTRAWAVE); + EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1); } // Test ConvTraits with DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle @@ -214,30 +214,30 @@ TEST_F(ConvTraitsTest, ConvFwdBaseTraitsExtraction) ck::LoopScheduler::Default, // LoopSched 1>; // NumGroupsToMerge - // Use ConvTraits to extract compile-time information - using Traits = ck_tile::reflect::conv::ConvTraits; + // Use ConvTraitsTmpl to extract compile-time information + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); // Verify signature information - EXPECT_EQ(Traits::spatial_dim, 2); - EXPECT_EQ(Traits::direction, ConvDirection::FORWARD); - EXPECT_THAT(Traits::layout, + EXPECT_EQ(traits.spatial_dim, 2); + EXPECT_EQ(traits.direction, ConvDirection::FORWARD); + EXPECT_THAT(traits.layout, ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); - EXPECT_EQ(Traits::data_type, DataType::FP16); - EXPECT_EQ(Traits::input_element_op, ElementwiseOperation::PASS_THROUGH); - EXPECT_EQ(Traits::weight_element_op, ElementwiseOperation::PASS_THROUGH); - EXPECT_EQ(Traits::output_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.data_type, DataType::FP16); + EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH); // Verify specializations - EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT); - EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT); + EXPECT_EQ(traits.gemm_padding, ck_tile::builder::GemmPadding::DEFAULT); + EXPECT_EQ(traits.conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT); // Verify algorithm information - EXPECT_EQ(Traits::thread_block_size, 256); + EXPECT_EQ(traits.thread_block_size, 256); // Verify tile dimensions - EXPECT_EQ(Traits::tile_dims.m, 128); - EXPECT_EQ(Traits::tile_dims.n, 128); - EXPECT_EQ(Traits::tile_dims.k, 16); + EXPECT_EQ(traits.tile_dims.m, 128); + EXPECT_EQ(traits.tile_dims.n, 128); + EXPECT_EQ(traits.tile_dims.k, 16); } // Test ConvTraits with DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor TEST_F(ConvTraitsTest, ConvFwdLargeTensorTraitsExtraction) @@ -298,29 +298,29 @@ TEST_F(ConvTraitsTest, ConvFwdLargeTensorTraitsExtraction) ck::half_t, // BComputeDataType ck::LoopScheduler::Default>; // LoopSched - // Use ConvTraits to extract compile-time information - using Traits = ck_tile::reflect::conv::ConvTraits; + // Use ConvTraitsTmpl to extract compile-time information + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); // Verify signature information - EXPECT_EQ(Traits::spatial_dim, 2); - EXPECT_EQ(Traits::direction, ConvDirection::FORWARD); - EXPECT_THAT(Traits::layout, + EXPECT_EQ(traits.spatial_dim, 2); + EXPECT_EQ(traits.direction, ConvDirection::FORWARD); + EXPECT_THAT(traits.layout, ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); - EXPECT_EQ(Traits::data_type, DataType::FP16); - EXPECT_EQ(Traits::input_element_op, ElementwiseOperation::PASS_THROUGH); - EXPECT_EQ(Traits::weight_element_op, ElementwiseOperation::PASS_THROUGH); - EXPECT_EQ(Traits::output_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.data_type, DataType::FP16); + EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH); // Verify specializations - EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT); - EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT); + EXPECT_EQ(traits.gemm_padding, ck_tile::builder::GemmPadding::DEFAULT); + EXPECT_EQ(traits.conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT); // Verify algorithm information - EXPECT_EQ(Traits::thread_block_size, 256); + EXPECT_EQ(traits.thread_block_size, 256); // Verify tile dimensions - EXPECT_EQ(Traits::tile_dims.m, 128); - EXPECT_EQ(Traits::tile_dims.n, 128); - EXPECT_EQ(Traits::tile_dims.k, 16); + EXPECT_EQ(traits.tile_dims.m, 128); + EXPECT_EQ(traits.tile_dims.n, 128); + EXPECT_EQ(traits.tile_dims.k, 16); } } // anonymous namespace diff --git a/experimental/builder/test/conv/ck/unit_instance_to_conv_traits.cpp b/experimental/builder/test/conv/ck/unit_instance_to_conv_traits.cpp deleted file mode 100644 index 9d6fab19d1..0000000000 --- a/experimental/builder/test/conv/ck/unit_instance_to_conv_traits.cpp +++ /dev/null @@ -1,1127 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -// ============================================================================ -// Unit Tests for InstanceTraits to ConvTraits Conversion -// ============================================================================ -// -// PURPOSE: -// -------- -// These tests verify the conversion layer between InstanceTraits (low-level -// template parameter extraction) and ConvTraits (high-level semantic traits). -// The conversion transforms raw CK kernel parameters into builder-friendly -// enums and structures. -// -// DESIGN RATIONALE: -// ----------------- -// ConvTraits uses a single generic specialization that works with any Device -// class satisfying the IsXdlFwdConv concept. This use of concepts is fragile -// and introduces extra complexity. We want to refector to just use functions -// for this conversion. -// -// These tests are intentionally verbose and repetitive to provide maximum -// coverage during refactoring. Once the refactoring is complete and stable, -// they can be simplified or consolidated. -// -// TEST COVERAGE: -// -------------- -// 1. Enum conversion functions (pipeline version, scheduler, etc.) -// 2. Signature extraction (direction, specialization, layout, data type) -// 3. Full transformation verification for each XDL Device class template: -// - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 -// - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle -// - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor -// -// NOTE: WMMA and DL (Direct Load) variants are not covered as they don't -// satisfy the IsXdlFwdConv concept (different tile parameter structure). -// ============================================================================ - -#include "ck/utility/scheduler_enum.hpp" -#include "ck_tile/builder/types.hpp" -#include -#include - -#include -#include -#include -#include -#include - -namespace { - -using ck_tile::builder::ConvDirection; -using ck_tile::builder::DataType; -using ck_tile::builder::ElementwiseOperation; -using ck_tile::builder::GemmPadding; -using ck_tile::builder::PipelineScheduler; -using ck_tile::builder::PipelineVersion; -using ck_tile::builder::TensorLayout; -using ::testing::ElementsAre; - -// ============================================================================ -// Test Enum Conversion Functions -// ============================================================================ - -TEST(InstanceToConvTraits, ConvertsBlockGemmPipelineVersion) -{ - using ck_tile::reflect::conv::convert_pipeline_version; - using enum ::ck::BlockGemmPipelineVersion; - using enum ::ck_tile::builder::PipelineVersion; - - EXPECT_EQ(convert_pipeline_version(), V1); - EXPECT_EQ(convert_pipeline_version(), V2); - EXPECT_EQ(convert_pipeline_version(), V3); - EXPECT_EQ(convert_pipeline_version(), V4); - EXPECT_EQ(convert_pipeline_version(), V5); -} - -TEST(InstanceToConvTraits, ConvertsPipelineVersion) -{ - using ck_tile::reflect::conv::convert_pipeline_version; - using enum ck::PipelineVersion; - using enum PipelineVersion; - - EXPECT_EQ(convert_pipeline_version(), V1); - EXPECT_EQ(convert_pipeline_version(), V2); - EXPECT_EQ(convert_pipeline_version(), V4); - EXPECT_EQ(convert_pipeline_version(), WEIGHT_ONLY); -} - -TEST(InstanceToConvTraits, ConvertsBlockGemmPipelineScheduler) -{ - using ck_tile::reflect::conv::convert_pipeline_scheduler; - using enum ck::BlockGemmPipelineScheduler; - using enum PipelineScheduler; - - EXPECT_EQ(convert_pipeline_scheduler(), INTRAWAVE); - EXPECT_EQ(convert_pipeline_scheduler(), INTERWAVE); -} - -TEST(InstanceToConvTraits, ConvertsLoopScheduler) -{ - using ck_tile::reflect::conv::convert_pipeline_scheduler; - using enum ck::LoopScheduler; - using enum PipelineScheduler; - - EXPECT_EQ(convert_pipeline_scheduler(), DEFAULT); - EXPECT_EQ(convert_pipeline_scheduler(), INTERWAVE); -} - -// ============================================================================ -// Test Convolution Direction Detection -// ============================================================================ - -TEST(InstanceToConvTraits, DetectsForwardDirection) -{ - using DeviceInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< - 2, // NDimSpatial - ck::tensor_layout::convolution::GNHWC, // ALayout - ck::tensor_layout::convolution::GKYXC, // BLayout - ck::Tuple<>, // DsLayout - ck::tensor_layout::convolution::GNHWK, // ELayout - ck::half_t, // ADataType - ck::half_t, // BDataType - float, // AccDataType - ck::half_t, // CShuffleDataType - ck::Tuple<>, // DsDataType - ck::half_t, // EDataType - ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation - ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation - ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, - ck::tensor_operation::device::GemmSpecialization::Default, - 256, - 128, - 128, - 16, - 8, - 8, - 32, - 32, - 4, - 4, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - 1, - 1, - ck::Sequence<1, 32, 1, 8>, - 8, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - ck::half_t, - ck::half_t, - false>; - - using Traits = ck_tile::reflect::conv::ConvTraits; - - EXPECT_EQ(Traits::direction, ConvDirection::FORWARD); -} - -// ============================================================================ -// Test Convolution Specialization Detection -// ============================================================================ - -TEST(InstanceToConvTraits, ExtractsDefaultSpecialization) -{ - using DeviceInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< - 2, - ck::tensor_layout::convolution::GNHWC, - ck::tensor_layout::convolution::GKYXC, - ck::Tuple<>, - ck::tensor_layout::convolution::GNHWK, - ck::half_t, - ck::half_t, - float, - ck::half_t, - ck::Tuple<>, - ck::half_t, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, - ck::tensor_operation::device::GemmSpecialization::Default, - 256, - 128, - 128, - 16, - 8, - 8, - 32, - 32, - 4, - 4, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - 1, - 1, - ck::Sequence<1, 32, 1, 8>, - 8, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - ck::half_t, - ck::half_t, - false>; - - using Traits = ck_tile::reflect::conv::ConvTraits; - - EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT); -} - -TEST(InstanceToConvTraits, ExtractsFilter1x1Pad0Specialization) -{ - using DeviceInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< - 2, - ck::tensor_layout::convolution::GNHWC, - ck::tensor_layout::convolution::GKYXC, - ck::Tuple<>, - ck::tensor_layout::convolution::GNHWK, - ck::half_t, - ck::half_t, - float, - ck::half_t, - ck::Tuple<>, - ck::half_t, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0, - ck::tensor_operation::device::GemmSpecialization::Default, - 256, - 128, - 128, - 16, - 8, - 8, - 32, - 32, - 4, - 4, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - 1, - 1, - ck::Sequence<1, 32, 1, 8>, - 8, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - ck::half_t, - ck::half_t, - false>; - - using Traits = ck_tile::reflect::conv::ConvTraits; - - EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::FILTER_1X1_PAD0); -} - -// ============================================================================ -// Test Layout Detection -// ============================================================================ - -TEST(InstanceToConvTraits, ExtractsGnhwcLayout) -{ - using DeviceInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< - 2, - ck::tensor_layout::convolution::GNHWC, - ck::tensor_layout::convolution::GKYXC, - ck::Tuple<>, - ck::tensor_layout::convolution::GNHWK, - ck::half_t, - ck::half_t, - float, - ck::half_t, - ck::Tuple<>, - ck::half_t, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, - ck::tensor_operation::device::GemmSpecialization::Default, - 256, - 128, - 128, - 16, - 8, - 8, - 32, - 32, - 4, - 4, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - 1, - 1, - ck::Sequence<1, 32, 1, 8>, - 8, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - ck::half_t, - ck::half_t, - false>; - - using Traits = ck_tile::reflect::conv::ConvTraits; - - EXPECT_THAT(Traits::layout, - ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); -} - -TEST(InstanceToConvTraits, ExtractsNhwgcLayout) -{ - using DeviceInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< - 2, - ck::tensor_layout::convolution::NHWGC, - ck::tensor_layout::convolution::GKYXC, - ck::Tuple<>, - ck::tensor_layout::convolution::NHWGK, - ck::half_t, - ck::half_t, - float, - ck::half_t, - ck::Tuple<>, - ck::half_t, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, - ck::tensor_operation::device::GemmSpecialization::Default, - 256, - 128, - 128, - 16, - 8, - 8, - 32, - 32, - 4, - 4, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - 1, - 1, - ck::Sequence<1, 32, 1, 8>, - 8, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - ck::half_t, - ck::half_t, - false>; - - using Traits = ck_tile::reflect::conv::ConvTraits; - - EXPECT_THAT(Traits::layout, - ElementsAre(TensorLayout::NHWGC, TensorLayout::GKYXC, TensorLayout::NHWGK)); -} - -TEST(InstanceToConvTraits, ExtractsNgchwGkyxcLayout) -{ - using DeviceInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< - 2, - ck::tensor_layout::convolution::NGCHW, - ck::tensor_layout::convolution::GKYXC, - ck::Tuple<>, - ck::tensor_layout::convolution::NGKHW, - ck::half_t, - ck::half_t, - float, - ck::half_t, - ck::Tuple<>, - ck::half_t, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, - ck::tensor_operation::device::GemmSpecialization::Default, - 256, - 128, - 128, - 16, - 8, - 8, - 32, - 32, - 4, - 4, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - 1, - 1, - ck::Sequence<1, 32, 1, 8>, - 8, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - ck::half_t, - ck::half_t, - false>; - - using Traits = ck_tile::reflect::conv::ConvTraits; - - EXPECT_THAT(Traits::layout, - ElementsAre(TensorLayout::NGCHW, TensorLayout::GKYXC, TensorLayout::NGKHW)); -} - -TEST(InstanceToConvTraits, ExtractsNgchwGkcyxLayout) -{ - using DeviceInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< - 2, - ck::tensor_layout::convolution::NGCHW, - ck::tensor_layout::convolution::GKCYX, - ck::Tuple<>, - ck::tensor_layout::convolution::NGKHW, - ck::half_t, - ck::half_t, - float, - ck::half_t, - ck::Tuple<>, - ck::half_t, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, - ck::tensor_operation::device::GemmSpecialization::Default, - 256, - 128, - 128, - 16, - 8, - 8, - 32, - 32, - 4, - 4, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - 1, - 1, - ck::Sequence<1, 32, 1, 8>, - 8, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - ck::half_t, - ck::half_t, - false>; - - using Traits = ck_tile::reflect::conv::ConvTraits; - - EXPECT_THAT(Traits::layout, - ElementsAre(TensorLayout::NGCHW, TensorLayout::GKCYX, TensorLayout::NGKHW)); -} - -// ============================================================================ -// Test Data Type Detection -// ============================================================================ - -TEST(InstanceToConvTraits, ExtractsFp16DataType) -{ - using DeviceInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< - 2, - ck::tensor_layout::convolution::GNHWC, - ck::tensor_layout::convolution::GKYXC, - ck::Tuple<>, - ck::tensor_layout::convolution::GNHWK, - ck::half_t, - ck::half_t, - float, - ck::half_t, - ck::Tuple<>, - ck::half_t, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, - ck::tensor_operation::device::GemmSpecialization::Default, - 256, - 128, - 128, - 16, - 8, - 8, - 32, - 32, - 4, - 4, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - 1, - 1, - ck::Sequence<1, 32, 1, 8>, - 8, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - ck::half_t, - ck::half_t, - false>; - - using Traits = ck_tile::reflect::conv::ConvTraits; - - EXPECT_EQ(Traits::data_type, DataType::FP16); -} - -TEST(InstanceToConvTraits, ExtractsBf16DataType) -{ - using DeviceInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< - 2, - ck::tensor_layout::convolution::GNHWC, - ck::tensor_layout::convolution::GKYXC, - ck::Tuple<>, - ck::tensor_layout::convolution::GNHWK, - ck::bhalf_t, - ck::bhalf_t, - float, - ck::bhalf_t, - ck::Tuple<>, - ck::bhalf_t, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, - ck::tensor_operation::device::GemmSpecialization::Default, - 256, - 128, - 128, - 16, - 8, - 8, - 32, - 32, - 4, - 4, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - 1, - 1, - ck::Sequence<1, 32, 1, 8>, - 8, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - ck::bhalf_t, - ck::bhalf_t, - false>; - - using Traits = ck_tile::reflect::conv::ConvTraits; - - EXPECT_EQ(Traits::data_type, DataType::BF16); -} - -TEST(InstanceToConvTraits, ExtractsFp32DataType) -{ - using DeviceInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< - 2, - ck::tensor_layout::convolution::GNHWC, - ck::tensor_layout::convolution::GKYXC, - ck::Tuple<>, - ck::tensor_layout::convolution::GNHWK, - float, - float, - float, - float, - ck::Tuple<>, - float, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, - ck::tensor_operation::device::GemmSpecialization::Default, - 256, - 128, - 128, - 16, - 8, - 8, - 32, - 32, - 4, - 4, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - 1, - 1, - ck::Sequence<1, 32, 1, 8>, - 8, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - float, - float, - false>; - - using Traits = ck_tile::reflect::conv::ConvTraits; - - EXPECT_EQ(Traits::data_type, DataType::FP32); -} - -TEST(InstanceToConvTraits, ExtractsI8DataType) -{ - using DeviceInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< - 2, - ck::tensor_layout::convolution::GNHWC, - ck::tensor_layout::convolution::GKYXC, - ck::Tuple<>, - ck::tensor_layout::convolution::GNHWK, - int8_t, - int8_t, - int32_t, - int8_t, - ck::Tuple<>, - int8_t, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, - ck::tensor_operation::device::GemmSpecialization::Default, - 256, - 128, - 128, - 16, - 8, - 8, - 32, - 32, - 4, - 4, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - 1, - 1, - ck::Sequence<1, 32, 1, 8>, - 8, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - int8_t, - int8_t, - false>; - - using Traits = ck_tile::reflect::conv::ConvTraits; - - EXPECT_EQ(Traits::data_type, DataType::I8); -} - -// ============================================================================ -// Test GEMM Padding Detection -// ============================================================================ - -TEST(InstanceToConvTraits, ExtractsDefaultGemmPadding) -{ - using DeviceInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< - 2, - ck::tensor_layout::convolution::GNHWC, - ck::tensor_layout::convolution::GKYXC, - ck::Tuple<>, - ck::tensor_layout::convolution::GNHWK, - ck::half_t, - ck::half_t, - float, - ck::half_t, - ck::Tuple<>, - ck::half_t, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, - ck::tensor_operation::device::GemmSpecialization::Default, - 256, - 128, - 128, - 16, - 8, - 8, - 32, - 32, - 4, - 4, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - 1, - 1, - ck::Sequence<1, 32, 1, 8>, - 8, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - ck::half_t, - ck::half_t, - false>; - - using Traits = ck_tile::reflect::conv::ConvTraits; - - EXPECT_EQ(Traits::gemm_padding, GemmPadding::DEFAULT); -} - -TEST(InstanceToConvTraits, ExtractsMnkGemmPadding) -{ - using DeviceInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< - 2, - ck::tensor_layout::convolution::GNHWC, - ck::tensor_layout::convolution::GKYXC, - ck::Tuple<>, - ck::tensor_layout::convolution::GNHWK, - ck::half_t, - ck::half_t, - float, - ck::half_t, - ck::Tuple<>, - ck::half_t, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, - ck::tensor_operation::device::GemmSpecialization::MNKPadding, - 256, - 128, - 128, - 16, - 8, - 8, - 32, - 32, - 4, - 4, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - 1, - 1, - ck::Sequence<1, 32, 1, 8>, - 8, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - ck::half_t, - ck::half_t, - false>; - - using Traits = ck_tile::reflect::conv::ConvTraits; - - EXPECT_EQ(Traits::gemm_padding, GemmPadding::MNK_PADDING); -} - -// ============================================================================ -// Comprehensive Transformation Tests - Per Device Class Template -// ============================================================================ -// These tests verify the complete InstanceTraits → ConvTraits transformation -// for each forward convolution Device class template. They are verbose to -// provide maximum safety during refactoring. -// ============================================================================ - -TEST(InstanceToConvTraits, TransformsFwdMultipleAbdXdlCShuffleV3) -{ - using DeviceInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< - 2, // NDimSpatial - ck::tensor_layout::convolution::GNHWC, // ALayout - ck::tensor_layout::convolution::GKYXC, // BLayout - ck::Tuple<>, // DsLayout - ck::tensor_layout::convolution::GNHWK, // ELayout - ck::half_t, // ADataType - ck::half_t, // BDataType - float, // AccDataType - ck::half_t, // CShuffleDataType - ck::Tuple<>, // DsDataType - ck::half_t, // EDataType - ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation - ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation - ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, - ck::tensor_operation::device::GemmSpecialization::Default, - 256, // BlockSize - 128, // MPerBlock - 128, // NPerBlock - 16, // KPerBlock - 8, // AK1 - 8, // BK1 - 32, // MPerXDL - 32, // NPerXDL - 4, // MXdlPerWave - 4, // NXdlPerWave - ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths - ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder - ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder - 2, // ABlockTransferSrcVectorDim - 8, // ABlockTransferSrcScalarPerVector - 8, // ABlockTransferDstScalarPerVector_AK1 - 1, // ABlockLdsExtraM - ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths - ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder - ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder - 2, // BBlockTransferSrcVectorDim - 8, // BBlockTransferSrcScalarPerVector - 8, // BBlockTransferDstScalarPerVector_BK1 - 1, // BBlockLdsExtraN - 1, // CShuffleMXdlPerWavePerShuffle - 1, // CShuffleNXdlPerWavePerShuffle - ck::Sequence<1, 32, 1, 8>, // CDEBlockTransferClusterLengths - 8, // CDEBlockTransferScalarPerVector_NPerBlock - ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched - ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer - ck::half_t, // AComputeDataType - ck::half_t, // BComputeDataType - false>; // DirectLoad - - using InstTraits = ck_tile::reflect::InstanceTraits; - using ConvTraits = ck_tile::reflect::conv::ConvTraits; - - // Verify signature information - EXPECT_EQ(ConvTraits::spatial_dim, InstTraits::kSpatialDim); - EXPECT_EQ(ConvTraits::direction, ConvDirection::FORWARD); - EXPECT_EQ(ConvTraits::data_type, DataType::FP16); - EXPECT_EQ(ConvTraits::gemm_padding, GemmPadding::DEFAULT); - - // Verify tile dimensions - EXPECT_EQ(ConvTraits::tile_dims.m, InstTraits::kMPerBlock); - EXPECT_EQ(ConvTraits::tile_dims.n, InstTraits::kNPerBlock); - EXPECT_EQ(ConvTraits::tile_dims.k, InstTraits::kKPerBlock); - - // Verify pipeline configuration - EXPECT_EQ(ConvTraits::pipeline_scheduler, PipelineScheduler::INTRAWAVE); - EXPECT_EQ(ConvTraits::pipeline_version, PipelineVersion::V1); -} - -TEST(InstanceToConvTraits, TransformsFwdMultipleAbdXdlCShuffle) -{ - using DeviceInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< - 2, // NDimSpatial - ck::tensor_layout::convolution::GNHWC, // ALayout - ck::tensor_layout::convolution::GKYXC, // BLayout - ck::Tuple<>, // DsLayout - ck::tensor_layout::convolution::GNHWK, // ELayout - ck::half_t, // ADataType - ck::half_t, // BDataType - float, // AccDataType - ck::half_t, // CShuffleDataType - ck::Tuple<>, // DsDataType - ck::half_t, // EDataType - ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation - ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation - ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, - ck::tensor_operation::device::GemmSpecialization::Default, - 1, // NumGemmKPrefetchStage - 256, // BlockSize - 128, // MPerBlock - 128, // NPerBlock - 16, // KPerBlock - 8, // AK1 - 8, // BK1 - 32, // MPerXDL - 32, // NPerXDL - 4, // MXdlPerWave - 4, // NXdlPerWave - ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths - ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder - ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder - 2, // ABlockTransferSrcVectorDim - 8, // ABlockTransferSrcScalarPerVector - 8, // ABlockTransferDstScalarPerVector_AK1 - 1, // ABlockLdsExtraM - ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths - ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder - ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder - 2, // BBlockTransferSrcVectorDim - 8, // BBlockTransferSrcScalarPerVector - 8, // BBlockTransferDstScalarPerVector_BK1 - 1, // BBlockLdsExtraN - 1, // CShuffleMXdlPerWavePerShuffle - 1, // CShuffleNXdlPerWavePerShuffle - ck::Sequence<1, 32, 1, 8>, // CDEBlockTransferClusterLengths - 8, // CDEBlockTransferScalarPerVector_NPerBlock - ck::half_t, // AComputeDataType - ck::half_t, // BComputeDataType - ck::LoopScheduler::Default, // LoopSched - 1>; // NumGroupsToMerge - - using InstTraits = ck_tile::reflect::InstanceTraits; - using ConvTraits = ck_tile::reflect::conv::ConvTraits; - - // Verify signature information - EXPECT_EQ(ConvTraits::spatial_dim, InstTraits::kSpatialDim); - EXPECT_EQ(ConvTraits::direction, ConvDirection::FORWARD); - EXPECT_EQ(ConvTraits::data_type, DataType::FP16); - EXPECT_EQ(ConvTraits::gemm_padding, GemmPadding::DEFAULT); - - // Verify tile dimensions - EXPECT_EQ(ConvTraits::tile_dims.m, InstTraits::kMPerBlock); - EXPECT_EQ(ConvTraits::tile_dims.n, InstTraits::kNPerBlock); - EXPECT_EQ(ConvTraits::tile_dims.k, InstTraits::kKPerBlock); - - // Verify pipeline configuration (uses LoopScheduler instead of BlockGemmPipelineScheduler) - EXPECT_EQ(ConvTraits::pipeline_scheduler, PipelineScheduler::DEFAULT); - EXPECT_EQ(ConvTraits::pipeline_version, PipelineVersion::V1); -} - -TEST(InstanceToConvTraits, TransformsFwdMultipleDXdlLargeTensor) -{ - using DeviceInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor< - 2, // NDimSpatial - ck::tensor_layout::convolution::GNHWC, // ALayout - ck::tensor_layout::convolution::GKYXC, // BLayout - ck::Tuple<>, // DsLayout - ck::tensor_layout::convolution::GNHWK, // ELayout - ck::half_t, // ADataType - ck::half_t, // BDataType - float, // AccDataType - ck::half_t, // CShuffleDataType - ck::Tuple<>, // DsDataType - ck::half_t, // EDataType - ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation - ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation - ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, - ck::tensor_operation::device::GemmSpecialization::Default, - 1, // NumGemmKPrefetchStage - 256, // BlockSize - 128, // MPerBlock - 128, // NPerBlock - 16, // KPerBlock - 8, // AK1 - 8, // BK1 - 32, // MPerXDL - 32, // NPerXDL - 4, // MXdlPerWave - 4, // NXdlPerWave - ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths - ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder - ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder - 2, // ABlockTransferSrcVectorDim - 8, // ABlockTransferSrcScalarPerVector - 8, // ABlockTransferDstScalarPerVector_AK1 - 1, // ABlockLdsExtraM - ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths - ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder - ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder - 2, // BBlockTransferSrcVectorDim - 8, // BBlockTransferSrcScalarPerVector - 8, // BBlockTransferDstScalarPerVector_BK1 - 1, // BBlockLdsExtraN - 1, // CShuffleMXdlPerWavePerShuffle - 1, // CShuffleNXdlPerWavePerShuffle - ck::Sequence<1, 32, 1, 8>, // CDEBlockTransferClusterLengths - 8, // CDEBlockTransferScalarPerVector_NPerBlock - ck::half_t, // AComputeDataType - ck::half_t, // BComputeDataType - ck::LoopScheduler::Default>; // LoopSched - - using InstTraits = ck_tile::reflect::InstanceTraits; - using ConvTraits = ck_tile::reflect::conv::ConvTraits; - - // Verify signature information - EXPECT_EQ(ConvTraits::spatial_dim, InstTraits::kSpatialDim); - EXPECT_EQ(ConvTraits::direction, ConvDirection::FORWARD); - EXPECT_EQ(ConvTraits::data_type, DataType::FP16); - EXPECT_EQ(ConvTraits::gemm_padding, GemmPadding::DEFAULT); - - // Verify tile dimensions - EXPECT_EQ(ConvTraits::tile_dims.m, InstTraits::kMPerBlock); - EXPECT_EQ(ConvTraits::tile_dims.n, InstTraits::kNPerBlock); - EXPECT_EQ(ConvTraits::tile_dims.k, InstTraits::kKPerBlock); - - // Verify pipeline configuration - EXPECT_EQ(ConvTraits::pipeline_scheduler, PipelineScheduler::DEFAULT); - EXPECT_EQ(ConvTraits::pipeline_version, PipelineVersion::V1); -} - -} // anonymous namespace diff --git a/experimental/builder/test/conv/ck/unit_instance_to_conv_traits_features.cpp b/experimental/builder/test/conv/ck/unit_instance_to_conv_traits_features.cpp new file mode 100644 index 0000000000..72269c38ac --- /dev/null +++ b/experimental/builder/test/conv/ck/unit_instance_to_conv_traits_features.cpp @@ -0,0 +1,800 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// ============================================================================ +// Unit Tests for Individual Conversion Functions +// ============================================================================ +// +// PURPOSE: +// -------- +// These tests verify individual conversion and extraction functions that +// transform raw CK kernel parameters into semantic types. Each test +// focuses on a single conversion function to ensure it correctly maps +// CK types to builder enums and structures. +// +// TEST COVERAGE: +// -------------- +// 1. Enum Conversions: +// - Pipeline versions (BlockGemmPipelineVersion and PipelineVersion) +// - Pipeline schedulers (BlockGemmPipelineScheduler and LoopScheduler) +// +// 2. Elementwise Operations (14 operations): +// - PassThrough, Scale, Relu, Gelu, Sigmoid, Tanh, ScaleAdd +// - Silu, Swish, Elu, LeakyRelu, UnaryConvert, ConvScale, ConvScaleAdd +// +// 3. Convolution Properties: +// - Direction detection (Forward) +// - Specializations (Default, Filter1x1Pad0, Filter1x1Stride1Pad0, +// Filter3x3, OddC) +// +// 4. Layout Detection: +// - 1D layouts (GNWC, NWGC, NGCW) +// - 2D layouts (GNHWC, NHWGC, NGCHW with GKYXC/GKCYX) +// - 3D layouts (GNDHWC, NDHWGC, NGCDHW) +// +// 5. Data Type Detection: +// - FP16, BF16, FP32, I8 +// +// 6. Pipeline Configuration: +// - Pipeline versions (V2, V3) +// - Schedulers (Interwave) +// +// 7. GEMM Padding Variations (17 types): +// - Default, MNK, M, N, K, MN, MK, NK +// - O, MO, NO, KO, MNO, MKO, NKO, MNKO +// ============================================================================ + +#include "ck/utility/scheduler_enum.hpp" +#include "ck_tile/builder/types.hpp" +#include +#include + +#include +#include +#include +#include +#include + +namespace { + +using ::ck_tile::builder::ConvDirection; +using ::ck_tile::builder::DataType; +using ::ck_tile::builder::ElementwiseOperation; +using ::ck_tile::builder::GemmPadding; +using ::ck_tile::builder::PipelineScheduler; +using ::ck_tile::builder::PipelineVersion; +using ::ck_tile::builder::TensorLayout; +using ::testing::ElementsAre; + +// ============================================================================ +// Test Helper Templates +// ============================================================================ +// These templates reduce boilerplate by providing sensible defaults for +// template parameters that don't vary in most tests. +// ============================================================================ + +namespace defaults { +// Default values used across most tests +static constexpr int kBlockSize = 256; +static constexpr int kMPerBlock = 128; +static constexpr int kNPerBlock = 128; +static constexpr int kKPerBlock = 16; +static constexpr int kAK1 = 8; +static constexpr int kBK1 = 8; +static constexpr int kMPerXDL = 32; +static constexpr int kNPerXDL = 32; +static constexpr int kMXdlPerWave = 4; +static constexpr int kNXdlPerWave = 4; +static constexpr int kABlockTransferSrcVectorDim = 2; +static constexpr int kABlockTransferSrcScalarPerVector = 8; +static constexpr int kABlockTransferDstScalarPerVector_AK1 = 8; +static constexpr int kABlockLdsExtraM = 1; +static constexpr int kBBlockTransferSrcVectorDim = 2; +static constexpr int kBBlockTransferSrcScalarPerVector = 8; +static constexpr int kBBlockTransferDstScalarPerVector_BK1 = 8; +static constexpr int kBBlockLdsExtraN = 1; +static constexpr int kCShuffleMXdlPerWavePerShuffle = 1; +static constexpr int kCShuffleNXdlPerWavePerShuffle = 1; +static constexpr int kCDEBlockTransferScalarPerVector_NPerBlock = 8; +static constexpr bool kDirectLoad = false; + +using DefaultABlockTransferThreadClusterLengths = ck::Sequence<4, 64, 1>; +using DefaultABlockTransferThreadClusterArrangeOrder = ck::Sequence<1, 0, 2>; +using DefaultABlockTransferSrcAccessOrder = ck::Sequence<1, 0, 2>; +using DefaultBBlockTransferThreadClusterLengths = ck::Sequence<4, 64, 1>; +using DefaultBBlockTransferThreadClusterArrangeOrder = ck::Sequence<1, 0, 2>; +using DefaultBBlockTransferSrcAccessOrder = ck::Sequence<1, 0, 2>; +using DefaultCDEBlockTransferClusterLengths = ck::Sequence<1, 32, 1, 8>; +} // namespace defaults + +// DeviceInstanceForTests - V3 variant with sensible defaults +template +using DeviceInstanceForTests_V3 = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< + NDimSpatial, + ALayout, + BLayout, + ck::Tuple<>, + ELayout, + ADataType, + BDataType, + AccDataType, + ADataType, + ck::Tuple<>, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + ConvForwardSpecialization, + GemmSpec, + defaults::kBlockSize, + defaults::kMPerBlock, + defaults::kNPerBlock, + defaults::kKPerBlock, + defaults::kAK1, + defaults::kBK1, + defaults::kMPerXDL, + defaults::kNPerXDL, + defaults::kMXdlPerWave, + defaults::kNXdlPerWave, + defaults::DefaultABlockTransferThreadClusterLengths, + defaults::DefaultABlockTransferThreadClusterArrangeOrder, + defaults::DefaultABlockTransferSrcAccessOrder, + defaults::kABlockTransferSrcVectorDim, + defaults::kABlockTransferSrcScalarPerVector, + defaults::kABlockTransferDstScalarPerVector_AK1, + defaults::kABlockLdsExtraM, + defaults::DefaultBBlockTransferThreadClusterLengths, + defaults::DefaultBBlockTransferThreadClusterArrangeOrder, + defaults::DefaultBBlockTransferSrcAccessOrder, + defaults::kBBlockTransferSrcVectorDim, + defaults::kBBlockTransferSrcScalarPerVector, + defaults::kBBlockTransferDstScalarPerVector_BK1, + defaults::kBBlockLdsExtraN, + defaults::kCShuffleMXdlPerWavePerShuffle, + defaults::kCShuffleNXdlPerWavePerShuffle, + defaults::DefaultCDEBlockTransferClusterLengths, + defaults::kCDEBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ADataType, + BDataType, + defaults::kDirectLoad>; + +// Test case helper for specialization testing +template +using SpecializationTestInstance = + DeviceInstanceForTests_V3<2, + ck::tensor_layout::convolution::GNHWC, + ck::tensor_layout::convolution::GKYXC, + ck::tensor_layout::convolution::GNHWK, + ck::half_t, + ck::half_t, + ck::half_t, + float, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + Spec>; + +// Test case helper for layout testing (1D, 2D, 3D) +template +using LayoutTestInstance = DeviceInstanceForTests_V3; + +// Test case helper for data type testing +template +using DataTypeTestInstance = DeviceInstanceForTests_V3<2, + ck::tensor_layout::convolution::GNHWC, + ck::tensor_layout::convolution::GKYXC, + ck::tensor_layout::convolution::GNHWK, + DataType, + DataType, + DataType, + AccDataType>; + +// Test case helper for pipeline version testing +template +using PipelineVersionTestInstance = DeviceInstanceForTests_V3< + 2, + ck::tensor_layout::convolution::GNHWC, + ck::tensor_layout::convolution::GKYXC, + ck::tensor_layout::convolution::GNHWK, + ck::half_t, + ck::half_t, + ck::half_t, + float, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, + ck::tensor_operation::device::GemmSpecialization::Default, + ck::BlockGemmPipelineScheduler::Intrawave, + PipelineVer>; + +// Test case helper for pipeline scheduler testing +template +using PipelineSchedulerTestInstance = DeviceInstanceForTests_V3< + 2, + ck::tensor_layout::convolution::GNHWC, + ck::tensor_layout::convolution::GKYXC, + ck::tensor_layout::convolution::GNHWK, + ck::half_t, + ck::half_t, + ck::half_t, + float, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, + ck::tensor_operation::device::GemmSpecialization::Default, + Scheduler>; + +// Test case helper for GEMM padding testing +template +using GemmPaddingTestInstance = DeviceInstanceForTests_V3< + 2, + ck::tensor_layout::convolution::GNHWC, + ck::tensor_layout::convolution::GKYXC, + ck::tensor_layout::convolution::GNHWK, + ck::half_t, + ck::half_t, + ck::half_t, + float, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, + GemmSpec>; + +// ============================================================================ +// Test Enum Conversion Functions +// ============================================================================ + +TEST(InstanceToConvTraits, ConvertsBlockGemmPipelineVersion) +{ + using ck_tile::reflect::conv::convert_pipeline_version; + using enum ::ck::BlockGemmPipelineVersion; + using enum ::ck_tile::builder::PipelineVersion; + EXPECT_EQ(convert_pipeline_version(), V1); + EXPECT_EQ(convert_pipeline_version(), V2); + EXPECT_EQ(convert_pipeline_version(), V3); + EXPECT_EQ(convert_pipeline_version(), V4); + EXPECT_EQ(convert_pipeline_version(), V5); +} + +TEST(InstanceToConvTraits, ConvertsPipelineVersion) +{ + using ck_tile::reflect::conv::convert_pipeline_version; + using enum ck::PipelineVersion; + using enum PipelineVersion; + EXPECT_EQ(convert_pipeline_version(), V1); + EXPECT_EQ(convert_pipeline_version(), V2); + EXPECT_EQ(convert_pipeline_version(), V4); + EXPECT_EQ(convert_pipeline_version(), WEIGHT_ONLY); +} + +TEST(InstanceToConvTraits, ConvertsBlockGemmPipelineScheduler) +{ + using ck_tile::reflect::conv::convert_pipeline_scheduler; + using enum ck::BlockGemmPipelineScheduler; + using enum PipelineScheduler; + EXPECT_EQ(convert_pipeline_scheduler(), INTRAWAVE); + EXPECT_EQ(convert_pipeline_scheduler(), INTERWAVE); +} + +TEST(InstanceToConvTraits, ConvertsLoopScheduler) +{ + using ck_tile::reflect::conv::convert_pipeline_scheduler; + using enum ck::LoopScheduler; + using enum PipelineScheduler; + EXPECT_EQ(convert_pipeline_scheduler(), DEFAULT); + EXPECT_EQ(convert_pipeline_scheduler(), INTERWAVE); +} + +// ============================================================================ +// Test Elementwise Operations +// ============================================================================ + +TEST(InstanceToConvTraits, ExtractsPassThroughOperation) +{ + using enum ElementwiseOperation; + constexpr auto op = + ck_tile::reflect::conv::elementwise_op(); + EXPECT_EQ(op, PASS_THROUGH); +} + +TEST(InstanceToConvTraits, ExtractsScaleOperation) +{ + using enum ElementwiseOperation; + constexpr auto op = + ck_tile::reflect::conv::elementwise_op(); + EXPECT_EQ(op, SCALE); +} + +TEST(InstanceToConvTraits, ExtractsReluOperation) +{ + using enum ElementwiseOperation; + constexpr auto op = + ck_tile::reflect::conv::elementwise_op(); + EXPECT_EQ(op, RELU); +} + +TEST(InstanceToConvTraits, ExtractsGeluOperation) +{ + using enum ElementwiseOperation; + constexpr auto op = + ck_tile::reflect::conv::elementwise_op(); + EXPECT_EQ(op, GELU); +} + +TEST(InstanceToConvTraits, ExtractsSigmoidOperation) +{ + using enum ElementwiseOperation; + constexpr auto op = + ck_tile::reflect::conv::elementwise_op(); + EXPECT_EQ(op, SIGMOID); +} + +TEST(InstanceToConvTraits, ExtractsTanhOperation) +{ + using enum ElementwiseOperation; + constexpr auto op = + ck_tile::reflect::conv::elementwise_op(); + EXPECT_EQ(op, TANH); +} + +TEST(InstanceToConvTraits, ExtractsScaleAddOperation) +{ + using enum ElementwiseOperation; + constexpr auto op = + ck_tile::reflect::conv::elementwise_op(); + EXPECT_EQ(op, SCALE_ADD); +} + +TEST(InstanceToConvTraits, ExtractsSiluOperation) +{ + using enum ElementwiseOperation; + constexpr auto op = + ck_tile::reflect::conv::elementwise_op(); + EXPECT_EQ(op, SILU); +} + +TEST(InstanceToConvTraits, ExtractsSwishOperation) +{ + using enum ElementwiseOperation; + constexpr auto op = + ck_tile::reflect::conv::elementwise_op(); + EXPECT_EQ(op, SWISH); +} + +TEST(InstanceToConvTraits, ExtractsEluOperation) +{ + using enum ElementwiseOperation; + constexpr auto op = + ck_tile::reflect::conv::elementwise_op(); + EXPECT_EQ(op, ELU); +} + +TEST(InstanceToConvTraits, ExtractsLeakyReluOperation) +{ + using enum ElementwiseOperation; + constexpr auto op = + ck_tile::reflect::conv::elementwise_op(); + EXPECT_EQ(op, LEAKY_RELU); +} + +TEST(InstanceToConvTraits, ExtractsUnaryConvertOperation) +{ + using enum ElementwiseOperation; + constexpr auto op = + ck_tile::reflect::conv::elementwise_op(); + EXPECT_EQ(op, UNARY_CONVERT); +} + +TEST(InstanceToConvTraits, ExtractsConvScaleOperation) +{ + using enum ElementwiseOperation; + constexpr auto op = + ck_tile::reflect::conv::elementwise_op(); + EXPECT_EQ(op, CONV_SCALE); +} + +TEST(InstanceToConvTraits, ExtractsConvScaleAddOperation) +{ + using enum ElementwiseOperation; + constexpr auto op = + ck_tile::reflect::conv::elementwise_op(); + EXPECT_EQ(op, CONV_SCALE_ADD); +} + +// ============================================================================ +// Test Convolution Direction Detection +// ============================================================================ + +TEST(InstanceToConvTraits, DetectsForwardDirection) +{ + using DeviceInstance = DeviceInstanceForTests_V3<>; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.direction, ConvDirection::FORWARD); +} + +// ============================================================================ +// Test Convolution Specialization Detection +// ============================================================================ + +TEST(InstanceToConvTraits, ExtractsDefaultSpecialization) +{ + using DeviceInstance = SpecializationTestInstance< + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default>; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT); +} + +TEST(InstanceToConvTraits, ExtractsFilter1x1Pad0Specialization) +{ + using DeviceInstance = SpecializationTestInstance< + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0>; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.conv_specialization, ck_tile::builder::ConvSpecialization::FILTER_1X1_PAD0); +} + +TEST(InstanceToConvTraits, ExtractsFilter1x1Stride1Pad0Specialization) +{ + using DeviceInstance = SpecializationTestInstance< + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0>; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.conv_specialization, + ck_tile::builder::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0); +} + +TEST(InstanceToConvTraits, ExtractsFilter3x3Specialization) +{ + using DeviceInstance = SpecializationTestInstance< + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter3x3>; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.conv_specialization, ck_tile::builder::ConvSpecialization::FILTER_3x3); +} + +TEST(InstanceToConvTraits, ExtractsOddCSpecialization) +{ + using DeviceInstance = SpecializationTestInstance< + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC>; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.conv_specialization, ck_tile::builder::ConvSpecialization::ODD_C); +} + +// ============================================================================ +// Test 1D Convolution Layout Detection +// ============================================================================ + +TEST(InstanceToConvTraits, ExtractsGnwcLayout) +{ + using DeviceInstance = LayoutTestInstance<1, + ck::tensor_layout::convolution::GNWC, + ck::tensor_layout::convolution::GKXC, + ck::tensor_layout::convolution::GNWK>; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.spatial_dim, 1); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::GNWC, TensorLayout::GKXC, TensorLayout::GNWK)); +} + +TEST(InstanceToConvTraits, ExtractsNwgcLayout) +{ + using DeviceInstance = LayoutTestInstance<1, + ck::tensor_layout::convolution::NWGC, + ck::tensor_layout::convolution::GKXC, + ck::tensor_layout::convolution::NWGK>; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.spatial_dim, 1); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::NWGC, TensorLayout::GKXC, TensorLayout::NWGK)); +} + +TEST(InstanceToConvTraits, ExtractsNgcwLayout) +{ + using DeviceInstance = LayoutTestInstance<1, + ck::tensor_layout::convolution::NGCW, + ck::tensor_layout::convolution::GKXC, + ck::tensor_layout::convolution::NGKW>; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.spatial_dim, 1); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::NGCW, TensorLayout::GKXC, TensorLayout::NGKW)); +} + +// ============================================================================ +// Test 2D Convolution Layout Detection +// ============================================================================ + +TEST(InstanceToConvTraits, ExtractsGnhwcLayout) +{ + using DeviceInstance = LayoutTestInstance<2, + ck::tensor_layout::convolution::GNHWC, + ck::tensor_layout::convolution::GKYXC, + ck::tensor_layout::convolution::GNHWK>; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); +} + +TEST(InstanceToConvTraits, ExtractsNhwgcLayout) +{ + using DeviceInstance = LayoutTestInstance<2, + ck::tensor_layout::convolution::NHWGC, + ck::tensor_layout::convolution::GKYXC, + ck::tensor_layout::convolution::NHWGK>; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::NHWGC, TensorLayout::GKYXC, TensorLayout::NHWGK)); +} + +TEST(InstanceToConvTraits, ExtractsNgchwGkyxcLayout) +{ + using DeviceInstance = LayoutTestInstance<2, + ck::tensor_layout::convolution::NGCHW, + ck::tensor_layout::convolution::GKYXC, + ck::tensor_layout::convolution::NGKHW>; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::NGCHW, TensorLayout::GKYXC, TensorLayout::NGKHW)); +} + +TEST(InstanceToConvTraits, ExtractsNgchwGkcyxLayout) +{ + using DeviceInstance = LayoutTestInstance<2, + ck::tensor_layout::convolution::NGCHW, + ck::tensor_layout::convolution::GKCYX, + ck::tensor_layout::convolution::NGKHW>; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::NGCHW, TensorLayout::GKCYX, TensorLayout::NGKHW)); +} + +// ============================================================================ +// Test 3D Convolution Layout Detection +// ============================================================================ + +TEST(InstanceToConvTraits, ExtractsGndhwcLayout) +{ + using DeviceInstance = LayoutTestInstance<3, + ck::tensor_layout::convolution::GNDHWC, + ck::tensor_layout::convolution::GKZYXC, + ck::tensor_layout::convolution::GNDHWK>; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.spatial_dim, 3); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::GNDHWC, TensorLayout::GKZYXC, TensorLayout::GNDHWK)); +} + +TEST(InstanceToConvTraits, ExtractsNdhwgcLayout) +{ + using DeviceInstance = LayoutTestInstance<3, + ck::tensor_layout::convolution::NDHWGC, + ck::tensor_layout::convolution::GKZYXC, + ck::tensor_layout::convolution::NDHWGK>; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.spatial_dim, 3); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::NDHWGC, TensorLayout::GKZYXC, TensorLayout::NDHWGK)); +} + +TEST(InstanceToConvTraits, ExtractsNgcdhwLayout) +{ + using DeviceInstance = LayoutTestInstance<3, + ck::tensor_layout::convolution::NGCDHW, + ck::tensor_layout::convolution::GKZYXC, + ck::tensor_layout::convolution::NGKDHW>; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.spatial_dim, 3); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::NGCDHW, TensorLayout::GKZYXC, TensorLayout::NGKDHW)); +} + +// ============================================================================ +// Test Data Type Detection +// ============================================================================ + +TEST(InstanceToConvTraits, ExtractsFp16DataType) +{ + using DeviceInstance = DataTypeTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.data_type, DataType::FP16); +} + +TEST(InstanceToConvTraits, ExtractsBf16DataType) +{ + using DeviceInstance = DataTypeTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.data_type, DataType::BF16); +} + +TEST(InstanceToConvTraits, ExtractsFp32DataType) +{ + using DeviceInstance = DataTypeTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.data_type, DataType::FP32); +} + +TEST(InstanceToConvTraits, ExtractsI8DataType) +{ + using DeviceInstance = DataTypeTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.data_type, DataType::I8); +} + +// ============================================================================ +// Test Pipeline Version Detection +// ============================================================================ + +TEST(InstanceToConvTraits, ExtractsPipelineV2) +{ + using DeviceInstance = PipelineVersionTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.pipeline_version, PipelineVersion::V2); +} + +TEST(InstanceToConvTraits, ExtractsPipelineV3) +{ + using DeviceInstance = PipelineVersionTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.pipeline_version, PipelineVersion::V3); +} + +TEST(InstanceToConvTraits, ExtractsInterwaveScheduler) +{ + using DeviceInstance = PipelineSchedulerTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::INTERWAVE); +} + +// ============================================================================ +// Test GEMM Padding Detection +// ============================================================================ + +TEST(InstanceToConvTraits, ExtractsDefaultGemmPadding) +{ + using DeviceInstance = + GemmPaddingTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.gemm_padding, GemmPadding::DEFAULT); +} + +TEST(InstanceToConvTraits, ExtractsMnkGemmPadding) +{ + using DeviceInstance = + GemmPaddingTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.gemm_padding, GemmPadding::MNK_PADDING); +} + +TEST(InstanceToConvTraits, ExtractsMPaddingGemmPadding) +{ + using DeviceInstance = + GemmPaddingTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.gemm_padding, GemmPadding::M_PADDING); +} + +TEST(InstanceToConvTraits, ExtractsNPaddingGemmPadding) +{ + using DeviceInstance = + GemmPaddingTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.gemm_padding, GemmPadding::N_PADDING); +} + +TEST(InstanceToConvTraits, ExtractsKPaddingGemmPadding) +{ + using DeviceInstance = + GemmPaddingTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.gemm_padding, GemmPadding::K_PADDING); +} + +TEST(InstanceToConvTraits, ExtractsMnPaddingGemmPadding) +{ + using DeviceInstance = + GemmPaddingTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.gemm_padding, GemmPadding::MN_PADDING); +} + +TEST(InstanceToConvTraits, ExtractsMkPaddingGemmPadding) +{ + using DeviceInstance = + GemmPaddingTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.gemm_padding, GemmPadding::MK_PADDING); +} + +TEST(InstanceToConvTraits, ExtractsNkPaddingGemmPadding) +{ + using DeviceInstance = + GemmPaddingTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.gemm_padding, GemmPadding::NK_PADDING); +} + +TEST(InstanceToConvTraits, ExtractsOPaddingGemmPadding) +{ + using DeviceInstance = + GemmPaddingTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.gemm_padding, GemmPadding::O_PADDING); +} + +TEST(InstanceToConvTraits, ExtractsMoPaddingGemmPadding) +{ + using DeviceInstance = + GemmPaddingTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.gemm_padding, GemmPadding::MO_PADDING); +} + +TEST(InstanceToConvTraits, ExtractsNoPaddingGemmPadding) +{ + using DeviceInstance = + GemmPaddingTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.gemm_padding, GemmPadding::NO_PADDING); +} + +TEST(InstanceToConvTraits, ExtractsKoPaddingGemmPadding) +{ + using DeviceInstance = + GemmPaddingTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.gemm_padding, GemmPadding::KO_PADDING); +} + +TEST(InstanceToConvTraits, ExtractsMnoPaddingGemmPadding) +{ + using DeviceInstance = + GemmPaddingTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.gemm_padding, GemmPadding::MNO_PADDING); +} + +TEST(InstanceToConvTraits, ExtractsMkoPaddingGemmPadding) +{ + using DeviceInstance = + GemmPaddingTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.gemm_padding, GemmPadding::MKO_PADDING); +} + +TEST(InstanceToConvTraits, ExtractsNkoPaddingGemmPadding) +{ + using DeviceInstance = + GemmPaddingTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.gemm_padding, GemmPadding::NKO_PADDING); +} + +TEST(InstanceToConvTraits, ExtractsMnkoPaddingGemmPadding) +{ + using DeviceInstance = + GemmPaddingTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.gemm_padding, GemmPadding::MNKO_PADDING); +} + +} // anonymous namespace diff --git a/experimental/builder/test/conv/ck/unit_instance_to_conv_traits_instances.cpp b/experimental/builder/test/conv/ck/unit_instance_to_conv_traits_instances.cpp new file mode 100644 index 0000000000..38942f9d45 --- /dev/null +++ b/experimental/builder/test/conv/ck/unit_instance_to_conv_traits_instances.cpp @@ -0,0 +1,262 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// ============================================================================ +// Unit Tests for Complete Device Instance Transformations +// ============================================================================ +// +// PURPOSE: +// -------- +// These tests verify the complete instance_to_conv_traits transformation +// for entire Device class templates. Each test validates that all traits +// are correctly extracted from a specific Device class instantiation. +// +// TEST COVERAGE: +// -------------- +// Complete transformation verification for each XDL Device class template: +// 1. DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 +// 2. DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle +// 3. DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor +// +// Each test verifies: +// - Spatial dimension extraction +// - Convolution direction +// - Data type detection +// - GEMM padding configuration +// - Tile dimensions (M, N, K per block) +// - Pipeline scheduler and version +// ============================================================================ + +#include + +#include +#include +#include +#include +#include + +namespace { + +using ::ck_tile::builder::ConvDirection; +using ::ck_tile::builder::DataType; +using ::ck_tile::builder::GemmPadding; +using ::ck_tile::builder::PipelineScheduler; +using ::ck_tile::builder::PipelineVersion; + +// ============================================================================ +// Comprehensive Transformation Tests - Per Device Class Template +// ============================================================================ +// These tests verify the complete InstanceTraits → ConvTraits transformation +// for each forward convolution Device class template. +// ============================================================================ + +TEST(InstanceToConvTraits, TransformsFwdMultipleAbdXdlCShuffleV3) +{ + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // ALayout + ck::tensor_layout::convolution::GKYXC, // BLayout + ck::Tuple<>, // DsLayout + ck::tensor_layout::convolution::GNHWK, // ELayout + ck::half_t, // ADataType + ck::half_t, // BDataType + float, // AccDataType + ck::half_t, // CShuffleDataType + ck::Tuple<>, // DsDataType + ck::half_t, // EDataType + ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, + ck::tensor_operation::device::GemmSpecialization::Default, + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, 32, 1, 8>, // CDEBlockTransferClusterLengths + 8, // CDEBlockTransferScalarPerVector_NPerBlock + ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched + ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer + ck::half_t, // AComputeDataType + ck::half_t, // BComputeDataType + false>; // DirectLoad + + using InstTraits = ck_tile::reflect::InstanceTraits; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + // Verify signature information + EXPECT_EQ(traits.spatial_dim, InstTraits::kSpatialDim); + EXPECT_EQ(traits.direction, ConvDirection::FORWARD); + EXPECT_EQ(traits.data_type, DataType::FP16); + EXPECT_EQ(traits.gemm_padding, GemmPadding::DEFAULT); + // Verify tile dimensions + EXPECT_EQ(traits.tile_dims.m, InstTraits::kMPerBlock); + EXPECT_EQ(traits.tile_dims.n, InstTraits::kNPerBlock); + EXPECT_EQ(traits.tile_dims.k, InstTraits::kKPerBlock); + // Verify pipeline configuration + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::INTRAWAVE); + EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1); +} + +TEST(InstanceToConvTraits, TransformsFwdMultipleAbdXdlCShuffle) +{ + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // ALayout + ck::tensor_layout::convolution::GKYXC, // BLayout + ck::Tuple<>, // DsLayout + ck::tensor_layout::convolution::GNHWK, // ELayout + ck::half_t, // ADataType + ck::half_t, // BDataType + float, // AccDataType + ck::half_t, // CShuffleDataType + ck::Tuple<>, // DsDataType + ck::half_t, // EDataType + ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, + ck::tensor_operation::device::GemmSpecialization::Default, + 1, // NumGemmKPrefetchStage + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, 32, 1, 8>, // CDEBlockTransferClusterLengths + 8, // CDEBlockTransferScalarPerVector_NPerBlock + ck::half_t, // AComputeDataType + ck::half_t, // BComputeDataType + ck::LoopScheduler::Default, // LoopSched + 1>; // NumGroupsToMerge + + using InstTraits = ck_tile::reflect::InstanceTraits; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + // Verify signature information + EXPECT_EQ(traits.spatial_dim, InstTraits::kSpatialDim); + EXPECT_EQ(traits.direction, ConvDirection::FORWARD); + EXPECT_EQ(traits.data_type, DataType::FP16); + EXPECT_EQ(traits.gemm_padding, GemmPadding::DEFAULT); + // Verify tile dimensions + EXPECT_EQ(traits.tile_dims.m, InstTraits::kMPerBlock); + EXPECT_EQ(traits.tile_dims.n, InstTraits::kNPerBlock); + EXPECT_EQ(traits.tile_dims.k, InstTraits::kKPerBlock); + // Verify pipeline configuration (uses LoopScheduler instead of BlockGemmPipelineScheduler) + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); + EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1); +} + +TEST(InstanceToConvTraits, TransformsFwdMultipleDXdlLargeTensor) +{ + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // ALayout + ck::tensor_layout::convolution::GKYXC, // BLayout + ck::Tuple<>, // DsLayout + ck::tensor_layout::convolution::GNHWK, // ELayout + ck::half_t, // ADataType + ck::half_t, // BDataType + float, // AccDataType + ck::half_t, // CShuffleDataType + ck::Tuple<>, // DsDataType + ck::half_t, // EDataType + ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, + ck::tensor_operation::device::GemmSpecialization::Default, + 1, // NumGemmKPrefetchStage + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, 32, 1, 8>, // CDEBlockTransferClusterLengths + 8, // CDEBlockTransferScalarPerVector_NPerBlock + ck::half_t, // AComputeDataType + ck::half_t, // BComputeDataType + ck::LoopScheduler::Default>; // LoopSched + + using InstTraits = ck_tile::reflect::InstanceTraits; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + // Verify signature information + EXPECT_EQ(traits.spatial_dim, InstTraits::kSpatialDim); + EXPECT_EQ(traits.direction, ConvDirection::FORWARD); + EXPECT_EQ(traits.data_type, DataType::FP16); + EXPECT_EQ(traits.gemm_padding, GemmPadding::DEFAULT); + // Verify tile dimensions + EXPECT_EQ(traits.tile_dims.m, InstTraits::kMPerBlock); + EXPECT_EQ(traits.tile_dims.n, InstTraits::kNPerBlock); + EXPECT_EQ(traits.tile_dims.k, InstTraits::kKPerBlock); + // Verify pipeline configuration + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); + EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1); +} + +} // anonymous namespace diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index 7cb0ae20c3..cc343f6f69 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -2108,7 +2108,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle std::unique_ptr describe() const override { - static_assert(ck_tile::reflect::conv::HasConvTraits, + static_assert(ck_tile::reflect::HasConvTraits, "ConvTraits specialization not found for this device operation. " "If you modified the template parameters of this class, ensure that " "the corresponding ConvTraits specialization in " diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp index 4f410d0cce..c9fb8ca3f6 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp @@ -1282,7 +1282,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor std::unique_ptr describe() const override { static_assert( - ck_tile::reflect::conv::HasConvTraits, + ck_tile::reflect::HasConvTraits, "ConvTraits specialization not found for this device operation. " "If you modified the template parameters of this class, ensure that " "the corresponding ConvTraits specialization in " From 993d3e2f0e02c78d6cb20f040c688f7ccf338898 Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Thu, 15 Jan 2026 22:11:44 +0800 Subject: [PATCH 09/99] [FMHA] Enable page size 16 for batch prefill kernel (#3568) * [FMHA] Enable page size 16 for batch prefill kernel * Refactor batch prefill KV offset logic to simplify template arguments - Remove redundant `kLog2PageSize` and `kIsVTileFitsInPage` from template args. - Add static assert to forbid `page_size=1` with vectorized layout. --- .../01_fmha/codegen/ops/fmha_batch_prefill.py | 2 +- ..._batch_prefill_pipeline_qr_ks_vs_async.hpp | 76 ++++++++++++++----- .../pipeline/block_fmha_pipeline_problem.hpp | 12 +-- 3 files changed, 62 insertions(+), 28 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index 37d296aa91..9a2d727253 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -36,7 +36,7 @@ DTYPE_BITS = { K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256} -SUPPORTED_PAGE_SIZE = [1, 128, 256, 1024] +SUPPORTED_PAGE_SIZE = [1, 16, 1024] SUPPORTED_KV_MEMORY_LAYOUT = ["vectorized", "linear"] SUPPORTED_KV_LOOKUP_TABLE = ["vllm", "sglang"] KV_MEMORY_LAYOUT_ENUM_MAP = { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 09b3f07883..c75f5d58c4 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -17,12 +17,12 @@ template CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx, const index_t& stride_token, @@ -31,6 +31,17 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx, OffsetVecType& kv_offset_vec, index_t global_seq_offset = 0) { + static constexpr index_t kLog2PageSize = [] { + index_t shift = 0; + index_t val = kPageBlockSize; + while(val > 1) + { + val >>= 1; + shift++; + } + return shift; + }(); + const index_t& thread_coord_start = coord_vec[kCoordAxis]; constexpr index_t kInPageOffsetMask = (1 << kLog2PageSize) - 1; if constexpr(kIsKcache) @@ -48,7 +59,10 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx, else { // for v offsets - if constexpr(kLog2PageSize == 0 && + // for page_size > 1, the V tile crosses pages when page_size is not a multiple of kN0. + static constexpr bool kVTileCrossesPages = + (kPageBlockSize > 1) && (kPageBlockSize % kN0 != 0); + if constexpr(kPageBlockSize == 1 && kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT) { // page size = 1, per-token page lookup. @@ -64,11 +78,42 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx, kv_offset_vec[k0] = page_base_offset; }); } - else + else if constexpr(kVTileCrossesPages) { - // This path handles page_size > 1 and/or non-linear KV layout, where page_idx is - // indexed by page_id (token_idx >> log2_page_size) with an in-page offset. - // Assumes the V tile stays within a single page so lane0 can broadcast the page id. + // V tile crosses multiple pages (e.g., page_size < kN0), so page_id must be computed + // per token. + static_for<0, kLoopCount, 1>{}([&](auto k0) { + const index_t global_token_idx = + global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; + const index_t page_id = global_token_idx >> kLog2PageSize; + const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask; + + const long_index_t page_base_offset = + static_cast(page_idx[page_id]) * stride_page_block; + + if constexpr(kKVMemoryLayout == + BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT) + { + // Vectorized layout uses a packed [token/kVectorSize, head_dim, kVectorSize] + // address pattern. + const long_index_t token_offset = + static_cast((token_idx_in_page / kVectorSize) * + (stride_token * kVectorSize)) + + (token_idx_in_page % kVectorSize); + + kv_offset_vec[k0] = page_base_offset + token_offset; + } + else // BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT + { + kv_offset_vec[k0] = page_base_offset + + static_cast(token_idx_in_page) * stride_token; + } + }); + } + else // !kVTileCrossesPages + { + // V tile is fully contained in one page, so page_id is shared. + // Use lane0 to compute page_id once and broadcast page_base_offset. const index_t lane0_start = __builtin_amdgcn_readfirstlane(thread_coord_start); const index_t lane0_page_id = (global_seq_offset + lane0_start + kLoopStart) >> kLog2PageSize; @@ -77,8 +122,9 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx, static_cast(page_idx[lane0_page_id]) * stride_page_block; static_for<0, kLoopCount, 1>{}([&](auto k0) { + // kLoopStride allows non-unit token spacing in the tile distribution. const index_t token_idx_in_page = - (global_seq_offset + thread_coord_start + kLoopStart + k0.value) & + (global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value) & kInPageOffsetMask; if constexpr(kKVMemoryLayout == @@ -142,7 +188,6 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; static constexpr index_t kPageBlockSize = Problem::kPageBlockSize; - static constexpr index_t kLog2PageSize = Problem::kLog2PageSize; static constexpr index_t kVectorSize = Problem::kVectorSize; static constexpr auto I0 = number<0>{}; static constexpr auto I1 = number<1>{}; @@ -150,9 +195,6 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync static constexpr auto I3 = number<3>{}; static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); - static_assert(kPageBlockSize % kN0 == 0 || kLog2PageSize == 0, - "Page size must be 1, or a multiple of the tile size (kN0)."); - static constexpr bool kIsGroupMode = Problem::kIsGroupMode; // TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x) // only need special care about seq_k padding (oob need set -INF of p instead of zero) @@ -456,12 +498,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync decltype(k_coord), 0, kPageBlockSize, - kLog2PageSize, 0, NRepeat, kN0 / NRepeat, kKVMemoryLayout, true, + kN0, kVectorSize>( page_idx, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k); @@ -501,12 +543,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync decltype(v_coord), VPageIndexDim, kPageBlockSize, - kLog2PageSize, 0, V_KRepeat, 1, kKVMemoryLayout, false, + kN0, kVectorSize>( page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k); @@ -587,12 +629,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync decltype(v_coord), VPageIndexDim, kPageBlockSize, - kLog2PageSize, kK1, V_KRepeat, 1, kKVMemoryLayout, false, + kN0, kVectorSize>( page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k); v_dram_window.update_page_idx(v_offsets); @@ -761,12 +803,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync decltype(v_coord), VPageIndexDim, kPageBlockSize, - kLog2PageSize, 2 * kK1, V_KRepeat, 1, kKVMemoryLayout, false, + kN0, kVectorSize>( page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k); v_dram_window.update_page_idx(v_offsets); @@ -900,12 +942,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync decltype(v_coord), VPageIndexDim, kPageBlockSize, - kLog2PageSize, (2 + i_k1.value) * kK1, V_KRepeat, 1, kKVMemoryLayout, false, + kN0, kVectorSize>( page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k); v_dram_window.update_page_idx(v_offsets); @@ -957,12 +999,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync decltype(k_coord), 0, kPageBlockSize, - kLog2PageSize, 0, NRepeat, kN0 / NRepeat, kKVMemoryLayout, true, + kN0, kVectorSize>( page_idx, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k); k_dram_window.update_page_idx(k_offsets); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index f9dc94bc65..a489eabb73 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -107,16 +107,6 @@ struct BlockFmhaBatchPrefillPipelineProblem static_assert(kPageBlockSize > 0, "kPageBlockSize must be positive"); static_assert((kPageBlockSize & (kPageBlockSize - 1)) == 0, "kPageBlockSize must be power of two"); - static constexpr index_t kLog2PageSize = []() constexpr { - index_t shift = 0; - index_t val = kPageBlockSize_; - while(val > 1) - { - val >>= 1; - shift++; - } - return shift; - }(); static constexpr index_t kVectorSize = 16 / sizeof(KDataType_); // Dwordx4 static constexpr auto kKVMemoryLayout = Traits_::kKVMemoryLayout; @@ -126,6 +116,8 @@ struct BlockFmhaBatchPrefillPipelineProblem static_assert(BlockFmhaShape_::kQKHeaddim % kVectorSize == 0, "kQKHeaddim must be divisible by kVectorSize"); + static_assert(!(kPageBlockSize == 1 && kIsVectorizedLayout), + "page_size=1 only supports linear KV cache layout"); static_assert(!kIsVectorizedLayout || kPageBlockSize % kVectorSize == 0, "kPageBlockSize must be divisible by kVectorSize for vectorized layout"); static_assert(kIsGroupMode_, "Batch prefill requires group mode"); From e30207985aa5d9d0b53dc837904bf2ac3063a412 Mon Sep 17 00:00:00 2001 From: Estevan Vedovelli Date: Thu, 15 Jan 2026 09:35:24 -0500 Subject: [PATCH 10/99] Fix error when building with -DCMAKE_BUILD_TYPE=Debug (#3541) Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> --- .../ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 897892afb2..ee8527c458 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -59,7 +59,7 @@ struct BaseFlatmmPipelineAGmemBGmemCRegV1 return TailHandler(run_func, has_hot_loop); else { - assert(("Wrong TailNumber!", false)); + assert(false && "Wrong TailNumber!"); return TailHandler(run_func, has_hot_loop); } } From 6df2d70143c0c8934f3dd08ec7086d5fdff16499 Mon Sep 17 00:00:00 2001 From: Yung-sheng Tu <112800063+yungshengtu@users.noreply.github.com> Date: Thu, 15 Jan 2026 16:19:31 +0100 Subject: [PATCH 11/99] Implement device_gemm_universal_preshuffle_instance for RDNA4 (#3429) * add device_gemm_wmma_cshuffle_v3_b_preshuffle.hpp * add examples * add instances to test * remove duplicate code between examples --- example/01_gemm/CMakeLists.txt | 4 + .../01_gemm/gemm_wmma_fp16_bpreshuffle.cpp | 70 ++++ example/01_gemm/gemm_wmma_fp8_bpreshuffle.cpp | 72 +++++ .../run_gemm_wmma_bpreshuffle_example.inc | 206 ++++++++++++ ...ice_gemm_wmma_cshuffle_v3_b_preshuffle.hpp | 303 ++++++++++++++++++ .../gpu/gemm_universal_preshuffle.hpp | 43 ++- .../gpu/gemm_universal_preshuffle.inc | 47 ++- .../gemm_universal_preshuffle/CMakeLists.txt | 18 +- ...ersal_preshuffle_f8_f8_bf16_mk_wmma_mn.hpp | 106 ++++++ ...f8_bf16_mk_wmma_mn_default_instance_p1.cpp | 33 ++ ...f8_bf16_mk_wmma_mn_default_instance_p2.cpp | 33 ++ ...f8_bf16_mk_wmma_mn_default_instance_p3.cpp | 33 ++ ...f8_bf16_mk_wmma_mn_default_instance_p4.cpp | 33 ++ ...versal_preshuffle_f8_f8_f16_mk_wmma_mn.hpp | 106 ++++++ ..._f8_f16_mk_wmma_mn_default_instance_p1.cpp | 33 ++ ..._f8_f16_mk_wmma_mn_default_instance_p2.cpp | 33 ++ ..._f8_f16_mk_wmma_mn_default_instance_p3.cpp | 33 ++ ..._f8_f16_mk_wmma_mn_default_instance_p4.cpp | 33 ++ test/gemm_universal_preshuffle/CMakeLists.txt | 4 +- ...=> test_gemm_universal_preshuffle_fp8.cpp} | 0 20 files changed, 1229 insertions(+), 14 deletions(-) create mode 100644 example/01_gemm/gemm_wmma_fp16_bpreshuffle.cpp create mode 100644 example/01_gemm/gemm_wmma_fp8_bpreshuffle.cpp create mode 100644 example/01_gemm/run_gemm_wmma_bpreshuffle_example.inc create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_preshuffle.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn_default_instance_p1.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn_default_instance_p2.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn_default_instance_p3.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn_default_instance_p4.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn_default_instance_p1.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn_default_instance_p2.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn_default_instance_p3.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn_default_instance_p4.cpp rename test/gemm_universal_preshuffle/{test_gemm_universal_preshuffle_xdl_fp8.cpp => test_gemm_universal_preshuffle_fp8.cpp} (100%) diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index 2d65368d4f..aba462638e 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -149,3 +149,7 @@ add_example_executable(example_gemm_wmma_fp16_fp8_v3 gemm_wmma_fp16_fp8_v3.cpp) add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_fp8_v3) add_example_executable(example_gemm_wmma_fp16_pk_i4_v3_b_scale gemm_wmma_fp16_pk_i4_v3_b_scale.cpp) add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_pk_i4_v3_b_scale) +add_example_executable(example_gemm_wmma_fp8_bpreshuffle gemm_wmma_fp8_bpreshuffle.cpp) +add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp8_bpreshuffle) +add_example_executable(example_gemm_wmma_fp16_bpreshuffle gemm_wmma_fp16_bpreshuffle.cpp) +add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_bpreshuffle) diff --git a/example/01_gemm/gemm_wmma_fp16_bpreshuffle.cpp b/example/01_gemm/gemm_wmma_fp16_bpreshuffle.cpp new file mode 100644 index 0000000000..d03971e6ec --- /dev/null +++ b/example/01_gemm/gemm_wmma_fp16_bpreshuffle.cpp @@ -0,0 +1,70 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/stream_config.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_preshuffle.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/utility/get_id.hpp" +#include "ck/utility/scheduler_enum.hpp" + +#include +#include +#include + +using F16 = ck::half_t; +using F32 = float; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using CDataType = F16; +using ComputeTypeA = F16; +using ComputeTypeB = F16; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr bool PermuteA = false; +static constexpr bool PermuteB = false; +static constexpr int KPack = 8; // int4 -> 32, fp8 -> 16, fp16 -> 8 +// clang-format off +using DeviceOpInstance = + ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3_BPreshuffle< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CElementOp, GemmDefault, + 128, + 32, 128, 128, + 8, 8, + 16, 16, + 2, 2, + S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB>; +// clang-format on + +#include "run_gemm_wmma_bpreshuffle_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); } diff --git a/example/01_gemm/gemm_wmma_fp8_bpreshuffle.cpp b/example/01_gemm/gemm_wmma_fp8_bpreshuffle.cpp new file mode 100644 index 0000000000..8f8b380b93 --- /dev/null +++ b/example/01_gemm/gemm_wmma_fp8_bpreshuffle.cpp @@ -0,0 +1,72 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/stream_config.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_preshuffle.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/amd_ck_fp8.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/utility/get_id.hpp" +#include "ck/utility/scheduler_enum.hpp" + +#include +#include +#include + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; + +using ADataType = F8; +using BDataType = F8; +using AccDataType = F32; +using CShuffleDataType = F32; +using CDataType = F16; +using ComputeTypeA = F8; +using ComputeTypeB = F8; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr bool PermuteA = false; +static constexpr bool PermuteB = false; +static constexpr int KPack = 16; // int4 -> 32, fp8 -> 16, fp16 -> 8 +// clang-format off +using DeviceOpInstance = + ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3_BPreshuffle< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CElementOp, GemmDefault, + 256, + 32, 128, 256, + 16, 16, + 16, 16, + 2, 1, + S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 16, 16, 0, + S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 16, 16, 0, + 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB>; +// clang-format on + +#include "run_gemm_wmma_bpreshuffle_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); } diff --git a/example/01_gemm/run_gemm_wmma_bpreshuffle_example.inc b/example/01_gemm/run_gemm_wmma_bpreshuffle_example.inc new file mode 100644 index 0000000000..b1d73cfe10 --- /dev/null +++ b/example/01_gemm/run_gemm_wmma_bpreshuffle_example.inc @@ -0,0 +1,206 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +template +bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto M = problem_size.M; + auto N = problem_size.N; + auto K = problem_size.K; + auto StrideA = problem_size.StrideA; + auto StrideB = problem_size.StrideB; + auto StrideC = problem_size.StrideC; + auto KBatch = problem_size.KBatch; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor b_k_n_preshuffled(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + + switch(config.init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{0, 2}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "b_k_n_preshuffled: " << b_k_n_preshuffled.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + // do GEMM + auto device_op = DeviceOpInstance{}; + + // weight pre-shuffle + int NPerWmma = device_op.GetPreShuffleParameters(); + int KLane = ck::get_warp_size() / NPerWmma; + + int K0 = K / (KLane * KPack); + // K -> K0 KLane KPack + // N -> N0 NPerWmma + // N, K -> N0 K0 KLane NPerWmma KPack + int tempk; + for(int n = 0; n < N; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / NPerWmma; + int n1 = n % NPerWmma; + + int k0 = k / (KLane * KPack); + tempk = k % (KLane * KPack); + int k1 = tempk / KPack; + int k2 = tempk % KPack; + + int outputIndex = n0 * KPack * NPerWmma * KLane * K0 + k0 * KPack * NPerWmma * KLane + + k1 * KPack * NPerWmma + n1 * KPack + k2; + + b_k_n_preshuffled(outputIndex) = b_k_n(n * K + k); + } + } + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n_preshuffled.mData.data()); + c_m_n_device_buf.ToDevice(c_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + auto invoker = device_op.MakeInvoker(); + + auto argument = + device_op.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + std::cerr << device_op.GetTypeString() << " does not support this problem" << std::endl; + + return true; + } + + float ave_time = + invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 50, 50, false, 1}); + + bool pass = true; + if(config.do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + invoker.Run(argument, StreamConfig{nullptr, false, 0}); + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass &= ck::utils::check_err(c_m_n_device_result, + c_m_n_host_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); + } + + if(config.time_kernel) + { + ave_time = + invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50}); + + std::size_t flop = 2_uz * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << device_op.GetTypeString() << std::endl; + } + + return pass; +} + +bool run_gemm_splitk_example(int argc, char* argv[]) +{ + ProblemSizeSplitK problem_size{3840, 4096, 4096, 4096, 4096, 4096, 1}; + ExecutionConfig config; + + return parse_cmd_args(argc, argv, problem_size, config) && run_gemm(problem_size, config); +} diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_preshuffle.hpp new file mode 100644 index 0000000000..87bca24448 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_preshuffle.hpp @@ -0,0 +1,303 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include "ck/utility/scheduler_enum.hpp" +#include "ck/utility/tuple.hpp" + +#include +#include +#include +#include +#include + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemm_Wmma_CShuffleV3_BPreshuffle + : public DeviceGemmV2BPreshuffle +{ + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< + ALayout, + BLayout, + Tuple<>, + CLayout, + Tuple, + Tuple, + AccDataType, + CShuffleDataType, + Tuple<>, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB, + true>; + + using Argument = typename GridwiseGemm::Argument; + + int GetPreShuffleParameters() override { return NPerWmma; } + + using DeviceGemmCommon = DeviceGemm_Wmma_CShuffleV3_Common, + Tuple, + Tuple<>, + CDataType, + MPerBlock, + NPerBlock, + KPerBlock, + BlockSize, + AK1, + BK1, + GemmSpec, + CShuffleBlockTransferScalarPerVector, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + true>; + + // Invoker + using Invoker = typename DeviceGemmCommon::Invoker; + + static bool IsSupportedArgument(const Argument& arg) + { + if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0) + { + return false; + } + return DeviceGemmCommon::IsSupportedArgument(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + index_t GetKPerBlock() override { return KPerBlock; } + + bool GetPermuteA() override { return PermuteA; } + bool GetPermuteB() override { return PermuteB; } + + static auto MakeArgument(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{std::array{p_a}, + std::array{p_b}, + std::array{}, + static_cast(p_c), + M, + N, + K, + std::array{StrideA}, + std::array{StrideB}, + std::array{}, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(std::array{p_a}, + std::array{p_b}, + std::array{}, + static_cast(p_c), + M, + N, + K, + std::array{StrideA}, + std::array{StrideB}, + std::array{}, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGemm_BPreshuffle_Wmma_CShuffleV3" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock << "x" << NPerBlock << "x" << KPerBlock << ", " + << "WaveTile: " + << MPerWmma << "x" << NPerWmma << ", " + << "WaveMap: " + << MRepeat << "x" << NRepeat << ", " + << "VmemReadVec: " + << ABlockTransferSrcScalarPerVector << "x" << BBlockTransferSrcScalarPerVector << ", " + << "BlkGemmPipelineScheduler: " + << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", " + << "BlkGemmPipelineVersion: " + << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", " + << "BlkGemmPipelinePrefetchStages: " + << GridwiseGemm::BlockwiseGemmPipe::PrefetchStages << ", " + << "Kpack: " + << GridwiseGemm::KPack; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_preshuffle.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_preshuffle.hpp index d8d1776a44..1a5709854c 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_preshuffle.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_preshuffle.hpp @@ -3,18 +3,19 @@ #pragma once -#include -#include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#ifdef CK_USE_XDL +#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/utility/amd_ck_fp8.hpp" +#include "ck/utility/data_type.hpp" +#if defined(CK_USE_XDL) || defined(CK_USE_WMMA) #include "gemm_universal_preshuffle.inc" #endif +#include +#include + namespace ck { namespace tensor_operation { namespace device { @@ -51,7 +52,7 @@ struct DeviceOperationInstanceFactory< static auto GetInstances() { -#ifdef CK_USE_XDL +#if defined(CK_USE_XDL) || defined(CK_USE_WMMA) std::vector> op_ptrs; #if(defined(CK_ENABLE_BF16) && defined(CK_ENABLE_FP8)) if constexpr(is_same_v && is_same_v && @@ -60,6 +61,7 @@ struct DeviceOperationInstanceFactory< if constexpr(is_same_v && is_same_v && is_same_v) { +#ifdef CK_USE_XDL add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma32x32_mn_instances( op_ptrs); add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma32x32_mn_compute_instances( @@ -90,6 +92,17 @@ struct DeviceOperationInstanceFactory< op_ptrs); add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instances_part1( op_ptrs); +#endif +#ifdef CK_USE_WMMA + add_device_gemm_universal_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_default_instances_p1( + op_ptrs); + add_device_gemm_universal_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_default_instances_p2( + op_ptrs); + add_device_gemm_universal_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_default_instances_p3( + op_ptrs); + add_device_gemm_universal_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_default_instances_p4( + op_ptrs); +#endif } } #endif @@ -100,6 +113,7 @@ struct DeviceOperationInstanceFactory< if constexpr(is_same_v && is_same_v && is_same_v) { +#ifdef CK_USE_XDL add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instances_p1( op_ptrs); add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instances_p2( @@ -136,10 +150,21 @@ struct DeviceOperationInstanceFactory< op_ptrs); add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p6( op_ptrs); +#endif +#ifdef CK_USE_WMMA + add_device_gemm_universal_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_default_instances_p1( + op_ptrs); + add_device_gemm_universal_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_default_instances_p2( + op_ptrs); + add_device_gemm_universal_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_default_instances_p3( + op_ptrs); + add_device_gemm_universal_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_default_instances_p4( + op_ptrs); +#endif } } #endif -#endif // CK_USE_XDL +#endif // CK_USE_XDL || CK_USE_WMMA return op_ptrs; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_preshuffle.inc b/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_preshuffle.inc index b983913953..4f61958f34 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_preshuffle.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_preshuffle.inc @@ -13,8 +13,7 @@ namespace instance { using GemmF8F8BF16InstanceVector = std::vector>>&; -using GemmF8F8F16InstanceVector = std::vector>>&; +#ifdef CK_USE_XDL void add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma32x32_mn_instances( GemmF8F8BF16InstanceVector& instances); @@ -61,7 +60,32 @@ void add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp GemmF8F8BF16InstanceVector& instances); #endif + +#ifdef CK_USE_WMMA + +void add_device_gemm_universal_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_default_instances_p1( + GemmF8F8BF16InstanceVector& instances); + +void add_device_gemm_universal_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_default_instances_p2( + GemmF8F8BF16InstanceVector& instances); + +void add_device_gemm_universal_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_default_instances_p3( + GemmF8F8BF16InstanceVector& instances); + +void add_device_gemm_universal_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_default_instances_p4( + GemmF8F8BF16InstanceVector& instances); + +#endif + +#endif + #if(defined(CK_ENABLE_FP16) && defined(CK_ENABLE_FP8)) + +using GemmF8F8F16InstanceVector = std::vector>>&; + +#ifdef CK_USE_XDL + void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instances_p1( GemmF8F8F16InstanceVector& instances); void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instances_p2( @@ -99,6 +123,25 @@ void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_ GemmF8F8F16InstanceVector& instances); void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p6( GemmF8F8F16InstanceVector& instances); + +#endif + +#ifdef CK_USE_WMMA + +void add_device_gemm_universal_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_default_instances_p1( + GemmF8F8F16InstanceVector& instances); + +void add_device_gemm_universal_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_default_instances_p2( + GemmF8F8F16InstanceVector& instances); + +void add_device_gemm_universal_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_default_instances_p3( + GemmF8F8F16InstanceVector& instances); + +void add_device_gemm_universal_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_default_instances_p4( + GemmF8F8F16InstanceVector& instances); + +#endif + #endif } // namespace instance } // namespace device diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/CMakeLists.txt index a022b746ac..c8fc544c83 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/CMakeLists.txt @@ -1,7 +1,7 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS set(GEMM_UNIVERSAL_INSTANCES) # F8_F8_BF16 @@ -21,6 +21,10 @@ device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshu device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p5_default_instance.cpp device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_f8_bf16_mk_mfma32x32_mn_default_instance.cpp device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_f8_bf16_mk_mfma32x32_mn_comp_instance.cpp +device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn_default_instance_p1.cpp +device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn_default_instance_p2.cpp +device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn_default_instance_p3.cpp +device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn_default_instance_p4.cpp ) # F8_F8_F16 @@ -43,6 +47,10 @@ list(APPEND GEMM_UNIVERSAL_INSTANCES device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p4.cpp device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p5.cpp device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p6.cpp + device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn_default_instance_p1.cpp + device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn_default_instance_p2.cpp + device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn_default_instance_p3.cpp + device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn_default_instance_p4.cpp ) # F8_F8_F16 @@ -64,6 +72,10 @@ set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_f16/devic set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p4.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p5.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p6.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn_default_instance_p1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn_default_instance_p2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn_default_instance_p3.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn_default_instance_p4.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") # F8_F8_BF16 set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_f8_bf16_mk_mfma32x32_mn_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") @@ -81,5 +93,9 @@ set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_bf16/devi set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p5_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_nk_mn_comp_default_instance_p1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_nk_mn_comp_default_instance_p2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn_default_instance_p1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn_default_instance_p2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn_default_instance_p3.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn_default_instance_p4.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") add_instance_library(device_gemm_universal_preshuffle_instance ${GEMM_UNIVERSAL_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn.hpp new file mode 100644 index 0000000000..dd56980f0a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn.hpp @@ -0,0 +1,106 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_preshuffle.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/utility/amd_ck_fp8.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/utility/scheduler_enum.hpp" +#include "ck/utility/sequence.hpp" + +#include + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = f8_t; +using BF16 = bhalf_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +static constexpr auto v1 = BlockGemmPipelineVersion::v1; + +template +using device_gemm_universal_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_instances_p1 = std::tuple< + // clang-format off + //#####################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemmPipeSched| BlkGemmPipelineVer| ComputeTypeA| + //#####################################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| | | | + //#####################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 128, 128, 16, 16, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 8, 1, 32>, S<4, 4, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 128, 16, 16, 16, 16, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 256, 128, 16, 16, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 128, 16, 16, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 16, 16, 16, 16, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, Intrawave, v1, F8 > + // clang-format on + >; + +template +using device_gemm_universal_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_instances_p2 = std::tuple< + // clang-format off + //#####################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemmPipeSched| BlkGemmPipelineVer| ComputeTypeA| + //#####################################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| | | | + //#####################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 128, 16, 16, 16, 16, 8, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 128, 16, 16, 16, 16, 8, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 128, 16, 16, 16, 16, 16, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 64, 128, 16, 16, 16, 16, 16, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 128, 16, 16, 16, 16, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, Intrawave, v1, F8 > + // clang-format on + >; + +template +using device_gemm_universal_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_instances_p3 = std::tuple< + // clang-format off + //#####################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemmPipeSched| BlkGemmPipelineVer| ComputeTypeA| + //#####################################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| | | | + //#####################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 128, 256, 16, 16, 16, 16, 2, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 256, 16, 16, 16, 16, 4, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 128, 256, 16, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 256, 16, 16, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 512, 256, 16, 16, 16, 16, 1, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 256, 16, 16, 16, 16, 8, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 256, 16, 16, 16, 16, 8, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, Intrawave, v1, F8 > + // clang-format on + >; + +template +using device_gemm_universal_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_instances_p4 = std::tuple< + // clang-format off + //#####################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemmPipeSched| BlkGemmPipelineVer| ComputeTypeA| + //#####################################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| | | | + //#####################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 128, 16, 16, 16, 16, 7, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 128, 128, 16, 16, 16, 16, 7, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 64, 128, 16, 16, 16, 16, 7, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 128, 16, 16, 16, 16, 4, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 224, 128, 16, 16, 16, 16, 2, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 224, 128, 16, 16, 16, 16, 1, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 1>, Intrawave, v1, F8 > + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn_default_instance_p1.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn_default_instance_p1.cpp new file mode 100644 index 0000000000..e7e43db376 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn_default_instance_p1.cpp @@ -0,0 +1,33 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn.hpp" + +#include +#include + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_universal_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_default_instances_p1( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_universal_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_instances_p1{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn_default_instance_p2.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn_default_instance_p2.cpp new file mode 100644 index 0000000000..240548279c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn_default_instance_p2.cpp @@ -0,0 +1,33 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn.hpp" + +#include +#include + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_universal_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_default_instances_p2( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_universal_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_instances_p2{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn_default_instance_p3.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn_default_instance_p3.cpp new file mode 100644 index 0000000000..af936b3924 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn_default_instance_p3.cpp @@ -0,0 +1,33 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn.hpp" + +#include +#include + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_universal_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_default_instances_p3( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_universal_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_instances_p3{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn_default_instance_p4.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn_default_instance_p4.cpp new file mode 100644 index 0000000000..019f27e01a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn_default_instance_p4.cpp @@ -0,0 +1,33 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn.hpp" + +#include +#include + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_universal_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_default_instances_p4( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_universal_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_instances_p4{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn.hpp new file mode 100644 index 0000000000..b2b823d3bd --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn.hpp @@ -0,0 +1,106 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_preshuffle.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/utility/amd_ck_fp8.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/utility/scheduler_enum.hpp" +#include "ck/utility/sequence.hpp" + +#include + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = f8_t; +using F16 = half_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +static constexpr auto v1 = BlockGemmPipelineVersion::v1; + +template +using device_gemm_universal_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_instances_p1 = std::tuple< + // clang-format off + //#####################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemmPipeSched| BlkGemmPipelineVer| ComputeTypeA| + //#####################################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| | | | + //#####################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 128, 128, 16, 16, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 8, 1, 32>, S<4, 4, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 128, 16, 16, 16, 16, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 256, 128, 16, 16, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 128, 16, 16, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 16, 16, 16, 16, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, Intrawave, v1, F8 > + // clang-format on + >; + +template +using device_gemm_universal_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_instances_p2 = std::tuple< + // clang-format off + //#####################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemmPipeSched| BlkGemmPipelineVer| ComputeTypeA| + //#####################################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| | | | + //#####################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 128, 16, 16, 16, 16, 8, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 128, 16, 16, 16, 16, 8, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 128, 16, 16, 16, 16, 16, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 64, 128, 16, 16, 16, 16, 16, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 128, 16, 16, 16, 16, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, Intrawave, v1, F8 > + // clang-format on + >; + +template +using device_gemm_universal_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_instances_p3 = std::tuple< + // clang-format off + //#####################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemmPipeSched| BlkGemmPipelineVer| ComputeTypeA| + //#####################################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| | | | + //#####################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 128, 256, 16, 16, 16, 16, 2, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 256, 16, 16, 16, 16, 4, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 128, 256, 16, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 256, 16, 16, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 512, 256, 16, 16, 16, 16, 1, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 256, 16, 16, 16, 16, 8, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 256, 16, 16, 16, 16, 8, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, Intrawave, v1, F8 > + // clang-format on + >; + +template +using device_gemm_universal_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_instances_p4 = std::tuple< + // clang-format off + //#####################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemmPipeSched| BlkGemmPipelineVer| ComputeTypeA| + //#####################################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| | | | + //#####################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 128, 16, 16, 16, 16, 7, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 128, 128, 16, 16, 16, 16, 7, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 64, 128, 16, 16, 16, 16, 7, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 128, 16, 16, 16, 16, 4, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 224, 128, 16, 16, 16, 16, 2, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 224, 128, 16, 16, 16, 16, 1, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 1>, Intrawave, v1, F8 > + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn_default_instance_p1.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn_default_instance_p1.cpp new file mode 100644 index 0000000000..c1dc5f263b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn_default_instance_p1.cpp @@ -0,0 +1,33 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn.hpp" + +#include +#include + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_universal_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_default_instances_p1( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_universal_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_instances_p1{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn_default_instance_p2.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn_default_instance_p2.cpp new file mode 100644 index 0000000000..148edd3035 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn_default_instance_p2.cpp @@ -0,0 +1,33 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn.hpp" + +#include +#include + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_universal_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_default_instances_p2( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_universal_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_instances_p2{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn_default_instance_p3.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn_default_instance_p3.cpp new file mode 100644 index 0000000000..d9918d967c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn_default_instance_p3.cpp @@ -0,0 +1,33 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn.hpp" + +#include +#include + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_universal_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_default_instances_p3( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_universal_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_instances_p3{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn_default_instance_p4.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn_default_instance_p4.cpp new file mode 100644 index 0000000000..4635cdaec0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn_default_instance_p4.cpp @@ -0,0 +1,33 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn.hpp" + +#include +#include + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_universal_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_default_instances_p4( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_universal_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_instances_p4{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/test/gemm_universal_preshuffle/CMakeLists.txt b/test/gemm_universal_preshuffle/CMakeLists.txt index 1abc4391bb..fd13826a4c 100644 --- a/test/gemm_universal_preshuffle/CMakeLists.txt +++ b/test/gemm_universal_preshuffle/CMakeLists.txt @@ -2,8 +2,8 @@ # SPDX-License-Identifier: MIT if(GPU_TARGETS MATCHES "gfx9[45]|gfx12") - add_gtest_executable(test_gemm_universal_preshuffle_xdl_fp8 test_gemm_universal_preshuffle_xdl_fp8.cpp) + add_gtest_executable(test_gemm_universal_preshuffle_fp8 test_gemm_universal_preshuffle_fp8.cpp) if(result EQUAL 0) - target_link_libraries(test_gemm_universal_preshuffle_xdl_fp8 PRIVATE utility device_gemm_universal_preshuffle_instance) + target_link_libraries(test_gemm_universal_preshuffle_fp8 PRIVATE utility device_gemm_universal_preshuffle_instance) endif() endif() diff --git a/test/gemm_universal_preshuffle/test_gemm_universal_preshuffle_xdl_fp8.cpp b/test/gemm_universal_preshuffle/test_gemm_universal_preshuffle_fp8.cpp similarity index 100% rename from test/gemm_universal_preshuffle/test_gemm_universal_preshuffle_xdl_fp8.cpp rename to test/gemm_universal_preshuffle/test_gemm_universal_preshuffle_fp8.cpp From e1f2a440960b9025ecedd00ff7ac7553c4de9e10 Mon Sep 17 00:00:00 2001 From: Michal Kulikowski Date: Wed, 14 Jan 2026 17:24:07 +0100 Subject: [PATCH 12/99] [CK][Examples] Fixing stride issues in ck examples 14/65/68/69 by workaround - Bypassing hostTensor validation -Fixing args num in ck examples 68/69 Signed-off-by: Michal Kulikowski --- .../gemm_wmma_quantization_int8.cpp | 13 +++++++------ .../gemm_add_add_wmma_fp16.cpp | 9 +++++---- .../run_gemm_multiply_multiply_wp_example.inc | 6 ++++-- example/68_gemm_add/common.hpp | 2 +- example/68_gemm_add/run_gemm_add_example_wmma.inc | 5 +++-- example/68_gemm_add/run_gemm_add_example_xdl.inc | 5 +++-- example/69_gemm_add_relu/common.hpp | 2 +- .../run_gemm_add_relu_example_wmma.inc | 5 +++-- .../run_gemm_add_relu_example_xdl.inc | 5 +++-- 9 files changed, 30 insertions(+), 22 deletions(-) diff --git a/example/14_gemm_quantization/gemm_wmma_quantization_int8.cpp b/example/14_gemm_quantization/gemm_wmma_quantization_int8.cpp index cc5e3616ff..7437d0be9d 100644 --- a/example/14_gemm_quantization/gemm_wmma_quantization_int8.cpp +++ b/example/14_gemm_quantization/gemm_wmma_quantization_int8.cpp @@ -27,10 +27,11 @@ using ::ck::Tensor; template using S = ck::Sequence; -using I8 = int8_t; -using I32 = int32_t; -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; +using I8 = int8_t; +using I32 = int32_t; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ActivationOp = PassThrough; @@ -125,11 +126,11 @@ int main(int /* argc */, char* /* argv */[]) if(std::is_same::value) { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); + return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{}); } else { - return HostTensorDescriptor({row, col}, {1_uz, stride}); + return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{}); } }; diff --git a/example/65_gemm_multiply_multiply/gemm_add_add_wmma_fp16.cpp b/example/65_gemm_multiply_multiply/gemm_add_add_wmma_fp16.cpp index 24c58bb69a..1e3d946bad 100644 --- a/example/65_gemm_multiply_multiply/gemm_add_add_wmma_fp16.cpp +++ b/example/65_gemm_multiply_multiply/gemm_add_add_wmma_fp16.cpp @@ -31,8 +31,9 @@ using S = ck::Sequence; using F16 = ck::half_t; using F32 = float; -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; using A0DataType = F16; using B0DataType = F16; @@ -139,11 +140,11 @@ int main(int argc, char* argv[]) if(std::is_same::value) { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); + return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{}); } else { - return HostTensorDescriptor({row, col}, {1_uz, stride}); + return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{}); } }; diff --git a/example/65_gemm_multiply_multiply/run_gemm_multiply_multiply_wp_example.inc b/example/65_gemm_multiply_multiply/run_gemm_multiply_multiply_wp_example.inc index 2de3222380..10dce7fe64 100644 --- a/example/65_gemm_multiply_multiply/run_gemm_multiply_multiply_wp_example.inc +++ b/example/65_gemm_multiply_multiply/run_gemm_multiply_multiply_wp_example.inc @@ -5,6 +5,8 @@ int run_gemm_example(int argc, char* argv[]) { + using Bypass = ck::tensor_layout::BypassLayoutVerification; + bool do_verification = true; int init_method = 1; bool time_kernel = false; @@ -64,11 +66,11 @@ int run_gemm_example(int argc, char* argv[]) if(std::is_same::value) { - return ck::HostTensorDescriptor({row, col}, {stride, 1_uz}); + return ck::HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{}); } else { - return ck::HostTensorDescriptor({row, col}, {1_uz, stride}); + return ck::HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{}); } }; diff --git a/example/68_gemm_add/common.hpp b/example/68_gemm_add/common.hpp index 362dc2fff2..12d4b381b2 100644 --- a/example/68_gemm_add/common.hpp +++ b/example/68_gemm_add/common.hpp @@ -87,7 +87,7 @@ parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfi config.init_method = std::stoi(argv[2]); config.time_kernel = std::stoi(argv[3]); } - else if(argc == 13) + else if(argc == 11) { config.do_verification = std::stoi(argv[1]); config.init_method = std::stoi(argv[2]); diff --git a/example/68_gemm_add/run_gemm_add_example_wmma.inc b/example/68_gemm_add/run_gemm_add_example_wmma.inc index ba15d03e07..0f2cc08edf 100644 --- a/example/68_gemm_add/run_gemm_add_example_wmma.inc +++ b/example/68_gemm_add/run_gemm_add_example_wmma.inc @@ -6,6 +6,7 @@ bool run_gemm_add(const ProblemSize& problem_size, const ExecutionConfig& config) { using namespace ck::literals; + using Bypass = ck::tensor_layout::BypassLayoutVerification; auto& [M, N, K, StrideA, StrideB, StrideD, StrideE] = problem_size; @@ -13,11 +14,11 @@ bool run_gemm_add(const ProblemSize& problem_size, const ExecutionConfig& config [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { if(std::is_same::value) { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); + return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{}); } else { - return HostTensorDescriptor({row, col}, {1_uz, stride}); + return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{}); } }; diff --git a/example/68_gemm_add/run_gemm_add_example_xdl.inc b/example/68_gemm_add/run_gemm_add_example_xdl.inc index da22230a4e..186423d32f 100644 --- a/example/68_gemm_add/run_gemm_add_example_xdl.inc +++ b/example/68_gemm_add/run_gemm_add_example_xdl.inc @@ -6,6 +6,7 @@ bool run_gemm_add(const ProblemSize& problem_size, const ExecutionConfig& config) { using namespace ck::literals; + using Bypass = ck::tensor_layout::BypassLayoutVerification; auto& [M, N, K, StrideA, StrideB, StrideD, StrideE] = problem_size; @@ -13,11 +14,11 @@ bool run_gemm_add(const ProblemSize& problem_size, const ExecutionConfig& config [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { if(std::is_same::value) { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); + return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{}); } else { - return HostTensorDescriptor({row, col}, {1_uz, stride}); + return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{}); } }; diff --git a/example/69_gemm_add_relu/common.hpp b/example/69_gemm_add_relu/common.hpp index e54c5317ae..de84d69a5e 100644 --- a/example/69_gemm_add_relu/common.hpp +++ b/example/69_gemm_add_relu/common.hpp @@ -87,7 +87,7 @@ parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfi config.init_method = std::stoi(argv[2]); config.time_kernel = std::stoi(argv[3]); } - else if(argc == 13) + else if(argc == 11) { config.do_verification = std::stoi(argv[1]); config.init_method = std::stoi(argv[2]); diff --git a/example/69_gemm_add_relu/run_gemm_add_relu_example_wmma.inc b/example/69_gemm_add_relu/run_gemm_add_relu_example_wmma.inc index 8deac6dec8..c3cfd00ab3 100644 --- a/example/69_gemm_add_relu/run_gemm_add_relu_example_wmma.inc +++ b/example/69_gemm_add_relu/run_gemm_add_relu_example_wmma.inc @@ -6,6 +6,7 @@ bool run_gemm_add_relu(const ProblemSize& problem_size, const ExecutionConfig& config) { using namespace ck::literals; + using Bypass = ck::tensor_layout::BypassLayoutVerification; auto& [M, N, K, StrideA, StrideB, StrideD, StrideE] = problem_size; @@ -13,11 +14,11 @@ bool run_gemm_add_relu(const ProblemSize& problem_size, const ExecutionConfig& c [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { if(std::is_same::value) { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); + return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{}); } else { - return HostTensorDescriptor({row, col}, {1_uz, stride}); + return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{}); } }; diff --git a/example/69_gemm_add_relu/run_gemm_add_relu_example_xdl.inc b/example/69_gemm_add_relu/run_gemm_add_relu_example_xdl.inc index df7474bab5..cca85aa11c 100644 --- a/example/69_gemm_add_relu/run_gemm_add_relu_example_xdl.inc +++ b/example/69_gemm_add_relu/run_gemm_add_relu_example_xdl.inc @@ -6,6 +6,7 @@ bool run_gemm_add_relu(const ProblemSize& problem_size, const ExecutionConfig& config) { using namespace ck::literals; + using Bypass = ck::tensor_layout::BypassLayoutVerification; auto& [M, N, K, StrideA, StrideB, StrideD, StrideE] = problem_size; @@ -13,11 +14,11 @@ bool run_gemm_add_relu(const ProblemSize& problem_size, const ExecutionConfig& c [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { if(std::is_same::value) { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); + return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{}); } else { - return HostTensorDescriptor({row, col}, {1_uz, stride}); + return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{}); } }; From f57395689b92ca1f644e6e549e763f6c293ced22 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 15 Jan 2026 07:49:06 -0800 Subject: [PATCH 13/99] Bump rocm-docs-core[api_reference] from 1.31.1 to 1.31.2 in /docs/sphinx (#3577) Bumps [rocm-docs-core[api_reference]](https://github.com/ROCm/rocm-docs-core) from 1.31.1 to 1.31.2. - [Release notes](https://github.com/ROCm/rocm-docs-core/releases) - [Changelog](https://github.com/ROCm/rocm-docs-core/blob/develop/CHANGELOG.md) - [Commits](https://github.com/ROCm/rocm-docs-core/compare/v1.31.1...v1.31.2) --- updated-dependencies: - dependency-name: rocm-docs-core[api_reference] dependency-version: 1.31.2 dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/sphinx/requirements.in | 2 +- docs/sphinx/requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index b1ab09e6f7..b37c5c5652 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core[api_reference]==1.31.1 +rocm-docs-core[api_reference]==1.31.2 sphinxcontrib-bibtex==2.6.5 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index 099e9e439f..7f0d71cc4b 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -237,7 +237,7 @@ requests==2.32.3 # via # pygithub # sphinx -rocm-docs-core[api-reference]==1.31.1 +rocm-docs-core[api-reference]==1.31.2 # via -r requirements.in rpds-py==0.24.0 # via From 086a1f8861ef8c81db854e7f2749458b69121617 Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Thu, 15 Jan 2026 08:30:23 -0800 Subject: [PATCH 14/99] Add LLM-agnostic Docker and build analysis tools (#3576) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit introduces utility tools for building, testing, and analyzing Composable Kernel. The tools are designed to be LLM-agnostic and can be used with any AI assistant or directly from the command line. Tools Added: ============ 1. ck-docker - Docker container management - Start/stop ROCm-enabled containers - Build targets with CMake + Ninja - Run tests with gtest filters - Auto-detect GPU targets (gfx950, gfx942, etc.) - Per-user, per-branch container naming to avoid conflicts 2. ck-build-analysis - Build time profiling - Uses Clang's -ftime-trace for compilation analysis - Aggregates statistics across multiple trace files - Identifies template instantiation bottlenecks - Generates detailed Markdown reports with: * Compilation phase breakdown * Top expensive instantiations * Template family analysis * Data-driven optimization recommendations - Configurable granularity (1µs to 500µs) - PEP 723 compliant Python script with auto-dependency management via uv Key Features: ============= - LLM-agnostic design (works with any AI assistant) - Zero-configuration setup with automatic dependency installation - Comprehensive documentation in script/tools/README*.md - Security hardening (input validation, no command injection) - Multi-file trace aggregation for accurate build analysis - Jinja2-based report generation for customizable output Implementation: =============== - script/tools/ck-docker - Main Docker orchestration script - script/tools/ck-build-analysis - Build analysis orchestration - script/tools/common.sh - Shared utilities (container mgmt, GPU detection) - script/tools/analyze_build_trace.py - PEP 723 compliant Python analyzer - script/tools/templates/ - Jinja2 templates for report generation - script/tools/README*.md - Comprehensive documentation Directory Structure: ==================== script/tools/ ├── README.md # Main overview ├── README_ck-docker.md # ck-docker documentation ├── README_ck-build-analysis.md # ck-build-analysis documentation ├── ck-docker # Docker orchestration script ├── ck-build-analysis # Build analysis orchestration ├── common.sh # Shared utilities ├── analyze_build_trace.py # Python analyzer (PEP 723) └── templates/ └── build_analysis_report.md.jinja # Report template The tools follow Unix philosophy: do one thing well, compose easily, and work from both CLI and programmatic contexts. --- script/tools/README.md | 78 ++++ script/tools/README_ck-build-analysis.md | 168 +++++++++ script/tools/README_ck-docker.md | 80 ++++ script/tools/analyze_build_trace.py | 347 ++++++++++++++++++ script/tools/ck-build-analysis | 237 ++++++++++++ script/tools/ck-docker | 294 +++++++++++++++ script/tools/common.sh | 97 +++++ .../templates/build_analysis_report.md.jinja | 125 +++++++ 8 files changed, 1426 insertions(+) create mode 100644 script/tools/README.md create mode 100644 script/tools/README_ck-build-analysis.md create mode 100644 script/tools/README_ck-docker.md create mode 100755 script/tools/analyze_build_trace.py create mode 100755 script/tools/ck-build-analysis create mode 100755 script/tools/ck-docker create mode 100644 script/tools/common.sh create mode 100644 script/tools/templates/build_analysis_report.md.jinja diff --git a/script/tools/README.md b/script/tools/README.md new file mode 100644 index 0000000000..e5bf91cedc --- /dev/null +++ b/script/tools/README.md @@ -0,0 +1,78 @@ +# Composable Kernel Tools + +This directory contains utility tools for building, testing, and analyzing Composable Kernel. + +These tools are designed to be LLM-agnostic and can be used with any AI assistant or directly from the command line. + +## Available Tools + +### ck-docker + +Build and test composable_kernel in Docker with ROCm support. + +See [README_ck-docker.md](README_ck-docker.md) for details. + +**Quick start:** +```bash +# Add to PATH +export PATH="$PATH:$PWD/script/tools" + +# Start container and build +ck-docker start +ck-docker build test_amdgcn_mma +ck-docker test test_amdgcn_mma +``` + +### ck-build-analysis + +Analyze Composable Kernel build times using Clang's -ftime-trace profiler. + +See [README_ck-build-analysis.md](README_ck-build-analysis.md) for details. + +**Quick start:** +```bash +# Add to PATH +export PATH="$PATH:$PWD/script/tools" + +# Analyze build time +ck-build-analysis example_convnd_fwd_xdl_fp8 +``` + +## LLM Assistant Integration + +These tools can be used as-is with any LLM assistant by providing the tool documentation to the assistant. The assistant can then invoke these tools on your behalf. + +For example, you can ask: +- "Start the docker container" +- "Build and test test_amdgcn_mma" +- "Analyze build time for example_convnd_fwd_xdl_fp8" + +The assistant will translate your natural language request into the appropriate tool invocation. + +## Dependencies + +- **ck-docker**: Requires Docker and ROCm-capable GPU (for running tests) +- **ck-build-analysis**: Requires Docker, automatically installs Python dependencies (jinja2) via `uv` + +## Directory Structure + +``` +script/tools/ +├── README.md # This file +├── README_ck-docker.md # Documentation for ck-docker +├── README_ck-build-analysis.md # Documentation for ck-build-analysis +├── ck-docker # Docker container management tool +├── ck-build-analysis # Build time analysis tool +├── common.sh # Shared utilities for bash scripts +├── analyze_build_trace.py # Python script for trace analysis (PEP 723 compliant) +└── templates/ + └── build_analysis_report.md.jinja # Jinja2 template for analysis reports +``` + +## Contributing + +When adding new tools to this directory: +1. Keep them LLM-agnostic (avoid hardcoding references to specific AI assistants) +2. Provide clear command-line usage documentation +3. Include examples for both CLI and LLM assistant usage +4. Follow the existing naming convention and structure diff --git a/script/tools/README_ck-build-analysis.md b/script/tools/README_ck-build-analysis.md new file mode 100644 index 0000000000..d52e4eb2c7 --- /dev/null +++ b/script/tools/README_ck-build-analysis.md @@ -0,0 +1,168 @@ +# ck-build-analysis + +Analyze Composable Kernel build times using Clang's -ftime-trace profiler. + +## Terminal Usage + +Direct command-line usage: + +```bash +# From composable_kernel directory +script/tools/ck-build-analysis example_convnd_fwd_xdl_fp8 +script/tools/ck-build-analysis example_convnd_fwd_xdl_fp8 --granularity=1 +script/tools/ck-build-analysis example_convnd_fwd_xdl_fp8 --granularity=1 --output=my_report.md + +# Or add to PATH +export PATH="$PATH:$PWD/script/tools" +ck-build-analysis example_convnd_fwd_xdl_fp8 +``` + +## LLM Assistant Integration + +If using an LLM assistant, you can ask in natural language: +- "Analyze build time for example_convnd_fwd_xdl_fp8" +- "Profile the compilation of test_amdgcn_mma with 1us granularity" +- "Generate a build time report for example_gemm_xdl" + +## Commands + +``` +ck-build-analysis [options] + +Options: + --granularity=N Time trace granularity in microseconds (default: 1) + --output=FILE Output report filename (default: build_time_analysis_report.md) + --name=NAME Docker container name (default: from CK_CONTAINER_NAME or auto-generated) + --no-reconfigure Skip CMake reconfiguration if build exists + --help Show this help message +``` + +## What It Does + +1. **Configures CMake** with `-ftime-trace` and custom granularity +2. **Builds the target** using Ninja in Docker +3. **Analyzes the trace** JSON file for template instantiation patterns +4. **Generates a report** with: + - Compilation phase breakdown + - Top expensive individual instantiations + - Template families ranked by total time and count + - Key insights and optimization recommendations + - Complete statistics + +## Configuration + +- **Container**: Uses ck-docker container (auto-starts if needed) +- **Granularity**: Default 1us (100% template coverage, best balance) +- **Output**: Markdown report in project root + +## Environment + +```bash +export CK_CONTAINER_NAME=my_build # Override container name +export CK_BUILD_ANALYSIS_GRANULARITY=1 # Default granularity in microseconds +``` + +## Examples + +```bash +# Complete template analysis with default granularity (1us - recommended) +ck-build-analysis example_convnd_fwd_xdl_fp8 + +# Quick daily check (10us granularity, captures most expensive templates) +ck-build-analysis example_convnd_fwd_xdl_fp8 --granularity=10 + +# Maximum detail (0us granularity, includes LLVM internals) +ck-build-analysis example_convnd_fwd_xdl_fp8 --granularity=0 + +# High-level overview (500us granularity, major bottlenecks only) +ck-build-analysis example_convnd_fwd_xdl_fp8 --granularity=500 + +# Custom output filename +ck-build-analysis example_convnd_fwd_xdl_fp8 --output=fp8_conv_analysis.md + +# Analyze test target +ck-build-analysis test_amdgcn_mma + +# Use existing build (skip reconfigure) +ck-build-analysis example_convnd_fwd_xdl_fp8 --no-reconfigure +``` + +## Output + +The report includes: +- **Executive Summary**: Total time, events, instantiations, unique templates +- **Compilation Phases**: InstantiateFunction, Frontend, Backend, Optimizer, etc. +- **Top 30 Individual Instantiations**: Most expensive single templates +- **Template Families**: Grouped by total time and instantiation count +- **Key Insights**: What's slow and why +- **Optimization Recommendations**: Short, medium, and long-term strategies +- **Detailed Statistics**: Averages, medians, distributions + +## Granularity Trade-offs + +| Granularity | Template Coverage | Use Case | +|-------------|-------------------|----------| +| **0us** | All templates + sub-us compiler internals | LLVM internals debugging, very large files, higher overhead | +| **1us (default)** | **All templates** | **Default: Complete template analysis with low overhead** | +| **10us** | Most expensive templates | Daily quick checks, smaller files, minimal overhead | +| **50-100us** | Top bottlenecks | Balanced detail/size, suitable for CI/CD | +| **500us** | High-level phases only | Not recommended for template analysis | + +**Recommended default**: 1us captures all template instantiations with minimal overhead + +## Notes + +- **0us and 1us capture all templates** - 0us adds sub-microsecond compiler internals +- **1us is the sweet spot**: complete template coverage, filters noise, low overhead +- **10us is practical** for daily use: captures most expensive templates, smaller files +- **500us loses most template instantiation data** - only use for high-level phase breakdown +- Finer granularity = more events = larger files + higher build time overhead +- For template-heavy C++ codebases like CK: **use 1us for analysis, 10us for daily checks** + +## Implementation Details + +### PEP 723 Compliance with Automatic Dependency Management + +The analysis script (`analyze_build_trace.py`) is PEP 723 compliant with inline dependency metadata: + +```python +# /// script +# requires-python = ">=3.8" +# dependencies = [ +# "jinja2>=3.0.0", +# ] +# /// +``` + +**The tool automatically installs and uses `uv`**, which provides: +- ✅ Zero-configuration dependency management +- ✅ Automatic installation of jinja2 from PEP 723 metadata +- ✅ Isolated dependency environment (no system pollution) +- ✅ Fast caching for subsequent runs + +**No manual setup required!** The first time you run the tool, it will: +1. Detect if `uv` is installed in the container +2. If not, automatically install it via Ubuntu packages (pipx install uv) +3. Use `uv run` to execute the analysis with auto-managed dependencies + +On subsequent runs, `uv` will already be available and dependencies will be cached. + +Installation is done through Ubuntu's package manager for security and reliability. + +### Components + +- **ck-build-analysis** - Main bash script that orchestrates Docker, CMake, and analysis +- **analyze_build_trace.py** - PEP 723 compliant Python script for trace analysis +- **templates/build_analysis_report.md.jinja** - Jinja2 template for report generation + +### Standalone Usage + +The Python script can also be run independently: + +```bash +# With uv (recommended - auto-installs dependencies from PEP 723 metadata) +uv run script/tools/analyze_build_trace.py trace.json report.md target 100 22 templates/ + +# With pipx (alternative - also auto-installs dependencies) +pipx run script/tools/analyze_build_trace.py trace.json report.md target 100 22 templates/ +``` diff --git a/script/tools/README_ck-docker.md b/script/tools/README_ck-docker.md new file mode 100644 index 0000000000..c432c1dba9 --- /dev/null +++ b/script/tools/README_ck-docker.md @@ -0,0 +1,80 @@ +# ck-docker + +Build and test composable_kernel in Docker with ROCm support. + +## Terminal Usage + +Direct command-line usage: + +```bash +# From composable_kernel directory +script/tools/ck-docker start +script/tools/ck-docker build test_amdgcn_mma +script/tools/ck-docker test test_amdgcn_mma --gtest_filter=*Fp16* +script/tools/ck-docker status +script/tools/ck-docker shell + +# Or add to PATH +export PATH="$PATH:$PWD/script/tools" +ck-docker start +``` + +## LLM Assistant Integration + +If using an LLM assistant, you can ask in natural language: +- "Start the docker container" +- "Build test_amdgcn_mma" +- "Run test_amdgcn_mma with filter *Fp16*" +- "Check container status" +- "Open a shell in the container" + +## Commands + +``` +ck-docker start [name] Start Docker container +ck-docker build [target] [--reconfigure] Build target (optionally reconfigure CMake) +ck-docker test [options] Run test +ck-docker shell [name] Interactive shell +ck-docker status [name] Check status +ck-docker stop [name] Stop container +``` + +## Configuration + +- **Image**: rocm/composable_kernel:ck_ub24.04_rocm7.0.1 +- **GPU**: Auto-detected via rocminfo (fallback: gfx950) +- **Compiler**: /opt/rocm/llvm/bin/clang++ +- **Build**: Ninja + CMake (Release) +- **Mount**: Current directory → /workspace +- **Container Name**: Auto-generated as `ck__` to avoid clashes + +## Environment + +```bash +export CK_CONTAINER_NAME=my_build # Override default container name +export CK_DOCKER_IMAGE=rocm/composable_kernel:ck_ub24.04_rocm7.0.1 # Override Docker image +export GPU_TARGET=gfx942 # Override GPU target detection +``` + +## Examples + +```bash +# Start container +ck-docker start + +# Build and run test +ck-docker build test_amdgcn_mma +ck-docker test test_amdgcn_mma + +# Force clean CMake reconfiguration and build +ck-docker build --reconfigure test_amdgcn_mma + +# Custom container +ck-docker start my_build +ck-docker build test_amdgcn_mma --name my_build +ck-docker test test_amdgcn_mma --name my_build + +# Debug +ck-docker shell +ck-docker status +``` diff --git a/script/tools/analyze_build_trace.py b/script/tools/analyze_build_trace.py new file mode 100755 index 0000000000..3597132f32 --- /dev/null +++ b/script/tools/analyze_build_trace.py @@ -0,0 +1,347 @@ +#!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# /// script +# requires-python = ">=3.8" +# dependencies = [ +# "jinja2>=3.0.0", +# ] +# /// +""" +Build Time Analysis Tool for Composable Kernel + +Analyzes Clang -ftime-trace output to identify template instantiation +bottlenecks and generate comprehensive build time reports. +""" + +import json +import os +import re +import sys +from collections import defaultdict +from datetime import datetime + +try: + from jinja2 import Environment, FileSystemLoader +except ImportError: + print("Error: jinja2 is required but not installed.", file=sys.stderr) + print("Install with: apt-get install python3-jinja2", file=sys.stderr) + print("Or with pip: pip install jinja2", file=sys.stderr) + sys.exit(1) + + +def parse_arguments(): + """Parse command-line arguments.""" + if len(sys.argv) < 7: + print( + "Usage: analyze_build_trace.py " + ) + print( + " trace_files_or_dir: Comma-separated list of trace files OR directory containing .json files" + ) + sys.exit(1) + + return { + "trace_input": sys.argv[1], + "output_file": sys.argv[2], + "target": sys.argv[3], + "granularity": sys.argv[4], + "build_time": sys.argv[5], + "template_dir": sys.argv[6], + } + + +def find_trace_files(trace_input): + """Find all trace files from input (file list, single file, or directory).""" + trace_files = [] + + # Check if it's a directory + if os.path.isdir(trace_input): + print(f"Scanning directory: {trace_input}") + for root, dirs, files in os.walk(trace_input): + for file in files: + # Include .cpp.json and .hip.json, exclude compile_commands.json and CMake files + if file.endswith((".cpp.json", ".hip.json")) and "CMakeFiles" in root: + trace_files.append(os.path.join(root, file)) + trace_files.sort() + # Check if it's a comma-separated list + elif "," in trace_input: + trace_files = [f.strip() for f in trace_input.split(",")] + # Single file + else: + trace_files = [trace_input] + + # Filter out non-existent files + valid_files = [f for f in trace_files if os.path.isfile(f)] + + if not valid_files: + print(f"Error: No valid trace files found in: {trace_input}", file=sys.stderr) + sys.exit(1) + + print(f"Found {len(valid_files)} trace file(s)") + return valid_files + + +def load_trace_data(trace_files): + """Load and parse multiple trace JSON files.""" + all_data = [] + + for trace_file in trace_files: + print(f" Loading: {trace_file}") + try: + with open(trace_file, "r") as f: + data = json.load(f) + # Get file basename for tracking + file_name = os.path.basename(trace_file) + all_data.append({"file": file_name, "path": trace_file, "data": data}) + except Exception as e: + print(f" Warning: Failed to load {trace_file}: {e}", file=sys.stderr) + + return all_data + + +def process_events(all_trace_data): + """Process trace events from multiple files and extract statistics.""" + print("Processing events from all files...") + + template_stats = defaultdict(lambda: {"count": 0, "total_dur": 0}) + phase_stats = defaultdict(int) + top_individual = [] + file_stats = [] + total_events = 0 + + for trace_info in all_trace_data: + file_name = trace_info["file"] + data = trace_info["data"] + events = data.get("traceEvents", []) + + file_template_time = 0 + file_event_count = len(events) + total_events += file_event_count + + print(f" Processing {file_name}: {file_event_count:,} events") + + for event in events: + name = event.get("name", "") + dur = int(event.get("dur", 0)) # Keep as integer microseconds + + if name and dur > 0: + phase_stats[name] += dur + + if name in ["InstantiateFunction", "InstantiateClass"]: + detail = event.get("args", {}).get("detail", "") + top_individual.append( + {"detail": detail, "dur": dur, "type": name, "file": file_name} + ) + + file_template_time += dur + + # Extract template name (everything before '<' or '(') + match = re.match(r"^([^<(]+)", detail) + if match: + template_name = match.group(1).strip() + # Normalize template names + template_name = re.sub(r"^ck::", "", template_name) + template_name = re.sub(r"^std::", "std::", template_name) + + template_stats[template_name]["count"] += 1 + template_stats[template_name]["total_dur"] += dur + + file_stats.append( + { + "name": file_name, + "events": file_event_count, + "template_time": file_template_time, + } + ) + + return template_stats, phase_stats, top_individual, file_stats, total_events + + +def prepare_template_data(template_stats, phase_stats, top_individual, file_stats): + """Prepare and calculate derived statistics for template rendering.""" + print("Sorting data...") + + # Sort data + sorted_phases = sorted(phase_stats.items(), key=lambda x: x[1], reverse=True) + top_individual.sort(key=lambda x: x["dur"], reverse=True) + file_stats.sort(key=lambda x: x["template_time"], reverse=True) + + # Calculate totals + total_template_time = sum(s["total_dur"] for s in template_stats.values()) + total_trace_time = sum(phase_stats.values()) + total_inst = sum(s["count"] for s in template_stats.values()) + + # Prepare templates by time with calculated fields + templates_by_time = [] + for name, stats in sorted( + template_stats.items(), key=lambda x: x[1]["total_dur"], reverse=True + ): + templates_by_time.append( + ( + name, + { + "count": stats["count"], + "total_dur": stats["total_dur"], + "avg": stats["total_dur"] // stats["count"] + if stats["count"] > 0 + else 0, + "pct": 100 * stats["total_dur"] / total_template_time + if total_template_time > 0 + else 0, + }, + ) + ) + + # Prepare templates by count + templates_by_count = [] + for name, stats in sorted( + template_stats.items(), key=lambda x: x[1]["count"], reverse=True + ): + templates_by_count.append( + ( + name, + { + "count": stats["count"], + "total_dur": stats["total_dur"], + "avg": stats["total_dur"] // stats["count"] + if stats["count"] > 0 + else 0, + }, + ) + ) + + # Add friendly type names to individual instantiations + for inst in top_individual: + inst["inst_type"] = "Func" if inst["type"] == "InstantiateFunction" else "Class" + + # Calculate additional metrics + median_count = 0 + if len(template_stats) > 0: + median_count = sorted([s["count"] for s in template_stats.values()])[ + len(template_stats) // 2 + ] + + top10_pct = 0 + if len(templates_by_time) >= 10: + top10_pct = ( + 100 + * sum(s[1]["total_dur"] for s in templates_by_time[:10]) + / total_template_time + ) + + return { + "sorted_phases": sorted_phases, + "top_individual": top_individual, + "templates_by_time": templates_by_time, + "templates_by_count": templates_by_count, + "total_template_time": total_template_time, + "total_trace_time": total_trace_time, + "total_inst": total_inst, + "median_count": median_count, + "top10_pct": top10_pct, + "unique_families": len(template_stats), + "file_stats": file_stats, + } + + +def setup_jinja_environment(template_dir): + """Set up Jinja2 environment with custom filters.""" + env = Environment(loader=FileSystemLoader(template_dir)) + + def format_number(value): + """Format number with thousand separators.""" + return f"{value:,}" + + def truncate(value, length): + """Truncate string to length with ellipsis.""" + if len(value) > length: + return value[: length - 3] + "..." + return value + + def pad(value, length): + """Pad string to specified length.""" + return f"{value:<{length}}" + + def us_to_ms(value): + """Convert microseconds to milliseconds.""" + return value / 1000.0 + + def us_to_s(value): + """Convert microseconds to seconds.""" + return value / 1000000.0 + + env.filters["format_number"] = format_number + env.filters["truncate"] = truncate + env.filters["pad"] = pad + env.filters["us_to_ms"] = us_to_ms + env.filters["us_to_s"] = us_to_s + + return env + + +def generate_report(env, data, args, total_events, num_files): + """Generate the final report using Jinja2 template.""" + print("Rendering report with Jinja2...") + + template = env.get_template("build_analysis_report.md.jinja") + + report_content = template.render( + timestamp=datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + target=args["target"], + granularity=args["granularity"], + build_time=args["build_time"], + total_events=total_events, + num_files=num_files, + total_instantiations=data["total_inst"], + unique_families=data["unique_families"], + total_trace_time=data["total_trace_time"], + total_template_time=data["total_template_time"], + phases=data["sorted_phases"], + top_individual=data["top_individual"], + templates_by_time=data["templates_by_time"], + templates_by_count=data["templates_by_count"], + median_count=data["median_count"], + top10_pct=data["top10_pct"], + file_stats=data["file_stats"], + ) + + return report_content + + +def main(): + """Main entry point for the analysis tool.""" + args = parse_arguments() + + # Find and load trace files + trace_files = find_trace_files(args["trace_input"]) + all_trace_data = load_trace_data(trace_files) + + # Process events from all files + template_stats, phase_stats, top_individual, file_stats, total_events = ( + process_events(all_trace_data) + ) + + # Prepare template data + data = prepare_template_data( + template_stats, phase_stats, top_individual, file_stats + ) + + # Setup Jinja2 environment + env = setup_jinja_environment(args["template_dir"]) + + # Generate report + report_content = generate_report(env, data, args, total_events, len(all_trace_data)) + + # Write output + with open(args["output_file"], "w") as f: + f.write(report_content) + + print(f"Report generated: {args['output_file']}") + print(f"Report size: {len(report_content):,} bytes") + print(f"Analyzed {len(all_trace_data)} file(s) with {total_events:,} total events") + + +if __name__ == "__main__": + main() diff --git a/script/tools/ck-build-analysis b/script/tools/ck-build-analysis new file mode 100755 index 0000000000..cd06a1796f --- /dev/null +++ b/script/tools/ck-build-analysis @@ -0,0 +1,237 @@ +#!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# CK Build Analysis Tool - Analyze build times using -ftime-trace + +set -e +set -o pipefail + +# Find script directory and load common utilities +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +source "${SCRIPT_DIR}/common.sh" + +# Initialize configuration +PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}") +CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}") + +# Default settings +GRANULARITY="${CK_BUILD_ANALYSIS_GRANULARITY:-1}" +OUTPUT_FILE="build_time_analysis_report.md" +RECONFIGURE=true + +# Help message +show_help() { + cat << EOF +CK Build Analysis - Analyze build times using Clang -ftime-trace + +Usage: ck-build-analysis [options] + +Arguments: + target Build target to analyze (e.g., example_convnd_fwd_xdl_fp8) + +Options: + --granularity=N Time trace granularity in microseconds (default: 1) + --output=FILE Output report filename (default: build_time_analysis_report.md) + --name=NAME Docker container name (default: ${CONTAINER_NAME}) + --no-reconfigure Skip CMake reconfiguration if build exists + --help Show this help message + +Examples: + ck-build-analysis example_convnd_fwd_xdl_fp8 + ck-build-analysis example_convnd_fwd_xdl_fp8 --granularity=10 + ck-build-analysis test_amdgcn_mma --granularity=1 --output=mma_test_analysis.md + +Granularity Guide: + 0 - Everything: All compiler events including sub-microsecond operations + Use for LLVM internals debugging. Large files, higher overhead. + + 1 (default) - Complete template coverage: Captures all template instantiations + Best balance - filters sub-microsecond noise, low overhead + + 10 - Daily use: Captures most expensive templates, smaller files + Good for quick checks and routine analysis + + 50-100 - Intermediate: Balanced between detail and file size + Suitable for CI/CD tracking + + 500 - High-level only: Major compilation phases, minimal detail + Not recommended for template analysis (loses most instantiations) + + Recommendation: Use 1us (default) for template analysis, 10us for quick checks. +EOF +} + +# Parse arguments +TARGET="" +while [[ $# -gt 0 ]]; do + case $1 in + --granularity=*) + GRANULARITY="${1#*=}" + shift + ;; + --output=*) + OUTPUT_FILE="${1#*=}" + shift + ;; + --name=*) + CONTAINER_NAME="${1#*=}" + shift + ;; + --no-reconfigure) + RECONFIGURE=false + shift + ;; + --help|-h) + show_help + exit 0 + ;; + -*) + echo "Unknown option: $1" + show_help + exit 1 + ;; + *) + if [ -z "$TARGET" ]; then + TARGET="$1" + else + echo "Error: Multiple targets specified" + show_help + exit 1 + fi + shift + ;; + esac +done + +if [ -z "$TARGET" ]; then + echo "Error: No target specified" + echo "" + show_help + exit 1 +fi + +# Validate OUTPUT_FILE to prevent path traversal +if [[ "$OUTPUT_FILE" =~ / ]] || [[ "$OUTPUT_FILE" =~ \.\. ]]; then + echo "Error: OUTPUT_FILE must be a simple filename (no path separators or .. allowed)" + echo "Invalid: $OUTPUT_FILE" + exit 1 +fi + +echo "═══════════════════════════════════════════════════════════════" +echo " CK Build Time Analysis" +echo "═══════════════════════════════════════════════════════════════" +echo "Target: $TARGET" +echo "Granularity: ${GRANULARITY}us" +echo "Container: $CONTAINER_NAME" +echo "Output: $OUTPUT_FILE" +echo "═══════════════════════════════════════════════════════════════" +echo "" + +# Ensure container is running +ensure_container_running "${CONTAINER_NAME}" "${SCRIPT_DIR}" + +# Configure CMake with -ftime-trace if needed +if [ "$RECONFIGURE" = true ] || ! docker exec "${CONTAINER_NAME}" test -f /workspace/build/build.ninja 2>/dev/null; then + echo "" + echo "Configuring CMake with -ftime-trace (granularity=${GRANULARITY}us)..." + + GPU_TARGET=$(detect_gpu_target "${CONTAINER_NAME}") + + docker exec -e GPU_TARGET="${GPU_TARGET}" -e GRANULARITY="${GRANULARITY}" "${CONTAINER_NAME}" bash -c ' + cd /workspace || exit 1 + rm -rf /workspace/build + mkdir /workspace/build + cd /workspace/build || exit 1 + cmake .. -GNinja \ + -DGPU_TARGETS="${GPU_TARGET}" \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++ \ + -DCMAKE_CXX_FLAGS="-ftime-trace -ftime-trace-granularity=${GRANULARITY}" \ + -DCMAKE_HIP_FLAGS="-ftime-trace -ftime-trace-granularity=${GRANULARITY}" \ + -DBUILD_TESTING=ON 2>&1 | tail -20 + ' + echo "CMake configuration complete" +fi + +# Build the target +echo "" +echo "Building target: $TARGET" +echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + +BUILD_START=$(date +%s) +docker exec -e TARGET="${TARGET}" "${CONTAINER_NAME}" bash -c 'cd /workspace/build && time ninja "${TARGET}" 2>&1' +BUILD_END=$(date +%s) +BUILD_TIME=$((BUILD_END - BUILD_START)) + +echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" +echo "Build completed in ${BUILD_TIME} seconds" + +# Find all trace JSON files for the target +echo "" +echo "Locating trace files..." + +# Count trace files +TRACE_COUNT=$(docker exec -e TARGET="${TARGET}" "${CONTAINER_NAME}" bash -c ' + find /workspace/build -type f \( -name "*.cpp.json" -o -name "*.hip.json" \) 2>/dev/null | \ + grep -vF "compile_commands.json" | wc -l +') + +if [ "$TRACE_COUNT" -eq 0 ]; then + echo "Error: Could not find any trace files in /workspace/build" + echo "Expected .cpp.json or .hip.json files from -ftime-trace compilation" + exit 1 +fi + +echo "Found ${TRACE_COUNT} trace file(s) in build directory" + +# We'll pass the build directory to the Python script +BUILD_DIR="/workspace/build" + +# Generate analysis report +echo "" +echo "Generating analysis report..." + +# Copy analysis script and templates to container +docker cp "${SCRIPT_DIR}/analyze_build_trace.py" "${CONTAINER_NAME}:/tmp/analyze_build_trace.py" +docker cp "${SCRIPT_DIR}/templates" "${CONTAINER_NAME}:/tmp/ck_build_analysis_templates" + +# Check if uv is available, install if needed, and use for PEP 723 dependency management +if ! docker exec "${CONTAINER_NAME}" bash -c "command -v uv >/dev/null 2>&1 || test -x \$HOME/.local/bin/uv"; then + echo "uv not found, installing via pipx..." + docker exec "${CONTAINER_NAME}" bash -c " + # Install pipx if not available + if ! command -v pipx >/dev/null 2>&1; then + apt-get update -qq && apt-get install -y -qq pipx >/dev/null 2>&1 + fi + # Install uv via pipx + pipx install uv >/dev/null 2>&1 + " + echo "uv installed successfully" +fi + +echo "Using uv run for automatic dependency management..." +# Ensure uv is in PATH (handles ~/.local/bin installation) +# Pass build directory instead of single file +docker exec -e BUILD_DIR="${BUILD_DIR}" -e OUTPUT_FILE="${OUTPUT_FILE}" -e TARGET="${TARGET}" -e GRANULARITY="${GRANULARITY}" -e BUILD_TIME="${BUILD_TIME}" "${CONTAINER_NAME}" bash -c 'export PATH="$HOME/.local/bin:$PATH" && uv run --no-project /tmp/analyze_build_trace.py "${BUILD_DIR}" "/workspace/${OUTPUT_FILE}" "${TARGET}" "${GRANULARITY}" "${BUILD_TIME}" /tmp/ck_build_analysis_templates' + +# Copy report back to host +docker cp "${CONTAINER_NAME}:/workspace/${OUTPUT_FILE}" "${PROJECT_ROOT}/${OUTPUT_FILE}" + +# Cleanup +docker exec "${CONTAINER_NAME}" rm -f /tmp/analyze_build_trace.py +docker exec "${CONTAINER_NAME}" rm -rf /tmp/ck_build_analysis_templates + +echo "" +echo "═══════════════════════════════════════════════════════════════" +echo " Analysis Complete!" +echo "═══════════════════════════════════════════════════════════════" +echo "Report: ${PROJECT_ROOT}/${OUTPUT_FILE}" +echo "" +echo "Summary:" +docker exec "${CONTAINER_NAME}" bash -c "head -20 /workspace/${OUTPUT_FILE} | tail -10" +echo "" +echo "View the full report:" +echo " cat ${OUTPUT_FILE}" +echo " or open it in your editor" +echo "═══════════════════════════════════════════════════════════════" diff --git a/script/tools/ck-docker b/script/tools/ck-docker new file mode 100755 index 0000000000..82bf770011 --- /dev/null +++ b/script/tools/ck-docker @@ -0,0 +1,294 @@ +#!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# CK Docker Tool - Build and test composable_kernel in Docker with ROCm support + +set -e +set -o pipefail + +# Find script directory and load common utilities +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +source "${SCRIPT_DIR}/common.sh" + +# Initialize configuration +PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}") +CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}") + +# Help message +show_help() { + cat << EOF +CK Docker Tool - Build and test composable_kernel in Docker + +Usage: ck-docker [options] + +Commands: + start [name] Start Docker container + build [target] [--reconfigure] Build target (optionally reconfigure CMake) + test [options] Run test + shell [name] Open shell in container + status [name] Check container status + stop [name] Stop and remove container + +Examples: + ck-docker start + ck-docker build test_amdgcn_mma + ck-docker build --reconfigure test_amdgcn_mma + ck-docker test test_amdgcn_mma --gtest_filter=*Fp16* + ck-docker shell + +Environment: + CK_CONTAINER_NAME - Override default container name (default: ck__) + CK_DOCKER_IMAGE - Override Docker image (default: rocm/composable_kernel:ck_ub24.04_rocm7.0.1) + GPU_TARGET - Override GPU target detection (e.g., gfx950, gfx942) +EOF +} + +# Start container +cmd_start() { + local name="${1:-${CONTAINER_NAME}}" + local docker_image=$(get_docker_image) + + # Check if container exists and is running + if container_exists "${name}"; then + if container_is_running "${name}"; then + echo "Container '${name}' is already running" + return 0 + else + echo "Starting existing container '${name}'..." + docker start "${name}" + echo "Container started" + return 0 + fi + fi + + echo "Creating new Docker container '${name}'..." + docker run -d \ + --name "${name}" \ + --device=/dev/kfd --device=/dev/dri \ + --security-opt seccomp=unconfined \ + --group-add video \ + -v "${PROJECT_ROOT}":/workspace \ + -w /workspace \ + "${docker_image}" \ + tail -f /dev/null + + echo "Container '${name}' started successfully" + docker exec "${name}" bash -c "echo 'Working directory:' && pwd" +} + +# Build target +cmd_build() { + local target="" + local name="${CONTAINER_NAME}" + local reconfigure=false + + while [[ $# -gt 0 ]]; do + case $1 in + --name) + name="$2" + shift 2 + ;; + --reconfigure) + reconfigure=true + shift + ;; + *) + target="$1" + shift + ;; + esac + done + + # Check if container is running + if ! container_is_running "${name}"; then + echo "Container '${name}' not running. Starting..." + cmd_start "${name}" + fi + + # Reconfigure CMake if requested or if build.ninja doesn't exist + if [ "$reconfigure" = true ] || ! docker exec "${name}" test -f /workspace/build/build.ninja 2>/dev/null; then + echo "Detecting GPU target..." + local gpu_target=$(detect_gpu_target "${name}") + + if [ "$reconfigure" = true ]; then + echo "Reconfiguring CMake from scratch for GPU target: ${gpu_target}" + else + echo "Configuring build with CMake for GPU target: ${gpu_target}" + fi + + docker exec "${name}" bash -c " + cd /workspace || exit 1 + rm -rf /workspace/build + mkdir /workspace/build + cd /workspace/build || exit 1 + cmake .. -GNinja \ + -DGPU_TARGETS=${gpu_target} \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++ \ + -DBUILD_TESTING=ON 2>&1 | tail -30 + " + fi + + if [ -z "$target" ]; then + echo "Building all configured targets..." + else + echo "Building target: ${target}" + fi + + docker exec "${name}" bash -c " + cd /workspace/build || exit 1 + ninja ${target} 2>&1 + " + + echo "Build complete" +} + +# Run test +cmd_test() { + local test_name="" + local name="${CONTAINER_NAME}" + local -a test_options=() + + while [[ $# -gt 0 ]]; do + case $1 in + --name) + name="$2" + shift 2 + ;; + --gtest_*|--help) + test_options+=("$1") + shift + ;; + *) + if [ -z "$test_name" ]; then + test_name="$1" + else + test_options+=("$1") + fi + shift + ;; + esac + done + + if [ -z "$test_name" ]; then + echo "Error: test_name required" + echo "Usage: ck-docker test [--name container_name] [gtest_options]" + return 1 + fi + + # Check if container is running + if ! container_is_running "${name}"; then + echo "Error: Container '${name}' not running" + echo "Start it with: ck-docker start --name ${name}" + return 1 + fi + + if ! docker exec "${name}" test -f "/workspace/build/bin/${test_name}" 2>/dev/null; then + echo "Test executable not found. Building ${test_name}..." + cmd_build "${test_name}" --name "${name}" + fi + + echo "Running: ${test_name} ${test_options[*]}" + echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + # Build the command with proper quoting + local cmd="cd /workspace/build && ./bin/${test_name}" + for opt in "${test_options[@]}"; do + cmd="${cmd} $(printf '%q' "$opt")" + done + docker exec "${name}" bash -c "${cmd}" +} + +# Shell +cmd_shell() { + local name="${1:-${CONTAINER_NAME}}" + + # Check if container is running + if ! container_is_running "${name}"; then + echo "Container '${name}' not running. Starting..." + cmd_start "${name}" + fi + + echo "Opening shell in '${name}' (type 'exit' to leave)..." + docker exec -it "${name}" bash +} + +# Status +cmd_status() { + local name="${1:-}" + local docker_image=$(get_docker_image) + + if [ -z "$name" ]; then + echo "Composable Kernel Docker Containers:" + echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + docker ps -a --filter "ancestor=${docker_image}" \ + --format "table {{.Names}}\t{{.Status}}\t{{.CreatedAt}}" || echo "No containers found" + else + # Check container status + if container_is_running "${name}"; then + echo "Container '${name}' is RUNNING" + docker ps --filter "name=^${name}$" --format "table {{.Names}}\t{{.Status}}\t{{.Image}}" + echo "" + echo "GPU Information:" + docker exec "${name}" bash -c "rocm-smi --showproductname 2>/dev/null | head -10 || echo 'No GPU detected'" + elif container_exists "${name}"; then + echo "Container '${name}' exists but is STOPPED" + echo "Start with: ck-docker start ${name}" + else + echo "Container '${name}' does NOT exist" + echo "Create with: ck-docker start ${name}" + fi + fi +} + +# Stop +cmd_stop() { + local name="${1:-${CONTAINER_NAME}}" + + # Check if container exists + if container_exists "${name}"; then + echo "Stopping and removing container '${name}'..." + docker stop "${name}" 2>/dev/null || true + docker rm "${name}" 2>/dev/null || true + echo "Container stopped and removed" + else + echo "Container '${name}' does not exist" + fi +} + +# Main command dispatcher +case "${1:-}" in + start) + shift + cmd_start "$@" + ;; + build) + shift + cmd_build "$@" + ;; + test) + shift + cmd_test "$@" + ;; + shell) + shift + cmd_shell "$@" + ;; + status) + shift + cmd_status "$@" + ;; + stop) + shift + cmd_stop "$@" + ;; + help|--help|-h) + show_help + ;; + *) + echo "Unknown command: ${1:-}" + echo "" + show_help + exit 1 + ;; +esac diff --git a/script/tools/common.sh b/script/tools/common.sh new file mode 100644 index 0000000000..6683572c0f --- /dev/null +++ b/script/tools/common.sh @@ -0,0 +1,97 @@ +#!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# Common utilities for CK Docker tools +# Shared configuration and helper functions + +# Find project root (where .git directory is) +get_project_root() { + local script_dir="$1" + cd "${script_dir}/../.." && pwd +} + +# Detect git branch and sanitize for Docker naming +get_sanitized_branch() { + local project_root="$1" + local branch + + branch=$(cd "${project_root}" && git rev-parse --abbrev-ref HEAD 2>/dev/null | tr '/' '_' | tr -cd 'a-zA-Z0-9_-' || echo "") + branch=${branch:-unknown} + + # Handle detached HEAD state + if [ "${branch}" = "HEAD" ]; then + branch="detached" + fi + + echo "${branch}" +} + +# Get username with fallback +get_username() { + echo "${USER:-$(whoami 2>/dev/null || echo "user")}" +} + +# Generate default container name: ck__ +get_default_container_name() { + local project_root="$1" + local user_name + local git_branch + + user_name=$(get_username) + git_branch=$(get_sanitized_branch "${project_root}") + + echo "ck_${user_name}_${git_branch}" +} + +# Get container name (respects CK_CONTAINER_NAME env var) +get_container_name() { + local project_root="$1" + local default_name + + default_name=$(get_default_container_name "${project_root}") + echo "${CK_CONTAINER_NAME:-${default_name}}" +} + +# Get Docker image (respects CK_DOCKER_IMAGE env var) +get_docker_image() { + echo "${CK_DOCKER_IMAGE:-rocm/composable_kernel:ck_ub24.04_rocm7.0.1}" +} + +# Check if container exists (exact match) +container_exists() { + local name="$1" + docker ps -a --filter "name=^${name}$" --format '{{.Names}}' | grep -q "^${name}$" +} + +# Check if container is running (exact match) +container_is_running() { + local name="$1" + docker ps --filter "name=^${name}$" --format '{{.Names}}' | grep -q "^${name}$" +} + +# Detect GPU target in container +detect_gpu_target() { + local container="$1" + + # Allow override via GPU_TARGET environment variable + if [ -n "${GPU_TARGET:-}" ]; then + echo "${GPU_TARGET}" + return 0 + fi + + docker exec "${container}" bash -c " + rocminfo 2>/dev/null | grep -oP 'gfx[0-9a-z]+' | head -1 || echo 'gfx950' + " | tr -d '\r\n' +} + +# Ensure container is running, start if needed +ensure_container_running() { + local container="$1" + local script_dir="$2" + + if ! container_is_running "${container}"; then + echo "Container '${container}' not running. Starting with ck-docker..." + "${script_dir}/ck-docker" start "${container}" + fi +} diff --git a/script/tools/templates/build_analysis_report.md.jinja b/script/tools/templates/build_analysis_report.md.jinja new file mode 100644 index 0000000000..f91dce14a9 --- /dev/null +++ b/script/tools/templates/build_analysis_report.md.jinja @@ -0,0 +1,125 @@ +# Composable Kernel Build Time Analysis Report + +**Generated:** {{ timestamp }} +**Target:** {{ target }} +**Granularity:** {{ granularity }}µs +**Files Analyzed:** {{ num_files }} + +## Executive Summary + +- **Wall Clock Time:** {{ build_time }} seconds +- **Trace Time:** {{ total_trace_time|us_to_s|round(1) }} seconds +- **Template Instantiation Time:** {{ total_template_time|us_to_s|round(1) }} seconds ({{ (100 * total_template_time / total_trace_time)|round(1) }}% of trace) +- **Total Events Captured:** {{ total_events|format_number }} (across {{ num_files }} file{{ 's' if num_files != 1 else '' }}) +- **Total Template Instantiations:** {{ total_instantiations|format_number }} +- **Unique Template Families:** {{ unique_families }} + +{% if num_files > 1 -%} +## Per-File Analysis + +| File | Events | Template Time (ms) | % of Total | +|------|--------|-------------------|------------| +{% for file in file_stats[:20] -%} +| {{ file.name|truncate(50)|pad(50) }} | {{ "%7d"|format(file.events) }} | {{ "%17.2f"|format(file.template_time|us_to_ms) }} | {{ "%9.1f"|format(100 * file.template_time / total_template_time if total_template_time > 0 else 0) }}% | +{% endfor %} + +{% endif -%} +## Compilation Phase Breakdown + +| Phase | Time (ms) | Time (s) | % of Total | +|-------|-----------|----------|------------| +{% for phase, dur in phases[:20] -%} +| {{ phase|pad(40) }} | {{ "%9.2f"|format(dur|us_to_ms) }} | {{ "%8.2f"|format(dur|us_to_s) }} | {{ "%9.1f"|format(100 * dur / total_trace_time) }}% | +{% endfor %} + +## Top 30 Most Expensive Individual Instantiations + +{% if num_files > 1 -%} +| Rank | Template | Type | Time (ms) | File | +|------|----------|------|-----------|------| +{% for inst in top_individual[:30] -%} +| {{ "%4d"|format(loop.index) }} | {{ inst.detail|truncate(50) }} | {{ inst.inst_type|pad(5) }} | {{ "%9.2f"|format(inst.dur|us_to_ms) }} | {{ inst.file|truncate(20) }} | +{% endfor -%} +{% else -%} +| Rank | Template | Type | Time (ms) | +|------|----------|------|-----------| +{% for inst in top_individual[:30] -%} +| {{ "%4d"|format(loop.index) }} | {{ inst.detail|truncate(70) }} | {{ inst.inst_type|pad(5) }} | {{ "%9.2f"|format(inst.dur|us_to_ms) }} | +{% endfor -%} +{% endif %} + +## Template Families by Total Time (Top 50) + +| Rank | Template Family | Count | Total (ms) | Avg (ms) | % of Total | +|------|-----------------|-------|------------|----------|------------| +{% for name, stats in templates_by_time[:50] -%} +| {{ "%4d"|format(loop.index) }} | {{ name|truncate(43)|pad(43) }} | {{ "%5d"|format(stats.count) }} | {{ "%10.2f"|format(stats.total_dur|us_to_ms) }} | {{ "%8.2f"|format(stats.avg|us_to_ms) }} | {{ "%9.1f"|format(stats.pct) }}% | +{% endfor %} + +## Template Families by Instantiation Count (Top 50) + +| Rank | Template Family | Count | Total (ms) | Avg (ms) | +|------|-----------------|-------|------------|----------| +{% for name, stats in templates_by_count[:50] -%} +| {{ "%4d"|format(loop.index) }} | {{ name|truncate(43)|pad(43) }} | {{ "%5d"|format(stats.count) }} | {{ "%10.2f"|format(stats.total_dur|us_to_ms) }} | {{ "%8.2f"|format(stats.avg|us_to_ms) }} | +{% endfor %} + +## Key Insights + +### 1. Template Instantiation Impact +- Template instantiation accounts for {{ (100 * total_template_time / total_trace_time)|round(1) }}% of total trace time +{% if unique_families >= 10 -%} +- Top 10 template families account for {{ top10_pct|round(1) }}% of instantiation time +{% endif %} + +### 2. Most Expensive Templates +{% if templates_by_time|length > 0 -%} +- **{{ templates_by_time[0][0] }}**: {{ templates_by_time[0][1].count|format_number }} instantiations, {{ (templates_by_time[0][1].total_dur|us_to_s)|round(2) }}s total +{% endif -%} +{% if templates_by_time|length > 1 -%} +- **{{ templates_by_time[1][0] }}**: {{ templates_by_time[1][1].count|format_number }} instantiations, {{ (templates_by_time[1][1].avg|us_to_ms)|round(2) }}ms average +{% endif %} + +## Optimization Recommendations + +### High-Impact Targets (by total time) +{% for name, stats in templates_by_time[:5] -%} +**{{ loop.index }}. {{ name }}** - {{ (stats.total_dur|us_to_s)|round(1) }}s total ({{ stats.pct|round(1) }}%) + - {{ stats.count|format_number }} instantiations, {{ (stats.avg|us_to_ms)|round(2) }}ms average + {% if stats.count > 100 -%} + - Strategy: Extern templates - High instantiation count suggests repeated compilation + {% elif stats.avg|us_to_ms > 50 -%} + - Strategy: Template specialization - High individual cost suggests complexity + {% else -%} + - Strategy: Explicit instantiation - Pre-instantiate common configurations + {% endif %} + +{% endfor %} +### Frequently Instantiated (optimization candidates) +{% for name, stats in templates_by_count[:5] if stats.count > 100 -%} +**{{ name }}** - {{ stats.count|format_number }} times ({{ (stats.total_dur|us_to_s)|round(2) }}s total) + - Consider: Precompiled headers or extern templates to avoid recompilation + +{% endfor %} +### Most Expensive Individual Instantiations +{% for inst in top_individual[:3] -%} +**{{ loop.index }}. {{ inst.detail|truncate(60) }}** - {{ (inst.dur|us_to_ms)|round(1) }}ms + - Strategy: Profile and simplify this specific instantiation + +{% endfor %} + +## Detailed Statistics + +- **Total Unique Templates:** {{ unique_families }} +- **Total Instantiations:** {{ total_instantiations|format_number }} +{% if total_instantiations > 0 -%} +- **Average Instantiation Time:** {{ ((total_template_time // total_instantiations)|us_to_ms)|round(3) }}ms +{% endif -%} +{% if unique_families > 0 -%} +- **Median Template Family Count:** {{ median_count }} +{% endif %} + +--- + +*Report generated using Clang -ftime-trace with {{ granularity }}µs granularity* +*Analysis tool: ck-build-analysis* From de8ee379ad9cc0108949abf0688c2f32c6e23850 Mon Sep 17 00:00:00 2001 From: Thrupti Raj Lakshmana Gowda Date: Fri, 16 Jan 2026 12:17:21 -0600 Subject: [PATCH 15/99] Fixing GEMM Multi D on Tile Engine (#3583) --- tile_engine/ops/gemm/gemm_instance_builder.py | 344 +++++++++--------- 1 file changed, 173 insertions(+), 171 deletions(-) diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index 9c60c565de..3607bbc59a 100644 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -676,36 +676,38 @@ struct SelectedKernel {{ if self.kernel_name_prefix == "gemm_multi_d": instance_code += """ - // Kernel type - using GemmKernelMultiD = ck_tile::GemmKernelMultiD; - - // Kernel arguments - auto kargs = GemmKernelMultiD::MakeKernelArgs(args); - - if (!GemmKernelMultiD::IsSupportedArgument(kargs)) { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!"); - } + // Kernel type + using GemmKernelMultiD = ck_tile::GemmKernelMultiD; + + // Kernel arguments + auto kargs = GemmKernelMultiD::MakeKernelArgs(args); + + if (!GemmKernelMultiD::IsSupportedArgument(kargs)) { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!"); + } - // Get grid and block sizes - const dim3 grids = GemmKernelMultiD::GridSize(args.M, args.N, args.k_batch); - const dim3 blocks = GemmKernelMultiD::BlockSize(); - - if(stream.log_level_ > 0) { - std::cout << "Launching kernel with args: " << GemmKernelMultiD::GetName() << '\\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; - }""" + // Get grid and block sizes + const dim3 grids = GemmKernelMultiD::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = GemmKernelMultiD::BlockSize(); + + if(stream.log_level_ > 0) { + std::cout << "Launching kernel with args: " << GemmKernelMultiD::GetName() << '\\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + }""" instance_code += f""" - // Launch kernel - constexpr int kBlockPerCu = {k_block_per_cu}; - float ave_time = ck_tile::launch_kernel( - stream, - ck_tile::make_kernel(GemmKernelMultiD{{}}, grids, blocks, 0, kargs)); - - return ave_time; - }};""" + // Launch kernel + constexpr int kBlockPerCu = {k_block_per_cu}; + float ave_time = ck_tile::launch_kernel( + stream, + ck_tile::make_kernel(GemmKernelMultiD{{}}, grids, blocks, 0, kargs)); + + return ave_time; + }} +}}; +""" elif self.kernel_name_prefix in ["gemm_universal", "gemm_preshuffle"]: instance_code += f""" @@ -713,32 +715,32 @@ struct SelectedKernel {{ // Kernel type using GemmKernel = ck_tile::GemmKernel; - // Kernel arguments - auto kargs = GemmKernel::MakeKernelArgs(args); - - if (!GemmKernel::IsSupportedArgument(kargs)) {{ - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!"); - }} + // Kernel arguments + auto kargs = GemmKernel::MakeKernelArgs(args); + + if (!GemmKernel::IsSupportedArgument(kargs)) {{ + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!"); + }} - // Get grid and block sizes - const dim3 grids = {"GemmKernel::MaxOccupancyGridSize(stream)" if persistent in [True, "true"] else "GemmKernel::GridSize(args.M, args.N, args.k_batch)"}; - const dim3 blocks = GemmKernel::BlockSize(); - - if(stream.log_level_ > 0) {{ - std::cout << "Launching kernel with args: " << GemmKernel::GetName() << '\\n' - << "grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}" - << ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}" - << std::endl; - }}""" + // Get grid and block sizes + const dim3 grids = {"GemmKernel::MaxOccupancyGridSize(stream)" if persistent in [True, "true"] else "GemmKernel::GridSize(args.M, args.N, args.k_batch)"}; + const dim3 blocks = GemmKernel::BlockSize(); + + if(stream.log_level_ > 0) {{ + std::cout << "Launching kernel with args: " << GemmKernel::GetName() << '\\n' + << "grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}" + << ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}" + << std::endl; + }}""" instance_code += f""" - // Launch kernel - constexpr int kBlockPerCu = {k_block_per_cu}; - float ave_time = ck_tile::launch_kernel( - stream, - ck_tile::make_kernel(GemmKernel{{}}, grids, blocks, 0, kargs)); - - return ave_time; + // Launch kernel + constexpr int kBlockPerCu = {k_block_per_cu}; + float ave_time = ck_tile::launch_kernel( + stream, + ck_tile::make_kernel(GemmKernel{{}}, grids, blocks, 0, kargs)); + + return ave_time; }} }}; """ @@ -747,8 +749,8 @@ struct SelectedKernel {{ def populate_epilogue(self, epilogue): instance_code = """ - // Epilogue - """ + // Epilogue + """ if epilogue == "cshuffle": if self.kernel_name_prefix == "gemm_universal": @@ -769,145 +771,145 @@ struct SelectedKernel {{ def populate_cshuffle_gemm_universal(self): instance_code = """ - using EpilogueProblem = ck_tile::CShuffleEpilogueProblem< - ADataType, - BDataType, - ck_tile::tuple<>, // DsDataType - AccDataType, - CDataType, - ck_tile::tuple<>, // DsLayout - CLayout, - ck_tile::element_wise::PassThrough, - TileM, // kM_ - TileN, // kN_ - WarpPerBlock_M, // MWave_ - WarpPerBlock_N, // NWave_ - WarpTileM, // MPerXdl_ - WarpTileN, // NPerXdl_ - WarpTileK, // KPerXdl_ - TransposeC, // isCTransposed_ - NumWaveGroups>; // kNumWaveGroups_ - - using GemmEpilogue = ck_tile::CShuffleEpilogue;""" + using EpilogueProblem = ck_tile::CShuffleEpilogueProblem< + ADataType, + BDataType, + ck_tile::tuple<>, // DsDataType + AccDataType, + CDataType, + ck_tile::tuple<>, // DsLayout + CLayout, + ck_tile::element_wise::PassThrough, + TileM, // kM_ + TileN, // kN_ + WarpPerBlock_M, // MWave_ + WarpPerBlock_N, // NWave_ + WarpTileM, // MPerXdl_ + WarpTileN, // NPerXdl_ + WarpTileK, // KPerXdl_ + TransposeC, // isCTransposed_ + NumWaveGroups>; // kNumWaveGroups_ + + using GemmEpilogue = ck_tile::CShuffleEpilogue;""" return instance_code def populate_cshuffle_gemm_multi_d(self): instance_code = """ - using EpilogueProblem = ck_tile::CShuffleEpilogueProblem< - ADataType, - BDataType, - DsDataType, - AccDataType, - CDataType, - DsLayout, - CLayout, - ElementWiseFn, - TileM, // kM_ - TileN, // kN_ - WarpPerBlock_M, // MWave_ - WarpPerBlock_N, // NWave_ - WarpTileM, // MPerXdl_ - WarpTileN, // NPerXdl_ - WarpTileK, // KPerXdl_ - TransposeC>; // isCTransposed_ - - using GemmEpilogue = ck_tile::CShuffleEpilogue;""" + using EpilogueProblem = ck_tile::CShuffleEpilogueProblem< + ADataType, + BDataType, + DsDataType, + AccDataType, + CDataType, + DsLayout, + CLayout, + ElementWiseFn, + TileM, // kM_ + TileN, // kN_ + WarpPerBlock_M, // MWave_ + WarpPerBlock_N, // NWave_ + WarpTileM, // MPerXdl_ + WarpTileN, // NPerXdl_ + WarpTileK, // KPerXdl_ + TransposeC>; // isCTransposed_ + + using GemmEpilogue = ck_tile::CShuffleEpilogue;""" return instance_code def populate_cshuffle_gemm_preshuffle(self): instance_code = """ - using EpilogueProblem = ck_tile::CShuffleEpilogueProblem< - ADataType, - BDataType, - ck_tile::tuple<>, // DsDataType - AccDataType, - CDataType, - ck_tile::tuple<>, // DsLayout - CLayout, - ck_tile::element_wise::PassThrough, - TileM, // kM_ - TileN, // kN_ - WarpPerBlock_M, // MWave_ - WarpPerBlock_N, // NWave_ - WarpTileM, // MPerXdl_ - WarpTileN, // NPerXdl_ - WarpTileK, // KPerXdl_ - TransposeC, // isCTransposed_ - NumWaveGroups, // kNumWaveGroups_ - false, // FixedVectorSize_ - 1, // VectorSizeC_ - PermuteN>; // isPermuteN_ - - using GemmEpilogue = ck_tile::CShuffleEpilogue;""" + using EpilogueProblem = ck_tile::CShuffleEpilogueProblem< + ADataType, + BDataType, + ck_tile::tuple<>, // DsDataType + AccDataType, + CDataType, + ck_tile::tuple<>, // DsLayout + CLayout, + ck_tile::element_wise::PassThrough, + TileM, // kM_ + TileN, // kN_ + WarpPerBlock_M, // MWave_ + WarpPerBlock_N, // NWave_ + WarpTileM, // MPerXdl_ + WarpTileN, // NPerXdl_ + WarpTileK, // KPerXdl_ + TransposeC, // isCTransposed_ + NumWaveGroups, // kNumWaveGroups_ + false, // FixedVectorSize_ + 1, // VectorSizeC_ + PermuteN>; // isPermuteN_ + + using GemmEpilogue = ck_tile::CShuffleEpilogue;""" return instance_code def populate_default_gemm_universal(self): instance_code = """ - using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem< - ADataType, - BDataType, - ck_tile::tuple<>, // DsDataType - AccDataType, - CDataType, - ck_tile::tuple<>, // DsLayout - CLayout, - ck_tile::element_wise::PassThrough, - TileM, // kM_ - TileN, // kN_ - kPadM, - kPadN, - WarpTileM, // kMPerXdl_ - WarpTileN, // kNPerXdl_ - WarpTileK, // kKPerXdl_ - TransposeC>; // isCTransposed_ - - using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue;""" + using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem< + ADataType, + BDataType, + ck_tile::tuple<>, // DsDataType + AccDataType, + CDataType, + ck_tile::tuple<>, // DsLayout + CLayout, + ck_tile::element_wise::PassThrough, + TileM, // kM_ + TileN, // kN_ + kPadM, + kPadN, + WarpTileM, // kMPerXdl_ + WarpTileN, // kNPerXdl_ + WarpTileK, // kKPerXdl_ + TransposeC>; // isCTransposed_ + + using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue;""" return instance_code def populate_default_gemm_multi_d(self): instance_code = """ - using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem< - ADataType, - BDataType, - DsDataType, - AccDataType, - CDataType, - DsLayout, - CLayout, - ElementWiseFn, - TileM, // kM_ - TileN, // kN_ - kPadM, - kPadN, - WarpTileM, // kMPerXdl_ - WarpTileN, // kNPerXdl_ - WarpTileK, // kKPerXdl_ - TransposeC>; // isCTransposed_ - - using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue;""" + using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem< + ADataType, + BDataType, + DsDataType, + AccDataType, + CDataType, + DsLayout, + CLayout, + ElementWiseFn, + TileM, // kM_ + TileN, // kN_ + kPadM, + kPadN, + WarpTileM, // kMPerXdl_ + WarpTileN, // kNPerXdl_ + WarpTileK, // kKPerXdl_ + TransposeC>; // isCTransposed_ + + using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue;""" return instance_code def populate_default_gemm_preshuffle(self): instance_code = """ - using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem< - ADataType, - BDataType, - ck_tile::tuple<>, // DsDataType - AccDataType, - CDataType, - ck_tile::tuple<>, // DsLayout - CLayout, - ck_tile::element_wise::PassThrough, - TileM, // kM_ - TileN, // kN_ - kPadM, - kPadN, - WarpTileM, // kMPerXdl_ - WarpTileN, // kNPerXdl_ - WarpTileK, // kKPerXdl_ - TransposeC>; // isCTransposed_ - - using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue;""" + using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem< + ADataType, + BDataType, + ck_tile::tuple<>, // DsDataType + AccDataType, + CDataType, + ck_tile::tuple<>, // DsLayout + CLayout, + ck_tile::element_wise::PassThrough, + TileM, // kM_ + TileN, // kN_ + kPadM, + kPadN, + WarpTileM, // kMPerXdl_ + WarpTileN, // kNPerXdl_ + WarpTileK, // kKPerXdl_ + TransposeC>; // isCTransposed_ + + using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue;""" return instance_code def _generate_cmake_individual_targets(self, kernel_list): From 427d4fb9e947ab73f374c7a941d9f84795662917 Mon Sep 17 00:00:00 2001 From: spolifroni-amd Date: Fri, 16 Jan 2026 13:34:44 -0500 Subject: [PATCH 16/99] CK Tile: fix some issues (#3557) * Adding CK Tile documentation * Updates based on feedback * Fix tile window API description * Fix remaining images * add documentation about flush_cache and rotating_buffer functionality in ck_tile * Supplement the documentation * light edit of the ck tile conceptual doc --------- Co-authored-by: Vidyasagar Co-authored-by: AviralGoelAMD Co-authored-by: ThomasNing --- docs/conceptual/ck_tile/buffer_views.rst | 233 ++++++++++++----------- 1 file changed, 117 insertions(+), 116 deletions(-) diff --git a/docs/conceptual/ck_tile/buffer_views.rst b/docs/conceptual/ck_tile/buffer_views.rst index 14b8309504..03b8e87b1b 100644 --- a/docs/conceptual/ck_tile/buffer_views.rst +++ b/docs/conceptual/ck_tile/buffer_views.rst @@ -1,35 +1,13 @@ -.. meta:: - :description: Composable Kernel CK Tile buffer views - :keywords: composable kernel, CK, CK Tile, ROCm, API, buffer view, raw memory - .. _ck_tile_buffer_views: -CK Tile buffer view -======================= - -Buffer view is an abstraction that provides structured access to memory. The ``buffer_view`` class is exposed in ``include/ck_tile/core/tensor/buffer_view.hpp``. - -Buffer view serves as the foundation for :ref:`ck_tile_tensor_views`. BufferView handles memory addressing and type safety, while TensorView builds upon this to add multi-dimensional coordinates (shape and strides). - - -Buffer view provides the following advantages: - -* A unified interface across global, shared, and register memory -* Address spaces encoded in types, taking advantage of compile-time type checking -* Configurable handling of invalid values, out-of-bounds operations, and conditional access patterns -* Atomic operations for parallel algorithms -* AMD GPU-specific optimizations -* Automatic application of appropriate memory ordering constraints and cache control directives based on the target address space and operation type - - -[TO DO: do we want to say more about these items? There wasn't a lot of detail in the original text, so I put them in a list for now] - - +Buffer Views - Raw Memory Access Address Space Usage Patterns ---------------------------- -[TO DO: explain in words what the diagram shows] +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + .. Original mermaid diagram (edit here, then run update_diagrams.py) @@ -66,18 +44,26 @@ Address Space Usage Patterns style Compute fill:#e0e7ff,stroke:#4338ca,stroke-width:2px + + + + .. image:: diagrams/buffer_views_1.svg :alt: Diagram :align: center +C++ Implementation +------------------ +**File**: ``include/ck_tile/core/tensor/buffer_view.hpp`` Basic Creation ~~~~~~~~~~~~~~ -[TO DO: remove "modern C++ template metaprogramming" and "zero-overhead abstraction"] +By encoding critical properties such as buffer size and address space as template parameters, BufferView transforms what would traditionally be runtime decisions into compile-time constants. This design philosophy enables the compiler to perform aggressive optimizations, including constant propagation, loop unrolling, and instruction selection, that would be impossible with runtime parameters. -[TO DO: might want to move the implementation details to a separate section under "reference"] +The use of compile-time constants extends beyond mere optimization. When the buffer size is encoded in the type system using constructs like ``number<8>{}``, the compiler can statically verify that array accesses are within bounds, eliminate unnecessary bounds checks, and even restructure algorithms to better match the known data dimensions. This compile-time knowledge propagates through the entire computation, enabling optimizations at every level of the abstraction hierarchy. +The address space template parameter represents another crucial design decision. By making the memory space part of the type system, BufferView ensures that operations appropriate for one memory space cannot be accidentally applied to another. This type safety prevents common errors such as attempting atomic operations on register memory or using global memory synchronization primitives on local memory. The compiler enforces these constraints at compile time, transforming potential runtime errors into compile-time diagnostics. .. code-block:: cpp @@ -98,7 +84,6 @@ Basic Creation buffer_size // number of elements ); - // Implementation detail: The actual C++ template is: // template (data, buffer_size, custom_invalid); - - // Invalid element access with is_valid_element=false - // Returns custom_invalid due to custom invalid value mode - auto invalid_value = buffer_view.template get(0, 0, false); - printf("Invalid element: %.1f\n", invalid_value.get(0)); - - // Out of bounds access - AMD buffer addressing handles bounds checking - // Will return custom_invalid when accessing beyond buffer_size - auto oob_value = buffer_view.template get(0, 100, true); - printf("Out of bounds: %.1f\n", oob_value.get(0)); - - - - - Get Operations -------------- -[TO DO: might want to put this implementation detail in the reference section] +Scalar Access +~~~~~~~~~~~~~ -The signature for the ``buffer_view`` ``get()`` takes four parameters: +The get operations in BufferView form the cornerstone of memory access patterns in CK Tile. These operations embody a advanced understanding of GPU memory systems and the patterns that lead to optimal performance. The scalar access interface incorporates multiple layers of optimization and safety mechanisms that work together to provide both performance and correctness. -``i``: the primary offset into the buffer expressed in terms of elements of type T rather than raw bytes. +The parameter structure of scalar access operations reflects careful design choices aimed at maximizing flexibility while maintaining efficiency. The base index parameter ``i`` represents the primary offset into the buffer, expressed in terms of elements of type T rather than raw bytes. This type-aware indexing prevents common errors related to pointer arithmetic and ensures that vector types are handled correctly. The additional ``linear_offset`` parameter provides fine-grained control over the final access location, enabling complex access patterns without requiring expensive index calculations in the kernel code. -``linear_offset``: [TO DO: what is this?] +The ``is_valid_element`` parameter provides a solution to conditional memory access. Rather than using traditional if-statements that would cause warp divergence, this boolean parameter enables predicated execution where the memory access occurs unconditionally but the result is conditionally used. This approach maintains uniform control flow across all threads in a warp, preserving the SIMD execution model that is fundamental to GPU performance. -``is_valid_element``: [TO DO: what is this?] +The invalid value modes provide a mechanism for handling the boundary conditions that arise in parallel algorithms. When ``InvalidElementUseNumericalZeroValue`` is set to true, the system returns zero for any invalid access, whether due to the ``is_valid_element`` flag or out-of-bounds indexing. This mode is important for algorithms where zero serves as a natural extension value, such as in image processing with zero-padding or sparse matrix operations where missing elements are implicitly zero. -[TO DO: the last param, that's the out of bounds handling, yes? -.. code:: cpp +The custom invalid value mode, activated when ``InvalidElementUseNumericalZeroValue`` is false, offers additional flexibility for algorithms with specific boundary requirements. This mode returns a user-specified value for invalid accesses, accommodating use cases such as sentinel values in sorting algorithms, infinity values in optimization problems, or special markers in data processing pipelines. The implementation ensures that this flexibility comes without performance penalty, using the same branchless execution strategies as the zero mode. - get(index_t i, - index_t linear_offset, - bool is_valid_element, - bool_constant = {}) +Out-of-bounds handling leverages AMD GPU hardware capabilities to provide safety with minimal impact to performance. When AMD buffer addressing is enabled, the hardware automatically clamps memory accesses to valid ranges, preventing the segmentation faults that would occur on CPU systems. This hardware-assisted bounds checking operates at wire speed, adding no overhead to the memory access path while ensuring that kernels cannot corrupt memory outside their allocated regions. +Vector Access +~~~~~~~~~~~~~ -[TO DO: need some context around the code] +Vector memory operations represent one of the most critical optimizations available in modern GPU programming, and BufferView's vector access interface exposes this capability. By using template parameters to specify vector types through constructs like ``ext_vector_t``, the interface enables compile-time selection of optimal load and store instructions that can transfer multiple data elements in a single memory transaction. This vectorization is crucial for :ref:`ck_tile_load_store_traits`, which automatically selects optimal access patterns. -[TO DO: code chunks need to have detail and explanation so that the reader can see what they're trying to demonstrate.] +The significance of vector operations extends beyond bandwidth improvements. GPUs are designed with wide memory buses that can transfer 128, 256, or even 512 bits per transaction. When scalar operations access only 32 bits at a time, they utilize only a fraction of this available bandwidth. Vector operations align with these wide buses, enabling full bandwidth utilization and reducing the total number of memory transactions required. +The implementation of vector access maintains the same parameter structure as scalar operations, providing consistency across the API while automatically handling the complexities of multi-element transfers. The system manages alignment requirements, ensures that vector loads and stores use the optimal hardware instructions, and handles cases where vector operations extend beyond buffer boundaries. This transparent handling of edge cases allows developers to use vector operations confidently without manual boundary checks or special-case code for partial vectors. -.. code-block:: cpp - - // Create buffer view - float data[8] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; - auto buffer_view = make_buffer_view(data, 8); - - // Simple get - compile-time bounds checking when possible - auto value_buf = buffer_view.template get(0,1,true); //get the buffer from the buffer view - float value = value_buf.get(0); //get the value from the buffer - - // Get with valid flag - branchless conditional access - bool valid_flag = false; - value_buf = buffer_view.template get(0,1,valid_flag); - value = value_buf.get(0); - // Returns 0 valid_flag is false - - // vectorized get - using float2 = ext_vector_t; - auto vector_buf = buffer_view.template get(0, 0, true); - // Loads 2 floats in a single instruction - float val1 = vector_buf.get(0); - float val2 = vector_buf.get(1); - } - -``ext_vector_t`` enables compile-time selection of optimal load and store instructions that can transfer multiple data elements in a single memory transaction. - -[TO DO: what is it actually doing? When does one use scalars vs vectors? Is it application specific or are there ] +Scalar vs Vectorized Memory Access +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + .. Original mermaid diagram (edit here, then run update_diagrams.py) @@ -287,8 +216,9 @@ The signature for the ``buffer_view`` ``get()`` takes four parameters: Understanding BufferView Indexing ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -[TO DO: an explanation of the diagram is needed] - +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + .. Original mermaid diagram (edit here, then run update_diagrams.py) @@ -335,14 +265,69 @@ Understanding BufferView Indexing .. image:: diagrams/buffer_views_3.svg :alt: Diagram :align: center - - + +C++ Get Operations +~~~~~~~~~~~~~~~~~~ + +.. code-block:: cpp + + __device__ void example_get_operations() + { + // Create buffer view + float data[8] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; + auto buffer_view = make_buffer_view(data, 8); + + // Simple get - compile-time bounds checking when possible + auto value_buf = buffer_view.template get(0,1,true); //get the buffer from the buffer view + float value = value_buf.get(0); //get the value from the buffer + + // Get with valid flag - branchless conditional access + bool valid_flag = false; + value_buf = buffer_view.template get(0,1,valid_flag); + value = value_buf.get(0); + // Returns 0 valid_flag is false + + // vectorized get + using float2 = ext_vector_t; + auto vector_buf = buffer_view.template get(0, 0, true); + // Loads 2 floats in a single instruction + float val1 = vector_buf.get(0); + float val2 = vector_buf.get(1); + } + +Custom Value Return Mode for OOB & Invalid Access +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: cpp + + void scalar_get_operations_example() { + + // Create data array + constexpr size_t buffer_size = 8; + float data[buffer_size] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; + float custom_invalid = 13.0f; + + // Create global memory buffer view with zero invalid value mode (default) + auto buffer_view = make_buffer_view(data, buffer_size, custom_invalid); + + // Invalid element access with is_valid_element=false + // Returns custom_invalid due to custom invalid value mode + auto invalid_value = buffer_view.template get(0, 0, false); + printf("Invalid element: %.1f\n", invalid_value.get(0)); + + // Out of bounds access - AMD buffer addressing handles bounds checking + // Will return custom_invalid when accessing beyond buffer_size + auto oob_value = buffer_view.template get(0, 100, true); + printf("Out of bounds: %.1f\n", oob_value.get(0)); + } + +.. note:: + + Partial Out Of Bound (OOB) access during vector reads will return 'junk' values for the OOB access. Zero or custom invalid value is only returned for complete invalid/OOB access, in other words, it is only returned when the first address of the vector is invalid. Update Operations ----------------- -Update operations modify the buffer content. The ``set()`` method writes a value to a specific location. - .. code-block:: cpp void scalar_set_operations_example() { @@ -373,8 +358,6 @@ Update operations modify the buffer content. The ``set()`` method writes a value Atomic Operations ----------------- -[TO DO: this needs information] - Atomic vs Non-Atomic Operations ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -441,3 +424,21 @@ C++ Atomic Operations __syncthreads(); } + +Summary +------- + +BufferView abstracts GPU memory hierarchies behind a concise interface. The approach is intended to keep overhead small while enabling optimizations that are otherwise awkward in low-level code. + +BufferView offers a unified interface across global, shared, and register memory. Using the same API for each space can lower cognitive overhead, reduce certain classes of mistakes, and support code reuse via template parameters. + +Address spaces are encoded in types so that common errors are reported at compile time. Consistent with CK Tile’s zero-overhead design aim, compile-time checks are favored over runtime guards. The C++ type system enforces memory-space constraints and can make valid cases more amenable to compiler optimization. + +BufferView supports configurable handling of invalid values, optional runtime bounds checks, and conditional access patterns. It also provides atomic operations for thread-safe updates. These features are intended to cover common edge cases without adding unnecessary overhead. + +By hiding the complexity of different memory spaces while exposing the operations needed for high-performance GPU computing, BufferView establishes a pattern that the rest of CK Tile follows: compile-time abstractions that enhance rather than compromise performance. The :ref:`ck_tile_tensor_views` and :ref:`ck_tile_distribution` add capability while maintaining the efficiency established at the base. For hardware-specific details about memory hierarchies, see :ref:`ck_tile_gpu_basics`. + +Next Steps +---------- + +Continue to :ref:`ck_tile_tensor_views` to learn how to build structured tensor views on top of buffer views. From 2d233c838a46e6797b96a0b5270eb46641782e5a Mon Sep 17 00:00:00 2001 From: John Shumway Date: Fri, 16 Jan 2026 10:36:23 -0800 Subject: [PATCH 17/99] Disable CK Builder for SLES15 in Jenkins CI (#3581) 1. Added `-DCK_EXPERIMENTAL_BUILDER=OFF` to the `setup_args` to explicitly disable the experimental builder 2. Added a detailed comment explaining why this is necessary: - SLES15 is a legacy platform with limited C++20 ecosystem support - While the ROCm compiler supports C++20, the older system libraries and standard library implementation on SLES15 does not reliably support all C++20 features required by the experimental CK Builder --- Jenkinsfile | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/Jenkinsfile b/Jenkinsfile index e01cfcbf01..e8ce97780d 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -1731,7 +1731,10 @@ pipeline { } agent{ label rocmnode("gfx90a") } environment{ - setup_args = """ -DGPU_TARGETS="gfx942" -DCK_USE_ALTERNATIVE_PYTHON=/opt/Python-3.8.13/bin/python3.8 """ + // SLES15 is a legacy platform with limited C++20 ecosystem support (older system libraries, + // standard library implementation). While the ROCm compiler supports C++20, the experimental + // CK Builder requires full C++20 feature support that does not be reliably available on SLES15. + setup_args = """ -DGPU_TARGETS="gfx942" -DCK_USE_ALTERNATIVE_PYTHON=/opt/Python-3.8.13/bin/python3.8 -DCK_EXPERIMENTAL_BUILDER=OFF """ execute_args = " " } steps{ From fec81109f1a9156d33806d614ee93321f76c4b6a Mon Sep 17 00:00:00 2001 From: logicat <35831253+ca1ic0@users.noreply.github.com> Date: Sat, 17 Jan 2026 02:40:05 +0800 Subject: [PATCH 18/99] Remove unnecessary hip_fp16 include from stream_config (#3549) --- include/ck/stream_config.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/include/ck/stream_config.hpp b/include/ck/stream_config.hpp index 896c048781..ea1c15b1aa 100644 --- a/include/ck/stream_config.hpp +++ b/include/ck/stream_config.hpp @@ -4,7 +4,6 @@ #pragma once #include -#include struct StreamConfig { From 3f735c127b8e78b702a31e19cb6e0e35eda3588a Mon Sep 17 00:00:00 2001 From: Johannes Graner Date: Fri, 16 Jan 2026 19:56:53 +0100 Subject: [PATCH 19/99] [CK Profiler] Restore CPU tensor initialization when verification is not done on GPU (#3594) * Fix large case init bounds * Revert "Fix large case init bounds" This reverts commit 1abca05c6f71ff6fee83fa870d0c84d86279bb70. * Restore CPU initialization for do_verification != 2 --- .../profile_grouped_conv_bwd_data_impl.hpp | 93 +++++++++++-------- .../profile_grouped_conv_bwd_weight_impl.hpp | 74 +++++++++------ .../profile_grouped_conv_fwd_impl.hpp | 86 ++++++++++------- 3 files changed, 152 insertions(+), 101 deletions(-) diff --git a/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp b/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp index 20bbd58f61..eceb70c05f 100644 --- a/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp @@ -62,7 +62,13 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification, std::cout << "wei: " << wei_g_k_c_xs_desc << std::endl; std::cout << "in: " << in_g_n_c_wis_desc << std::endl; - // Get element space sizes + // Create host tensors + Tensor out(out_g_n_k_wos_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor in_host(in_g_n_c_wis_desc); + Tensor in_device(in_g_n_c_wis_desc); + + // Get element space sizes for allocation const auto out_element_space_size = out_g_n_k_wos_desc.GetElementSpaceSize(); const auto wei_element_space_size = wei_g_k_c_xs_desc.GetElementSpaceSize(); const auto in_element_space_size = in_g_n_c_wis_desc.GetElementSpaceSize(); @@ -72,48 +78,57 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification, DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_element_space_size); DeviceMem in_device_buf(sizeof(InDataType) * in_element_space_size); - // Generate data directly on GPU using DeviceMem methods - switch(init_method) + // Initialize tensors based on do_verification: + // - do_verification=2: GPU-side initialization + // - do_verification=0,1: CPU-side initialization + if(do_verification == 2) { - case 0: - // Zero initialization - out_device_buf.SetZero(); - wei_device_buf.SetZero(); - break; - case 1: - // Discrete integer values in range [-5, 5] - out_device_buf.FillUniformRandInteger(-5, 5); - wei_device_buf.FillUniformRandInteger(-5, 5); - break; - case 2: - // Continuous float values - out_device_buf.FillUniformRandFp(0.0f, 1.0f); - wei_device_buf.FillUniformRandFp(-0.5f, 0.5f); - break; - default: - // Constant value 1 - out_device_buf.SetValue(ck::type_convert(1)); - wei_device_buf.SetValue(ck::type_convert(1)); + // GPU-side initialization for GPU verification workflow + switch(init_method) + { + case 0: + // Zero initialization + out_device_buf.SetZero(); + wei_device_buf.SetZero(); + break; + case 1: + // Discrete integer values in range [-5, 5] + out_device_buf.FillUniformRandInteger(-5, 5); + wei_device_buf.FillUniformRandInteger(-5, 5); + break; + case 2: + // Continuous float values + out_device_buf.FillUniformRandFp(0.0f, 1.0f); + wei_device_buf.FillUniformRandFp(-0.5f, 0.5f); + break; + default: + // Constant value 1 + out_device_buf.SetValue(ck::type_convert(1)); + wei_device_buf.SetValue(ck::type_convert(1)); + } } - - // Create host tensors (needed only for verification) - Tensor out(out_g_n_k_wos_desc); - Tensor wei(wei_g_k_c_xs_desc); - Tensor in_host(in_g_n_c_wis_desc); - Tensor in_device(in_g_n_c_wis_desc); - - // Copy GPU→CPU only if verification is enabled - if(do_verification == 1 || do_verification == 2) + else { - out_device_buf.FromDevice(out.mData.data()); - wei_device_buf.FromDevice(wei.mData.data()); - } + // CPU-side initialization for do_verification=0,1 + switch(init_method) + { + case 0: break; // Tensors are already zero-initialized by default + case 1: + out.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + out.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + out.GenerateTensorValue(GeneratorTensor_1{1}); + wei.GenerateTensorValue(GeneratorTensor_1{1}); + } - // Copy to host only if CPU verification is needed - if(do_verification == 1) - { - out_device_buf.FromDevice(out.mData.data()); - wei_device_buf.FromDevice(wei.mData.data()); + // Copy initialized host data to device + out_device_buf.ToDevice(out.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); } // Allocate GPU reference buffer (used only if do_verification == 2) diff --git a/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp b/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp index f1498f4c2d..3a9f14e595 100644 --- a/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp @@ -67,7 +67,13 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, std::cout << "weight: " << wei_g_k_c_xs_desc << std::endl; std::cout << "output: " << out_g_n_k_wos_desc << std::endl; - // Get element space sizes + // Create host tensors + Tensor input(in_g_n_c_wis_desc); + Tensor weight_host_result(wei_g_k_c_xs_desc); + Tensor weight_device_result(wei_g_k_c_xs_desc); + Tensor output(out_g_n_k_wos_desc); + + // Get element space sizes for allocation const auto input_element_space_size = in_g_n_c_wis_desc.GetElementSpaceSize(); const auto weight_element_space_size = wei_g_k_c_xs_desc.GetElementSpaceSize(); const auto output_element_space_size = out_g_n_k_wos_desc.GetElementSpaceSize(); @@ -77,36 +83,48 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, DeviceMem wei_device_buf(sizeof(WeiDataType) * weight_element_space_size); DeviceMem out_device_buf(sizeof(OutDataType) * output_element_space_size); - // Generate data directly on GPU using DeviceMem methods - switch(init_method) + // Initialize tensors based on do_verification: + // - do_verification=2: GPU-side initialization + // - do_verification=0,1: CPU-side initialization + if(do_verification == 2) { - case 0: - // Zero initialization - in_device_buf.SetZero(); - out_device_buf.SetZero(); - break; - case 1: - // Discrete integer values in range [-5, 5] - in_device_buf.FillUniformRandInteger(-5, 5); - out_device_buf.FillUniformRandInteger(-5, 5); - break; - default: - // Continuous float values - in_device_buf.FillUniformRandFp(0.0f, 1.0f); - out_device_buf.FillUniformRandFp(-0.5f, 0.5f); + // GPU-side initialization for GPU verification workflow + switch(init_method) + { + case 0: + // Zero initialization + in_device_buf.SetZero(); + out_device_buf.SetZero(); + break; + case 1: + // Discrete integer values in range [-5, 5] + in_device_buf.FillUniformRandInteger(-5, 5); + out_device_buf.FillUniformRandInteger(-5, 5); + break; + default: + // Continuous float values + in_device_buf.FillUniformRandFp(0.0f, 1.0f); + out_device_buf.FillUniformRandFp(-0.5f, 0.5f); + } } - - // Create host tensors (needed only for verification) - Tensor input(in_g_n_c_wis_desc); - Tensor weight_host_result(wei_g_k_c_xs_desc); - Tensor weight_device_result(wei_g_k_c_xs_desc); - Tensor output(out_g_n_k_wos_desc); - - // Copy to host only if CPU verification is needed - if(do_verification == 1) + else { - in_device_buf.FromDevice(input.mData.data()); - out_device_buf.FromDevice(output.mData.data()); + // CPU-side initialization for do_verification=0,1 + switch(init_method) + { + case 0: break; // Tensors are already zero-initialized by default + case 1: + input.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + output.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + input.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + output.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + // Copy initialized host data to device + in_device_buf.ToDevice(input.mData.data()); + out_device_buf.ToDevice(output.mData.data()); } // Allocate GPU reference buffer (used only if do_verification == 2) diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp index 54bb66c42e..bbafdee417 100644 --- a/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp @@ -86,50 +86,68 @@ bool profile_grouped_conv_fwd_impl(int do_verification, copy(conv_param.input_left_pads_, input_left_pads); copy(conv_param.input_right_pads_, input_right_pads); - // Get element space sizes for GPU allocation - const auto input_size = in_g_n_c_wis_desc.GetElementSpaceSize(); - const auto weight_size = wei_g_k_c_xs_desc.GetElementSpaceSize(); - const auto output_size = out_g_n_k_wos_desc.GetElementSpaceSize(); - std::cout << "input: " << in_g_n_c_wis_desc << std::endl; std::cout << "weight: " << wei_g_k_c_xs_desc << std::endl; std::cout << "output: " << out_g_n_k_wos_desc << std::endl; - // Allocate GPU memory first (GPU-first workflow) - DeviceMem in_device_buf(sizeof(InDataType) * input_size); - DeviceMem wei_device_buf(sizeof(WeiDataType) * weight_size); - DeviceMem out_device_buf(sizeof(OutDataType) * output_size); - - // Generate data directly on GPU using DeviceMem methods - switch(init_method) - { - case 0: - // Zero initialization - in_device_buf.SetZero(); - wei_device_buf.SetZero(); - break; - case 1: - // Discrete integer generation: {-5, -4, -3, ..., 3, 4} - in_device_buf.FillUniformRandInteger(-5, 5); - wei_device_buf.FillUniformRandInteger(-5, 5); - break; - default: - // Continuous float generation - in_device_buf.FillUniformRandFp(0.0f, 1.0f); - wei_device_buf.FillUniformRandFp(-0.5f, 0.5f); - } - - // Create host tensors (for verification if needed) + // Create host tensors Tensor input(in_g_n_c_wis_desc); Tensor weight(wei_g_k_c_xs_desc); Tensor host_output(out_g_n_k_wos_desc); Tensor device_output(out_g_n_k_wos_desc); - // Copy to host only if CPU verification is needed - if(do_verification == 1) + // Get element space sizes for allocation + const auto input_size = in_g_n_c_wis_desc.GetElementSpaceSize(); + const auto weight_size = wei_g_k_c_xs_desc.GetElementSpaceSize(); + const auto output_size = out_g_n_k_wos_desc.GetElementSpaceSize(); + + // Allocate GPU memory + DeviceMem in_device_buf(sizeof(InDataType) * input_size); + DeviceMem wei_device_buf(sizeof(WeiDataType) * weight_size); + DeviceMem out_device_buf(sizeof(OutDataType) * output_size); + + // Initialize tensors based on do_verification: + // - do_verification=2: GPU-side initialization + // - do_verification=0,1: CPU-side initialization + if(do_verification == 2) { - in_device_buf.FromDevice(input.mData.data()); - wei_device_buf.FromDevice(weight.mData.data()); + // GPU-side initialization for GPU verification workflow + switch(init_method) + { + case 0: + // Zero initialization + in_device_buf.SetZero(); + wei_device_buf.SetZero(); + break; + case 1: + // Discrete integer generation: {-5, -4, -3, ..., 3, 4} + in_device_buf.FillUniformRandInteger(-5, 5); + wei_device_buf.FillUniformRandInteger(-5, 5); + break; + default: + // Continuous float generation + in_device_buf.FillUniformRandFp(0.0f, 1.0f); + wei_device_buf.FillUniformRandFp(-0.5f, 0.5f); + } + } + else + { + // CPU-side initialization for do_verification=0,1 + switch(init_method) + { + case 0: break; // Tensors are already zero-initialized by default + case 1: + input.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + weight.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + input.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + weight.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + // Copy initialized host data to device + in_device_buf.ToDevice(input.mData.data()); + wei_device_buf.ToDevice(weight.mData.data()); } // Allocate GPU reference buffer (used only if do_verification == 2) From f9104ef9b3b794f8e02757cbf2935818f5389dac Mon Sep 17 00:00:00 2001 From: Cong Ma <142121551+CongMa13@users.noreply.github.com> Date: Fri, 16 Jan 2026 16:27:39 -0700 Subject: [PATCH 20/99] [CK TILE QUANT GEMM] use OverrideADataType in aquant pipeline (#3584) --- .../gemm_aquant_pipeline_ag_bg_cr_mem.hpp | 23 ++++++---- .../gemm_aquant_pipeline_ag_bg_cr_v3.hpp | 44 +++++++++---------- 2 files changed, 35 insertions(+), 32 deletions(-) diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp index 2f6497fdba..650cd947f7 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp @@ -28,7 +28,11 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem using BDataType = remove_cvref_t; using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; - using QuantGroupSize = remove_cvref_t; + using QuantGroupSize = remove_cvref_t; + // When ADataType is pk_int4_t, use BDataType instead for transpose operations + // since packed 4-bit integers cannot be directly transposed (requires at least 8-bit precision) + using OverrideADataType = + std::conditional_t, BDataType, ADataType>; static_assert(QuantGroupSize::kM == 1, "no block for M supported yet!"); static_assert(QuantGroupSize::kN == 1, "only M/K blocks for AQuant kernel!"); @@ -228,9 +232,10 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem "B block window has incorrect lengths for defined BLayout!"); // A/B tiles in LDS - using the same approach as regular gemm pipeline - auto ab_lds_blocks = Base::template GetABLdsTensorViews(p_smem); - auto& a_lds_block = ab_lds_blocks.at(I0{}); - auto& b_lds_block = ab_lds_blocks.at(I1{}); + auto ab_lds_blocks = + Base::template GetABLdsTensorViews(p_smem); + auto& a_lds_block = ab_lds_blocks.at(I0{}); + auto& b_lds_block = ab_lds_blocks.at(I1{}); // Tile distribution for load from lds constexpr auto a_lds_load_tile_distr = @@ -260,7 +265,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem using AQBlockTileDistr = decltype(aq_copy_dram_window.get_tile_distribution()); using ABlockTile = - decltype(make_static_distributed_tensor(ABlockTileDistr{})); + decltype(make_static_distributed_tensor(ABlockTileDistr{})); using BBlockTile = decltype(make_static_distributed_tensor(BBlockTileDistr{})); using AQBlockTile = @@ -295,7 +300,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem // LDS prefill - VGPRs to LDS if constexpr(is_a_col_major && !is_a_load_tr_v()) { - auto a_shuffle_tmp = make_static_distributed_tensor( + auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(I0{})); Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); @@ -346,7 +351,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem // Prepare next iteration data if constexpr(is_a_col_major && !is_a_load_tr_v()) { - auto a_shuffle_tmp = make_static_distributed_tensor( + auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d( a_shuffle_tmp, @@ -406,7 +411,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem if constexpr(is_a_col_major && !is_a_load_tr_v()) { - auto a_shuffle_tmp = make_static_distributed_tensor( + auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(number{})); @@ -494,7 +499,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem return PipelineImpl{} .template operator()( a_dram_block_window_tmp, - [](const BDataType& a) { return a; }, + [](const OverrideADataType& a) { return a; }, b_dram_block_window_tmp, [](const BDataType& b) { return b; }, aq_dram_block_window_tmp, diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp index 22dd78e070..71e4a74400 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp @@ -25,7 +25,11 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3; using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; - using QuantGroupSize = remove_cvref_t; + using QuantGroupSize = remove_cvref_t; + // When ADataType is pk_int4_t, use BDataType instead for transpose operations + // since packed 4-bit integers cannot be directly transposed (requires at least 8-bit precision) + using OverrideADataType = + std::conditional_t, BDataType, ADataType>; static_assert(QuantGroupSize::kM == 1, "no block for M supported yet!"); static_assert(QuantGroupSize::kN == 1, "only M/K blocks for AQuant kernel!"); @@ -164,14 +168,17 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 - CK_TILE_DEVICE static void LoadAndConvertATile(ABlockTile_& a_block_tile, - const ADramWindow& a_dram_window) + template + CK_TILE_DEVICE static void + LoadAndConvertATile(ABlockTile_& a_block_tile, + ADramWindow& a_dram_window, + const DramTileWindowStep& dram_tile_window_step) { using DestDataType = typename ABlockTile_::DataType; using SrcDataType = typename ADramWindow::Base::TileWindowBase::DataType; constexpr index_t UnaryOpSize = 8; load_int4_tile(a_block_tile, a_dram_window); + move_tile_window(a_dram_window, dram_tile_window_step); } template (p_smem); + Base::template GetABLdsTensorViews(p_smem); constexpr auto a_lds_load_tile_distr = make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); @@ -241,11 +248,8 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(ABlockTileDistr{})); + decltype(make_static_distributed_tensor(ABlockTileDistr{})); using BBlockTile = decltype(make_static_distributed_tensor(BBlockTileDistr{})); using AQBlockTile = @@ -274,8 +278,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( + auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, a_block_tile); Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); @@ -306,8 +309,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( + auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, a_block_tile); Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); @@ -349,8 +351,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( + auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, a_block_tile); Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); @@ -430,10 +431,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3{}.template operator()( a_dram_block_window_tmp, - // Note: a_element_func takes BDataType (not ADataType) because A tiles are - // converted from ADataType (e.g., pk_int4_t) to BDataType (e.g., fp8) in - // LoadAndConvertATile before the element function is applied. - [](const BDataType& a) { return a; }, + [](const OverrideADataType& a) { return a; }, b_dram_block_window_tmp, [](const BDataType& b) { return b; }, aq_dram_block_window_tmp, @@ -476,7 +474,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3{}.template operator()( a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, + [](const OverrideADataType& a) { return a; }, b_dram_block_window_tmp, [](const BDataType& b) { return b; }, aq_dram_block_window_tmp, From fe40a5d13941b64162cffce9496d1d94a90f80a5 Mon Sep 17 00:00:00 2001 From: Erwin Terpstra Date: Sat, 17 Jan 2026 08:30:27 +0100 Subject: [PATCH 21/99] Implement batched gemm bias permute for RDNA4 (#3534) * feat: test setup for batched contraction (aka batched gemm multiple d e permute) * wip: device struct for WMMA batched contraction multiple d based on new gridwise op * feat: working batched contraction on RDNA, non-naive tensor descriptors for gridwise_gemm_wmma_cshuffle_v3, test setup for odd cases * fix: failure to resolve template parameters when calling new function overload * fix: passing reference type as parameter instead of underlying types * fix: merge error caused duplicate definitions * fix: make sure constness of template and parameters types match * fix: don't compile batched contraction test on unsupported architectures * feat: add example for new wmma implementation, and consolidate example code between platforms * style: return inline instead of with branch * chore: add extra assert on vector memory access sizes * chore: clean up some unused variables * fix: correct tail number calculation, added small cases and extra instances to the test * fix: properly support wave transfer by generating correct grid descriptors dependent on the transfer method --- .../gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp | 168 +-- .../gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp | 169 +--- .../CMakeLists.txt | 1 + .../batched_gemm_bias_e_permute_wmma_fp16.cpp | 351 +------ ...tched_gemm_bias_e_permute_wmma_v3_fp16.cpp | 111 ++ .../batched_gemm_bias_e_permute_xdl_fp16.cpp | 339 +------ ...un_batched_gemm_bias_e_permute_example.inc | 350 +++++++ ...ontraction_multiple_d_wmma_cshuffle_v3.hpp | 956 ++++++++++++++++++ .../grid/gridwise_gemm_wmma_cshuffle_v3.hpp | 101 +- .../gridwise_gemm_wmma_cshuffle_v3_common.hpp | 76 +- .../cpu/reference_contraction.hpp | 273 +++++ .../gpu/batched_gemm_bias_permute.hpp | 25 + .../batched_gemm_bias_permute/CMakeLists.txt | 3 +- ...mma_c_shuffle_f16_f16_f16_f16_instance.cpp | 78 ++ ...le_batched_contraction_multiple_d_impl.hpp | 309 ++++++ test/CMakeLists.txt | 1 + test/batched_contraction/CMakeLists.txt | 9 + .../test_batched_contraction.cpp | 164 +++ 18 files changed, 2475 insertions(+), 1009 deletions(-) create mode 100644 example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_v3_fp16.cpp create mode 100644 example/29_batched_gemm_bias_e_permute/run_batched_gemm_bias_e_permute_example.inc create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle_v3.hpp create mode 100644 library/src/tensor_operation_instance/gpu/batched_gemm_bias_permute/device_batched_gemm_bias_permute_m2_n3_k1_wmma_c_shuffle_f16_f16_f16_f16_instance.cpp create mode 100644 profiler/include/profiler/profile_batched_contraction_multiple_d_impl.hpp create mode 100644 test/batched_contraction/CMakeLists.txt create mode 100644 test/batched_contraction/test_batched_contraction.cpp diff --git a/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp b/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp index f7663cbd0a..6295cfdd04 100644 --- a/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp +++ b/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp @@ -17,7 +17,7 @@ #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/numeric.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" using ::ck::DeviceMem; using ::ck::HostTensorDescriptor; @@ -69,142 +69,6 @@ using DeviceOpInstanceKKNN = ck::tensor_operation::device:: using DeviceOpInstance = DeviceOpInstanceKKNN; -// hardcoded for NumDimM == NumDimN == NumDimK == 2 -template = - false> -struct ReferenceContraction_G1_M2_N3_K1 : public ck::tensor_operation::device::BaseOperator -{ - // Argument - struct Argument : public ck::tensor_operation::device::BaseArgument - { - Argument(const Tensor& a_gs_ms_ks, - const Tensor& b_gs_ns_ks, - Tensor& e_gs_ms_ns, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) - : a_gs_ms_ks_{a_gs_ms_ks}, - b_gs_ns_ks_{b_gs_ns_ks}, - e_gs_ms_ns_{e_gs_ms_ns}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - cde_element_op_{cde_element_op} - { - } - - const Tensor& a_gs_ms_ks_; - const Tensor& b_gs_ns_ks_; - Tensor& e_gs_ms_ns_; - - AElementwiseOperation a_element_op_; - BElementwiseOperation b_element_op_; - CDEElementwiseOperation cde_element_op_; - }; - - // Invoker - struct Invoker : public ck::tensor_operation::device::BaseInvoker - { - using Argument = ReferenceContraction_G1_M2_N3_K1::Argument; - - float Run(const Argument& arg) - { - auto f_gs_ms_ns = [&](auto g0, auto m0, auto m1, auto n0, auto n1, auto n2) { - const int K0 = arg.a_gs_ms_ks_.mDesc.GetLengths()[3]; - - AccDataType v_acc = 0; - - for(int k0 = 0; k0 < K0; ++k0) - { - AccDataType v_a; - AccDataType v_b; - - arg.a_element_op_( - v_a, ck::type_convert(arg.a_gs_ms_ks_(g0, m0, m1, k0))); - arg.b_element_op_( - v_b, - ck::type_convert(arg.b_gs_ns_ks_(g0, n0, n1, n2, k0))); - - v_acc += v_a * v_b; - } - - AccDataType v_c; - - arg.cde_element_op_(v_c, v_acc); - - arg.e_gs_ms_ns_(g0, m0, m1, n0, n1, n2) = v_c; - }; - - make_ParallelTensorFunctor(f_gs_ms_ns, - arg.e_gs_ms_ns_.mDesc.GetLengths()[0], - arg.e_gs_ms_ns_.mDesc.GetLengths()[1], - arg.e_gs_ms_ns_.mDesc.GetLengths()[2], - arg.e_gs_ms_ns_.mDesc.GetLengths()[3], - arg.e_gs_ms_ns_.mDesc.GetLengths()[4], - arg.e_gs_ms_ns_.mDesc.GetLengths()[5])( - std::thread::hardware_concurrency()); - - return 0; - } - - float Run(const ck::tensor_operation::device::BaseArgument* p_arg, - const StreamConfig& /* stream_config */ = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg)); - } - }; - - static constexpr bool IsValidCompilationParameter() - { - // TODO: properly implement this check - return true; - } - - bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override - { - return true; - } - - static auto MakeArgument(const Tensor& a_gs_ms_ks, - const Tensor& b_gs_ns_ks, - Tensor& e_gs_ms_ns, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) - { - return Argument{ - a_gs_ms_ks, b_gs_ns_ks, e_gs_ms_ns, a_element_op, b_element_op, cde_element_op}; - } - - static auto MakeInvoker() { return Invoker{}; } - - virtual std::unique_ptr MakeInvokerPointer() - { - return std::make_unique(Invoker{}); - } - - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - // clang-format off - str << "ReferenceContraction_M3_N2_K1" - << std::endl; - // clang-format on - - return str.str(); - } -}; - int main(int argc, char* argv[]) { bool do_verification = true; @@ -353,16 +217,18 @@ int main(int argc, char* argv[]) Tensor c_gs_ms_ns_host_result( e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); - using ReferenceOpInstance = ReferenceContraction_G1_M2_N3_K1; + using ReferenceOpInstance = + ck::tensor_operation::host::ReferenceBatchedContraction_G1_M2_N3_K1; auto ref_gemm = ReferenceOpInstance{}; auto ref_invoker = ref_gemm.MakeInvoker(); @@ -399,7 +265,13 @@ int main(int argc, char* argv[]) } } - return ck::utils::check_err(e_gs_ms_ns_device_result, e_gs_ms_ns_host_result) ? 0 : 1; + bool pass = ck::utils::check_err(e_gs_ms_ns_device_result, e_gs_ms_ns_host_result); + std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl; + + if(!pass) + { + return 1; + } } return 0; diff --git a/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp b/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp index 736dc09867..3adfecc7ae 100644 --- a/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp +++ b/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp @@ -17,6 +17,8 @@ #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/numeric.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" + using ::ck::DeviceMem; using ::ck::HostTensorDescriptor; using ::ck::make_ParallelTensorFunctor; @@ -67,142 +69,6 @@ using DeviceOpInstanceKKNN = ck::tensor_operation::device:: using DeviceOpInstance = DeviceOpInstanceKKNN; -template = - false> -struct ReferenceContraction_G1_M3_N2_K1 : public ck::tensor_operation::device::BaseOperator -{ - // Argument - struct Argument : public ck::tensor_operation::device::BaseArgument - { - Argument(const Tensor& a_gs_ms_ks, - const Tensor& b_gs_ns_ks, - Tensor& e_gs_ms_ns, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) - : a_gs_ms_ks_{a_gs_ms_ks}, - b_gs_ns_ks_{b_gs_ns_ks}, - e_gs_ms_ns_{e_gs_ms_ns}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - cde_element_op_{cde_element_op} - { - } - - const Tensor& a_gs_ms_ks_; - const Tensor& b_gs_ns_ks_; - Tensor& e_gs_ms_ns_; - - AElementwiseOperation a_element_op_; - BElementwiseOperation b_element_op_; - CDEElementwiseOperation cde_element_op_; - }; - - // Invoker - struct Invoker : public ck::tensor_operation::device::BaseInvoker - { - using Argument = ReferenceContraction_G1_M3_N2_K1::Argument; - - float Run(const Argument& arg) - { - auto f_gs_ms_ns = [&](auto g0, auto m0, auto m1, auto m2, auto n0, auto n1) { - const int K0 = arg.a_gs_ms_ks_.mDesc.GetLengths()[4]; - - AccDataType v_acc = 0; - - for(int k0 = 0; k0 < K0; ++k0) - { - AccDataType v_a; - AccDataType v_b; - - arg.a_element_op_( - v_a, - ck::type_convert(arg.a_gs_ms_ks_(g0, m0, m1, m2, k0))); - arg.b_element_op_( - v_b, ck::type_convert(arg.b_gs_ns_ks_(g0, n0, n1, k0))); - - v_acc += v_a * v_b; - } - - AccDataType v_c; - - arg.cde_element_op_(v_c, v_acc); - - arg.e_gs_ms_ns_(g0, m0, m1, m2, n0, n1) = v_c; - }; - - make_ParallelTensorFunctor(f_gs_ms_ns, - arg.e_gs_ms_ns_.mDesc.GetLengths()[0], - arg.e_gs_ms_ns_.mDesc.GetLengths()[1], - arg.e_gs_ms_ns_.mDesc.GetLengths()[2], - arg.e_gs_ms_ns_.mDesc.GetLengths()[3], - arg.e_gs_ms_ns_.mDesc.GetLengths()[4], - arg.e_gs_ms_ns_.mDesc.GetLengths()[5])( - std::thread::hardware_concurrency()); - - return 0; - } - - float Run(const ck::tensor_operation::device::BaseArgument* p_arg, - const StreamConfig& /* stream_config */ = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg)); - } - }; - - static constexpr bool IsValidCompilationParameter() - { - // TODO: properly implement this check - return true; - } - - bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override - { - return true; - } - - static auto MakeArgument(const Tensor& a_gs_ms_ks, - const Tensor& b_gs_ns_ks, - Tensor& e_gs_ms_ns, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) - { - return Argument{ - a_gs_ms_ks, b_gs_ns_ks, e_gs_ms_ns, a_element_op, b_element_op, cde_element_op}; - } - - static auto MakeInvoker() { return Invoker{}; } - - virtual std::unique_ptr MakeInvokerPointer() - { - return std::make_unique(Invoker{}); - } - - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - // clang-format off - str << "ReferenceContraction_G1_M3_N2_K1" - << std::endl; - // clang-format on - - return str.str(); - } -}; - int main(int argc, char* argv[]) { bool do_verification = true; @@ -353,17 +219,18 @@ int main(int argc, char* argv[]) Tensor c_gs_ms_ns_host_result( e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); - using ReferenceOpInstance = ReferenceContraction_G1_M3_N2_K1; + using ReferenceOpInstance = + ck::tensor_operation::host::ReferenceBatchedContraction_G1_M3_N2_K1; auto ref_gemm = ReferenceOpInstance{}; auto ref_invoker = ref_gemm.MakeInvoker(); @@ -400,7 +267,13 @@ int main(int argc, char* argv[]) } } - return ck::utils::check_err(e_gs_ms_ns_device_result, e_gs_ms_ns_host_result) ? 0 : 1; + bool pass = ck::utils::check_err(e_gs_ms_ns_device_result, e_gs_ms_ns_host_result); + std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl; + + if(!pass) + { + return 1; + } } return 0; diff --git a/example/29_batched_gemm_bias_e_permute/CMakeLists.txt b/example/29_batched_gemm_bias_e_permute/CMakeLists.txt index d5d5521370..6cf93215f8 100644 --- a/example/29_batched_gemm_bias_e_permute/CMakeLists.txt +++ b/example/29_batched_gemm_bias_e_permute/CMakeLists.txt @@ -3,3 +3,4 @@ add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp) add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp) +add_example_executable(example_batched_gemm_bias_e_permute_wmma_v3_fp16 batched_gemm_bias_e_permute_wmma_v3_fp16.cpp) diff --git a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp index 06bf971ac4..f102a0b132 100644 --- a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp +++ b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp @@ -106,352 +106,5 @@ using DeviceOpInstanceKKNN = using DeviceOpInstance = DeviceOpInstanceKKNN; -// hardcoded for NumDimM == NumDimN == NumDimK == 2 -template = - false> -struct ReferenceContraction_G2_M2_N2_K1 : public ck::tensor_operation::device::BaseOperator -{ - // Argument - struct Argument : public ck::tensor_operation::device::BaseArgument - { - Argument(const Tensor& a_gs_ms_ks, - const Tensor& b_gs_ns_ks, - Tensor& e_gs_ms_ns, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) - : a_gs_ms_ks_{a_gs_ms_ks}, - b_gs_ns_ks_{b_gs_ns_ks}, - e_gs_ms_ns_{e_gs_ms_ns}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - cde_element_op_{cde_element_op} - { - } - - const Tensor& a_gs_ms_ks_; - const Tensor& b_gs_ns_ks_; - Tensor& e_gs_ms_ns_; - - AElementwiseOperation a_element_op_; - BElementwiseOperation b_element_op_; - CDEElementwiseOperation cde_element_op_; - }; - - // Invoker - struct Invoker : public ck::tensor_operation::device::BaseInvoker - { - using Argument = ReferenceContraction_G2_M2_N2_K1::Argument; - - float Run(const Argument& arg) - { - auto f_ms_ns = [&](auto g0, auto g1, auto m0, auto m1, auto n0, auto n1) { - const int K0 = arg.a_gs_ms_ks_.mDesc.GetLengths()[4]; - - AccDataType v_acc = 0; - - for(int k0 = 0; k0 < K0; ++k0) - { - AccDataType v_a; - AccDataType v_b; - - arg.a_element_op_( - v_a, - ck::type_convert(arg.a_gs_ms_ks_(g0, g1, m0, m1, k0))); - arg.b_element_op_( - v_b, - ck::type_convert(arg.b_gs_ns_ks_(g0, g1, n0, n1, k0))); - - v_acc += v_a * v_b; - } - - AccDataType v_c; - - arg.cde_element_op_(v_c, v_acc); - - arg.e_gs_ms_ns_(g0, g1, m0, m1, n0, n1) = v_c; - }; - - make_ParallelTensorFunctor(f_ms_ns, - arg.e_gs_ms_ns_.mDesc.GetLengths()[0], - arg.e_gs_ms_ns_.mDesc.GetLengths()[1], - arg.e_gs_ms_ns_.mDesc.GetLengths()[2], - arg.e_gs_ms_ns_.mDesc.GetLengths()[3], - arg.e_gs_ms_ns_.mDesc.GetLengths()[4], - arg.e_gs_ms_ns_.mDesc.GetLengths()[5])( - std::thread::hardware_concurrency()); - - return 0; - } - - float Run(const ck::tensor_operation::device::BaseArgument* p_arg, - const StreamConfig& /* stream_config */ = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg)); - } - }; - - static constexpr bool IsValidCompilationParameter() - { - // TODO: properly implement this check - return true; - } - - bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override - { - return true; - } - - static auto MakeArgument(const Tensor& a_gs_ms_ks, - const Tensor& b_gs_ns_ks, - Tensor& e_gs_ms_ns, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) - { - return Argument{ - a_gs_ms_ks, b_gs_ns_ks, e_gs_ms_ns, a_element_op, b_element_op, cde_element_op}; - } - - static auto MakeInvoker() { return Invoker{}; } - - virtual std::unique_ptr MakeInvokerPointer() - { - return std::make_unique(Invoker{}); - } - - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - // clang-format off - str << "ReferenceContraction_G2_M2_N2_K1" - << std::endl; - // clang-format on - - return str.str(); - } -}; - -int main(int argc, char* argv[]) -{ - bool do_verification = true; - int init_method = 1; - bool time_kernel = false; - - ck::index_t G0 = 1; - ck::index_t G1 = 2; - - ck::index_t M0 = 4; - ck::index_t M1 = 128; - - ck::index_t N0 = 16; - ck::index_t N1 = 256; - - ck::index_t K0 = 2048; - - if(argc == 1) - { - // use default case - } - else if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else if(argc == 11) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - G0 = std::stoi(argv[4]); - G1 = std::stoi(argv[5]); - M0 = std::stoi(argv[6]); - M1 = std::stoi(argv[7]); - N0 = std::stoi(argv[8]); - N1 = std::stoi(argv[9]); - K0 = std::stoi(argv[10]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=no, 1=yes)\n"); - printf("arg4-10: G0, G1, M0, M1, N0, N1, K0\n"); - exit(0); - } - - // A[G0, G1, M0, M1, K0] - std::vector a_gs_ms_ks_lengths{G0, G1, M0, M1, K0}; - std::vector a_gs_ms_ks_strides{G1 * M0 * M1 * K0, M0 * M1 * K0, M1 * K0, K0, 1}; - // B[G0, G1, N0, N1, K0] - std::vector b_gs_ns_ks_lengths{G0, G1, N0, N1, K0}; - std::vector b_gs_ns_ks_strides{G1 * N0 * N1 * K0, N0 * N1 * K0, N1 * K0, K0, 1}; - - // D[G0, G1, M0, N0, M1, N1] - std::vector d_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1}; - std::vector d_gs_ms_ns_strides{G1 * N0 * N1, N0 * N1, 0, 0, N1, 1}; - // E[G0, G1, M0, N0, M1, N1] - std::vector e_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1}; - std::vector e_gs_ms_ns_strides{ - G1 * M0 * N0 * M1 * N1, M0 * N0 * M1 * N1, N0 * M1 * N1, N1, M1 * N1, 1}; - - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, Row{}); - Tensor b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides, Row{}); - Tensor d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides, Bypass{}); - Tensor e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); - Tensor e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); - std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; - std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl; - std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.mDesc << std::endl; - std::cout << "e_gs_ms_ns: " << e_gs_ms_ns_host_result.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: - a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - default: - a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - break; - } - DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b_gs_ns_ks.mDesc.GetElementSpaceSize()); - DeviceMem d_device_buf(sizeof(DDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize()); - DeviceMem e_device_buf(sizeof(EDataType) * - e_gs_ms_ns_device_result.mDesc.GetElementSpaceSize()); - - a_device_buf.ToDevice(a_gs_ms_ks.mData.data()); - b_device_buf.ToDevice(b_gs_ns_ks.mData.data()); - d_device_buf.ToDevice(d_gs_ms_ns.mData.data()); - - // set zero - e_device_buf.SetZero(); - - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto cde_element_op = CDEElementOp{}; - - // device operation - auto op = DeviceOpInstance{}; - auto invoker = op.MakeInvoker(); - auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(), - b_device_buf.GetDeviceBuffer(), - std::array{d_device_buf.GetDeviceBuffer()}, - e_device_buf.GetDeviceBuffer(), - a_gs_ms_ks_lengths, - a_gs_ms_ks_strides, - b_gs_ns_ks_lengths, - b_gs_ns_ks_strides, - std::array, 1>{d_gs_ms_ns_lengths}, - std::array, 1>{d_gs_ms_ns_strides}, - e_gs_ms_ns_lengths, - e_gs_ms_ns_strides, - a_element_op, - b_element_op, - cde_element_op); - - if(!op.IsSupportedArgument(argument)) - { - std::cout << op.GetTypeString() << " does not support this problem" << std::endl; - - return 0; - } - - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - - ck::index_t G = - ck::accumulate_n(e_gs_ms_ns_lengths.begin(), NumDimG, 1, std::multiplies<>{}); - - ck::index_t M = ck::accumulate_n( - e_gs_ms_ns_lengths.begin() + NumDimG, NumDimM, 1, std::multiplies<>{}); - - ck::index_t N = ck::accumulate_n( - e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM, NumDimN, 1, std::multiplies<>{}); - - ck::index_t K = ck::accumulate_n( - a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM, NumDimK, 1, std::multiplies<>{}); - std::cout << "GMNK=" << G << ", " << M << ", " << N << ", " << K << std::endl; - std::size_t flop = std::size_t(2) * G * M * N * K; - std::size_t num_btype = sizeof(ADataType) * G * M * K + sizeof(BDataType) * G * K * N + - sizeof(DDataType) * G * M * N + sizeof(EDataType) * G * M * N; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << op.GetTypeString() << std::endl; - - e_device_buf.FromDevice(e_gs_ms_ns_device_result.mData.data()); - - if(do_verification) - { - Tensor c_ms_ns_host_result( - e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); - - using ReferenceOpInstance = ReferenceContraction_G2_M2_N2_K1; - - auto ref_gemm = ReferenceOpInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = ref_gemm.MakeArgument( - a_gs_ms_ks, b_gs_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{}); - - ref_invoker.Run(ref_argument); - - for(size_t g0 = 0; g0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[0]; ++g0) - { - for(size_t g1 = 0; g1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[1]; ++g1) - { - for(size_t m0 = 0; m0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[2]; ++m0) - { - for(size_t m1 = 0; m1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[3]; ++m1) - { - for(size_t n0 = 0; n0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[4]; ++n0) - { - for(size_t n1 = 0; n1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[5]; - ++n1) - { - cde_element_op(e_gs_ms_ns_host_result(g0, g1, m0, m1, n0, n1), - c_ms_ns_host_result(g0, g1, m0, m1, n0, n1), - d_gs_ms_ns(g0, g1, m0, m1, n0, n1)); - } - } - } - } - } - } - - return ck::utils::check_err(e_gs_ms_ns_device_result, e_gs_ms_ns_host_result) ? 0 : 1; - } - - return 0; -} +#include "run_batched_gemm_bias_e_permute_example.inc" +int main(int argc, char* argv[]) { return !run_batched_gemm_bias_e_permute_example(argc, argv); } diff --git a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_v3_fp16.cpp b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_v3_fp16.cpp new file mode 100644 index 0000000000..4e34f18b8b --- /dev/null +++ b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_v3_fp16.cpp @@ -0,0 +1,111 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/numeric.hpp" + +using ::ck::DeviceMem; +using ::ck::HostTensorDescriptor; +using ::ck::make_ParallelTensorFunctor; +using ::ck::Tensor; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Add = ck::tensor_operation::element_wise::Add; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F16; +using DDataType = F16; +using DsDataType = ck::Tuple; +using EDataType = F16; + +static constexpr ck::index_t NumDimG = 2; +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 1; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Add; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +static constexpr auto ASpec = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto BSpec = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto DESpec = ck::tensor_operation::device::TensorSpecialization::Default; + +using DeviceOpInstanceKKNN = + ck::tensor_operation::device::DeviceBatchedContractionMultipleD_Wmma_CShuffle_V3< + NumDimG, + NumDimM, + NumDimN, + NumDimK, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + ASpec, + BSpec, + DESpec, + 128, + 64, + 64, + 64, + 4, + 4, + 16, + 16, + 1, + 4, + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 4, + 4, + false, + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 4, + 4, + false, + 1, + 1, + S<1, 64, 1, 2>, + S<8, 8>>; + +using DeviceOpInstance = DeviceOpInstanceKKNN; + +#include "run_batched_gemm_bias_e_permute_example.inc" +int main(int argc, char* argv[]) { return !run_batched_gemm_bias_e_permute_example(argc, argv); } diff --git a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_xdl_fp16.cpp b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_xdl_fp16.cpp index d7f468bc62..4ed054faaa 100644 --- a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_xdl_fp16.cpp +++ b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_xdl_fp16.cpp @@ -67,340 +67,5 @@ using DeviceOpInstanceKKNN = ck::tensor_operation::device:: using DeviceOpInstance = DeviceOpInstanceKKNN; -// hardcoded for NumDimM == NumDimN == NumDimK == 2 -template = - false> -struct ReferenceContraction_G2_M2_N2_K1 : public ck::tensor_operation::device::BaseOperator -{ - // Argument - struct Argument : public ck::tensor_operation::device::BaseArgument - { - Argument(const Tensor& a_gs_ms_ks, - const Tensor& b_gs_ns_ks, - Tensor& e_gs_ms_ns, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) - : a_gs_ms_ks_{a_gs_ms_ks}, - b_gs_ns_ks_{b_gs_ns_ks}, - e_gs_ms_ns_{e_gs_ms_ns}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - cde_element_op_{cde_element_op} - { - } - - const Tensor& a_gs_ms_ks_; - const Tensor& b_gs_ns_ks_; - Tensor& e_gs_ms_ns_; - - AElementwiseOperation a_element_op_; - BElementwiseOperation b_element_op_; - CDEElementwiseOperation cde_element_op_; - }; - - // Invoker - struct Invoker : public ck::tensor_operation::device::BaseInvoker - { - using Argument = ReferenceContraction_G2_M2_N2_K1::Argument; - - float Run(const Argument& arg) - { - auto f_ms_ns = [&](auto g0, auto g1, auto m0, auto m1, auto n0, auto n1) { - const int K0 = arg.a_gs_ms_ks_.mDesc.GetLengths()[4]; - - AccDataType v_acc = 0; - - for(int k0 = 0; k0 < K0; ++k0) - { - AccDataType v_a; - AccDataType v_b; - - arg.a_element_op_( - v_a, - ck::type_convert(arg.a_gs_ms_ks_(g0, g1, m0, m1, k0))); - arg.b_element_op_( - v_b, - ck::type_convert(arg.b_gs_ns_ks_(g0, g1, n0, n1, k0))); - - v_acc += v_a * v_b; - } - - AccDataType v_c; - - arg.cde_element_op_(v_c, v_acc); - - arg.e_gs_ms_ns_(g0, g1, m0, m1, n0, n1) = v_c; - }; - - make_ParallelTensorFunctor(f_ms_ns, - arg.e_gs_ms_ns_.mDesc.GetLengths()[0], - arg.e_gs_ms_ns_.mDesc.GetLengths()[1], - arg.e_gs_ms_ns_.mDesc.GetLengths()[2], - arg.e_gs_ms_ns_.mDesc.GetLengths()[3], - arg.e_gs_ms_ns_.mDesc.GetLengths()[4], - arg.e_gs_ms_ns_.mDesc.GetLengths()[5])( - std::thread::hardware_concurrency()); - - return 0; - } - - float Run(const ck::tensor_operation::device::BaseArgument* p_arg, - const StreamConfig& /* stream_config */ = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg)); - } - }; - - static constexpr bool IsValidCompilationParameter() - { - // TODO: properly implement this check - return true; - } - - bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override - { - return true; - } - - static auto MakeArgument(const Tensor& a_gs_ms_ks, - const Tensor& b_gs_ns_ks, - Tensor& e_gs_ms_ns, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) - { - return Argument{ - a_gs_ms_ks, b_gs_ns_ks, e_gs_ms_ns, a_element_op, b_element_op, cde_element_op}; - } - - static auto MakeInvoker() { return Invoker{}; } - - virtual std::unique_ptr MakeInvokerPointer() - { - return std::make_unique(Invoker{}); - } - - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - // clang-format off - str << "ReferenceContraction_G2_M2_N2_K1" - << std::endl; - // clang-format on - - return str.str(); - } -}; - -int main(int argc, char* argv[]) -{ - bool do_verification = true; - int init_method = 1; - bool time_kernel = false; - - ck::index_t G0 = 1; - ck::index_t G1 = 2; - - ck::index_t M0 = 4; - ck::index_t M1 = 256; - - ck::index_t N0 = 16; - ck::index_t N1 = 128; - - ck::index_t K0 = 64; - - // A[G0, G1, M0, M1, K0] - std::vector a_gs_ms_ks_lengths{G0, G1, M0, M1, K0}; - std::vector a_gs_ms_ks_strides{G1 * M0 * M1 * K0, M0 * M1 * K0, M1 * K0, K0, 1}; - // B[G0, G1, N0, N1, K0] - std::vector b_gs_ns_ks_lengths{G0, G1, N0, N1, K0}; - std::vector b_gs_ns_ks_strides{G1 * N0 * N1 * K0, N0 * N1 * K0, N1 * K0, K0, 1}; - - // D[G0, G1, M0, N0, M1, N1] - std::vector d_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1}; - std::vector d_gs_ms_ns_strides{G1 * N0 * N1, N0 * N1, 0, 0, N1, 1}; - // E[G0, G1, M0, N0, M1, N1] - std::vector e_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1}; - std::vector e_gs_ms_ns_strides{ - G1 * M0 * N0 * M1 * N1, M0 * N0 * M1 * N1, N0 * M1 * N1, N1, M1 * N1, 1}; - - if(argc == 1) - { - // use default case - } - else if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=no, 1=yes)\n"); - exit(0); - } - - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, Row{}); - Tensor b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides, Row{}); - Tensor d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides, Bypass{}); - Tensor e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); - Tensor e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); - - std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; - std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl; - std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.mDesc << std::endl; - std::cout << "e_gs_ms_ns: " << e_gs_ms_ns_host_result.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: - a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - default: - a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - break; - } - - DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b_gs_ns_ks.mDesc.GetElementSpaceSize()); - DeviceMem d_device_buf(sizeof(DDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize()); - DeviceMem e_device_buf(sizeof(EDataType) * - e_gs_ms_ns_device_result.mDesc.GetElementSpaceSize()); - - a_device_buf.ToDevice(a_gs_ms_ks.mData.data()); - b_device_buf.ToDevice(b_gs_ns_ks.mData.data()); - d_device_buf.ToDevice(d_gs_ms_ns.mData.data()); - - // set zero - e_device_buf.SetZero(); - - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto cde_element_op = CDEElementOp{}; - - // device operation - auto op = DeviceOpInstance{}; - auto invoker = op.MakeInvoker(); - auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(), - b_device_buf.GetDeviceBuffer(), - std::array{d_device_buf.GetDeviceBuffer()}, - e_device_buf.GetDeviceBuffer(), - a_gs_ms_ks_lengths, - a_gs_ms_ks_strides, - b_gs_ns_ks_lengths, - b_gs_ns_ks_strides, - std::array, 1>{d_gs_ms_ns_lengths}, - std::array, 1>{d_gs_ms_ns_strides}, - e_gs_ms_ns_lengths, - e_gs_ms_ns_strides, - a_element_op, - b_element_op, - cde_element_op); - - if(!op.IsSupportedArgument(argument)) - { - std::cout << op.GetTypeString() << " does not support this problem" << std::endl; - - return 0; - } - - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - - ck::index_t G = - ck::accumulate_n(e_gs_ms_ns_lengths.begin(), NumDimG, 1, std::multiplies<>{}); - - ck::index_t M = ck::accumulate_n( - e_gs_ms_ns_lengths.begin() + NumDimG, NumDimM, 1, std::multiplies<>{}); - - ck::index_t N = ck::accumulate_n( - e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM, NumDimN, 1, std::multiplies<>{}); - - ck::index_t K = ck::accumulate_n( - a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM, NumDimK, 1, std::multiplies<>{}); - - std::size_t flop = std::size_t(2) * G * M * N * K; - std::size_t num_btype = sizeof(ADataType) * G * M * K + sizeof(BDataType) * G * K * N + - sizeof(DDataType) * G * M * N + sizeof(EDataType) * G * M * N; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << op.GetTypeString() << std::endl; - - e_device_buf.FromDevice(e_gs_ms_ns_device_result.mData.data()); - - if(do_verification) - { - Tensor c_ms_ns_host_result( - e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); - - using ReferenceOpInstance = ReferenceContraction_G2_M2_N2_K1; - - auto ref_gemm = ReferenceOpInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = ref_gemm.MakeArgument( - a_gs_ms_ks, b_gs_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{}); - - ref_invoker.Run(ref_argument); - - for(size_t g0 = 0; g0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[0]; ++g0) - { - for(size_t g1 = 0; g1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[1]; ++g1) - { - for(size_t m0 = 0; m0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[2]; ++m0) - { - for(size_t m1 = 0; m1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[3]; ++m1) - { - for(size_t n0 = 0; n0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[4]; ++n0) - { - for(size_t n1 = 0; n1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[5]; - ++n1) - { - cde_element_op(e_gs_ms_ns_host_result(g0, g1, m0, m1, n0, n1), - c_ms_ns_host_result(g0, g1, m0, m1, n0, n1), - d_gs_ms_ns(g0, g1, m0, m1, n0, n1)); - } - } - } - } - } - } - - return ck::utils::check_err(e_gs_ms_ns_device_result, e_gs_ms_ns_host_result) ? 0 : 1; - } - - return 0; -} +#include "run_batched_gemm_bias_e_permute_example.inc" +int main(int argc, char* argv[]) { return !run_batched_gemm_bias_e_permute_example(argc, argv); } diff --git a/example/29_batched_gemm_bias_e_permute/run_batched_gemm_bias_e_permute_example.inc b/example/29_batched_gemm_bias_e_permute/run_batched_gemm_bias_e_permute_example.inc new file mode 100644 index 0000000000..803c1eb0bf --- /dev/null +++ b/example/29_batched_gemm_bias_e_permute/run_batched_gemm_bias_e_permute_example.inc @@ -0,0 +1,350 @@ + +// hardcoded for NumDimM == NumDimN == NumDimK == 2 +template = + false> +struct ReferenceContraction_G2_M2_N2_K1 : public ck::tensor_operation::device::BaseOperator +{ + // Argument + struct Argument : public ck::tensor_operation::device::BaseArgument + { + Argument(const Tensor& a_gs_ms_ks, + const Tensor& b_gs_ns_ks, + Tensor& e_gs_ms_ns, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : a_gs_ms_ks_{a_gs_ms_ks}, + b_gs_ns_ks_{b_gs_ns_ks}, + e_gs_ms_ns_{e_gs_ms_ns}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + } + + const Tensor& a_gs_ms_ks_; + const Tensor& b_gs_ns_ks_; + Tensor& e_gs_ms_ns_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + }; + + // Invoker + struct Invoker : public ck::tensor_operation::device::BaseInvoker + { + using Argument = ReferenceContraction_G2_M2_N2_K1::Argument; + + float Run(const Argument& arg) + { + auto f_ms_ns = [&](auto g0, auto g1, auto m0, auto m1, auto n0, auto n1) { + const int K0 = arg.a_gs_ms_ks_.mDesc.GetLengths()[4]; + + AccDataType v_acc = 0; + + for(int k0 = 0; k0 < K0; ++k0) + { + AccDataType v_a; + AccDataType v_b; + + arg.a_element_op_( + v_a, + ck::type_convert(arg.a_gs_ms_ks_(g0, g1, m0, m1, k0))); + arg.b_element_op_( + v_b, + ck::type_convert(arg.b_gs_ns_ks_(g0, g1, n0, n1, k0))); + + v_acc += v_a * v_b; + } + + AccDataType v_c; + + arg.cde_element_op_(v_c, v_acc); + + arg.e_gs_ms_ns_(g0, g1, m0, m1, n0, n1) = v_c; + }; + + make_ParallelTensorFunctor(f_ms_ns, + arg.e_gs_ms_ns_.mDesc.GetLengths()[0], + arg.e_gs_ms_ns_.mDesc.GetLengths()[1], + arg.e_gs_ms_ns_.mDesc.GetLengths()[2], + arg.e_gs_ms_ns_.mDesc.GetLengths()[3], + arg.e_gs_ms_ns_.mDesc.GetLengths()[4], + arg.e_gs_ms_ns_.mDesc.GetLengths()[5])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const ck::tensor_operation::device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override + { + return true; + } + + static auto MakeArgument(const Tensor& a_gs_ms_ks, + const Tensor& b_gs_ns_ks, + Tensor& e_gs_ms_ns, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{ + a_gs_ms_ks, b_gs_ns_ks, e_gs_ms_ns, a_element_op, b_element_op, cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceContraction_G2_M2_N2_K1" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +int run_batched_gemm_bias_e_permute_example(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + ck::index_t G0 = 1; + ck::index_t G1 = 2; + + ck::index_t M0 = 4; + ck::index_t M1 = 128; + + ck::index_t N0 = 16; + ck::index_t N1 = 256; + + ck::index_t K0 = 2048; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 11) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + G0 = std::stoi(argv[4]); + G1 = std::stoi(argv[5]); + M0 = std::stoi(argv[6]); + M1 = std::stoi(argv[7]); + N0 = std::stoi(argv[8]); + N1 = std::stoi(argv[9]); + K0 = std::stoi(argv[10]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4-10: G0, G1, M0, M1, N0, N1, K0\n"); + exit(0); + } + + // A[G0, G1, M0, M1, K0] + std::vector a_gs_ms_ks_lengths{G0, G1, M0, M1, K0}; + std::vector a_gs_ms_ks_strides{G1 * M0 * M1 * K0, M0 * M1 * K0, M1 * K0, K0, 1}; + // B[G0, G1, N0, N1, K0] + std::vector b_gs_ns_ks_lengths{G0, G1, N0, N1, K0}; + std::vector b_gs_ns_ks_strides{G1 * N0 * N1 * K0, N0 * N1 * K0, N1 * K0, K0, 1}; + + // D[G0, G1, M0, N0, M1, N1] + std::vector d_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1}; + std::vector d_gs_ms_ns_strides{G1 * N0 * N1, N0 * N1, 0, 0, N1, 1}; + // E[G0, G1, M0, N0, M1, N1] + std::vector e_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1}; + std::vector e_gs_ms_ns_strides{ + G1 * M0 * N0 * M1 * N1, M0 * N0 * M1 * N1, N0 * M1 * N1, N1, M1 * N1, 1}; + + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, Row{}); + Tensor b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides, Row{}); + Tensor d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides, Bypass{}); + Tensor e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); + Tensor e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); + std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; + std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl; + std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.mDesc << std::endl; + std::cout << "e_gs_ms_ns: " << e_gs_ms_ns_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + } + DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_gs_ns_ks.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * + e_gs_ms_ns_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_gs_ms_ks.mData.data()); + b_device_buf.ToDevice(b_gs_ns_ks.mData.data()); + d_device_buf.ToDevice(d_gs_ms_ns.mData.data()); + + // set zero + e_device_buf.SetZero(); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // device operation + auto op = DeviceOpInstance{}; + auto invoker = op.MakeInvoker(); + auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + a_gs_ms_ks_lengths, + a_gs_ms_ks_strides, + b_gs_ns_ks_lengths, + b_gs_ns_ks_strides, + std::array, 1>{d_gs_ms_ns_lengths}, + std::array, 1>{d_gs_ms_ns_strides}, + e_gs_ms_ns_lengths, + e_gs_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + if(!op.IsSupportedArgument(argument)) + { + std::cout << op.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + ck::index_t G = + ck::accumulate_n(e_gs_ms_ns_lengths.begin(), NumDimG, 1, std::multiplies<>{}); + + ck::index_t M = ck::accumulate_n( + e_gs_ms_ns_lengths.begin() + NumDimG, NumDimM, 1, std::multiplies<>{}); + + ck::index_t N = ck::accumulate_n( + e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM, NumDimN, 1, std::multiplies<>{}); + + ck::index_t K = ck::accumulate_n( + a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM, NumDimK, 1, std::multiplies<>{}); + std::cout << "GMNK=" << G << ", " << M << ", " << N << ", " << K << std::endl; + std::size_t flop = std::size_t(2) * G * M * N * K; + std::size_t num_btype = sizeof(ADataType) * G * M * K + sizeof(BDataType) * G * K * N + + sizeof(DDataType) * G * M * N + sizeof(EDataType) * G * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << op.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_gs_ms_ns_device_result.mData.data()); + + if(do_verification) + { + Tensor c_ms_ns_host_result( + e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); + + using ReferenceOpInstance = ReferenceContraction_G2_M2_N2_K1; + + auto ref_gemm = ReferenceOpInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_gs_ms_ks, b_gs_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(size_t g0 = 0; g0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[0]; ++g0) + { + for(size_t g1 = 0; g1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[1]; ++g1) + { + for(size_t m0 = 0; m0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[2]; ++m0) + { + for(size_t m1 = 0; m1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[3]; ++m1) + { + for(size_t n0 = 0; n0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[4]; ++n0) + { + for(size_t n1 = 0; n1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[5]; + ++n1) + { + cde_element_op(e_gs_ms_ns_host_result(g0, g1, m0, m1, n0, n1), + c_ms_ns_host_result(g0, g1, m0, m1, n0, n1), + d_gs_ms_ns(g0, g1, m0, m1, n0, n1)); + } + } + } + } + } + } + + return ck::utils::check_err(e_gs_ms_ns_device_result, e_gs_ms_ns_host_result); + } + + return 1; +} diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle_v3.hpp new file mode 100644 index 0000000000..47ef2e339d --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle_v3.hpp @@ -0,0 +1,956 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/utility/scheduler_enum.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_contraction_multiple_d_wmma_cshuffle_v3(typename DeviceOp::Argument karg) +{ +#if(defined(__gfx11__) || defined(__gfx12__)) + static constexpr index_t NumDTensor = GridwiseOp::NumDTensor; + + const index_t g_idx = amd_wave_read_first_lane(blockIdx.y); + + const long_index_t a_batch_offset = + amd_wave_read_first_lane(karg.compute_ptr_offset_of_batch_.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(karg.compute_ptr_offset_of_batch_.GetBPtrOffset(g_idx)); + const long_index_t e_batch_offset = + amd_wave_read_first_lane(karg.compute_ptr_offset_of_batch_.GetEPtrOffset(g_idx)); + + const auto ds_batch_offset = + amd_wave_read_first_lane(karg.compute_ptr_offset_of_batch_.GetDsPtrOffset(g_idx)); + + typename GridwiseOp::AsGridPointer p_as_grid_batch{karg.p_a_grid_ + a_batch_offset}; + typename GridwiseOp::BsGridPointer p_bs_grid_batch{karg.p_b_grid_ + b_batch_offset}; + typename GridwiseOp::DsGridPointer p_ds_grid_batch; + + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_batch(i) = karg.p_ds_grid_[i] + ds_batch_offset[i]; }); + + using EpilogueType = typename std::conditional::type; + + constexpr index_t LDS_size = GridwiseOp::template GetSharedMemoryNumberOfByte(); + __shared__ char p_shared[LDS_size]; + + const auto a_grid_desc_ak0_m_ak1 = + GridwiseOp::MakeAGridDescriptor_AK0_M_AK1(karg.a_grid_desc_m_k_); + const auto b_grid_desc_bk0_n_bk1 = + GridwiseOp::MakeBGridDescriptor_BK0_N_BK1(karg.b_grid_desc_n_k_); + + auto epilogue_args = EpilogueType{}; + GridwiseOp::template Run( + p_as_grid_batch, + p_bs_grid_batch, + p_ds_grid_batch, + karg.p_e_grid_ + e_batch_offset, + p_shared, + make_tuple(a_grid_desc_ak0_m_ak1), + make_tuple(b_grid_desc_bk0_n_bk1), + karg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + karg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + karg.block_2_etile_map_, + karg.a_element_op_, + karg.b_element_op_, + karg.cde_element_op_, + epilogue_args); +#else + ignore = karg; +#endif +} + +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +// Tensor Contraction: +// input : A +// input : B +// input : D0, D1, ... +// output : E +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...] +// B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...] +// D[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] +// E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] + +// NOTE: TensorSpecialization::Packed specialized tensor is "packed" in a sense that each inner +// dimension in a dimension group (eg [G0, G1] in Gs, [M0, M1, M2] in Ms, etc.) are contiguous and +// ordered. Not in a sense that the tensor [G0, G1, ..., M0, M1, ..., N0, N1...] can be permuted +// while still being a contiguous, unpadded tensor. In other words, it merely degenerates into +// TensorSpecialization::Default with NumDimG/M/N/K = 1 +// +// Detail- Packed tensor satisfies +// stride_0 = 1 +// stride_i = stride_{i - 1} * extent_{i - 1} +// So tensor +// [G0, G1, G2, M, N] +// transposed into tensor +// [G0, G2, G1, M, N] +// with strides +// [G2 * G1 * M * N, G1 * M * N, M * N, N, 1] +// is again a packed tensor. MakeGridDescriptor() currently just merges dimensions and ignores some +// strides from input tensor extents so finer dimension information is lost. Merging dimensions is +// essentially a degenerated case of TensorSpecialization::Default with NumDimG/M/N/K = 1. +// +// Might need to expose dimension order to the interface to fully support +// TensorSpecialization::Packed in a traditional sense of "packed" tensor +template +struct DeviceBatchedContractionMultipleD_Wmma_CShuffle_V3 + : public DeviceBatchedContractionMultipleD +{ + using DeviceOp = DeviceBatchedContractionMultipleD_Wmma_CShuffle_V3; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + // Assume: A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...] + static auto MakeAGridDescriptor_M_K(const std::vector& a_gs_ms_ks_lengths_vec, + const std::vector& a_gs_ms_ks_strides_vec) + { + assert(a_gs_ms_ks_lengths_vec.size() == NumDimG + NumDimM + NumDimK && + a_gs_ms_ks_strides_vec.size() == NumDimG + NumDimM + NumDimK); + + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, Number{}); + }; + + const auto a_ms_ks_lengths = to_tuple( + a_gs_ms_ks_lengths_vec, Number{}, Number{}); + const auto a_ms_ks_strides = to_tuple( + a_gs_ms_ks_strides_vec, Number{}, Number{}); + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{}; + + // dimension Ids for K0, K1, ... + constexpr auto kDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(a_ms_ks_lengths, mDimIds); + + // lengths for K0, K1, ... + const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds); + + if constexpr(ASpec == TensorSpecialization::Packed) + { + auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{}); + auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{}); + const auto a_grid_desc_mraw_kraw = make_naive_tensor_descriptor( + make_tuple(M, K), + make_tuple(a_ms_ks_strides[Number{}], + a_ms_ks_strides[Number{}])); + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + else + { + // naive tensor A[M0, M1, M2, ..., K0, K1, K2...] + const auto a_grid_desc_ms_ks = + make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides); + + // transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...] + const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor( + a_grid_desc_ms_ks, + make_tuple(make_merge_transform(mLengths), make_merge_transform(kLengths)), + make_tuple(mDimIds, kDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + } + + // Assume: B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...] + static auto MakeBGridDescriptor_N_K(const std::vector& b_gs_ns_ks_lengths_vec, + const std::vector& b_gs_ns_ks_strides_vec) + { + assert(b_gs_ns_ks_lengths_vec.size() == NumDimG + NumDimN + NumDimK && + b_gs_ns_ks_strides_vec.size() == NumDimG + NumDimN + NumDimK); + + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, Number{}); + }; + + const auto b_ns_ks_lengths = to_tuple( + b_gs_ns_ks_lengths_vec, Number{}, Number{}); + const auto b_ns_ks_strides = to_tuple( + b_gs_ns_ks_strides_vec, Number{}, Number{}); + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = typename arithmetic_sequence_gen<0, NumDimN, 1>::type{}; + + // dimension Ids for K0, K1, ... + constexpr auto kDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for K0, K1, ... + const auto kLengths = get_container_subset(b_ns_ks_lengths, kDimIds); + + // lengths for N0, N1, ... + const auto nLengths = get_container_subset(b_ns_ks_lengths, nDimIds); + + if constexpr(BSpec == TensorSpecialization::Packed) + { + auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{}); + auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{}); + const auto b_grid_desc_nraw_kraw = make_naive_tensor_descriptor( + make_tuple(N, K), + make_tuple(b_ns_ks_strides[Number{}], + b_ns_ks_strides[Number{}])); + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + else + { + // naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...] + const auto b_grid_desc_ns_ks = + make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides); + + // transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...] + const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor( + b_grid_desc_ns_ks, + make_tuple(make_merge_transform(nLengths), make_merge_transform(kLengths)), + make_tuple(nDimIds, kDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + } + + // assume E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] + static auto MakeEGridDescriptor_M_N(const std::vector& e_gs_ms_ns_lengths_vec, + const std::vector& e_gs_ms_ns_strides_vec) + { + assert(e_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN && + e_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN); + + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, Number{}); + }; + + const auto e_ms_ns_lengths = to_tuple( + e_gs_ms_ns_lengths_vec, Number{}, Number{}); + const auto e_ms_ns_strides = to_tuple( + e_gs_ms_ns_strides_vec, Number{}, Number{}); + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{}; + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(e_ms_ns_lengths, mDimIds); + + // lengths for K0, K1, ... + const auto nLengths = get_container_subset(e_ms_ns_lengths, nDimIds); + + if constexpr(DESpec == TensorSpecialization::Packed) + { + auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{}); + auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{}); + const auto e_grid_desc_mraw_nraw = make_naive_tensor_descriptor( + make_tuple(M, N), + make_tuple(e_ms_ns_strides[Number{}], + e_ms_ns_strides[Number{}])); + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + else + { + // naive tensor E[M0, M1, M2, ..., N0, N1, N2...] + const auto e_grid_desc_ms_ns = + make_naive_tensor_descriptor(e_ms_ns_lengths, e_ms_ns_strides); + + // transformed tensor E[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...] + const auto e_grid_desc_mraw_nraw = transform_tensor_descriptor( + e_grid_desc_ms_ns, + make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)), + make_tuple(mDimIds, nDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + } + + // assume E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] + static auto MakeEGridDescriptor_G_M_N(const std::vector& e_gs_ms_ns_lengths_vec, + const std::vector& e_gs_ms_ns_strides_vec) + { + assert(e_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN && + e_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN); + + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, Number{}); + }; + + const auto e_gs_ms_ns_lengths = + to_tuple(e_gs_ms_ns_lengths_vec, Number<0>{}, Number{}); + const auto e_gs_ms_ns_strides = + to_tuple(e_gs_ms_ns_strides_vec, Number<0>{}, Number{}); + + // dimension Ids for G0, G1, ... + constexpr auto gDimIds = typename arithmetic_sequence_gen<0, NumDimG, 1>::type{}; + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = + typename arithmetic_sequence_gen::type{}; + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = typename arithmetic_sequence_gen::type{}; + + // lengths for G0, G1, ... + const auto gLengths = get_container_subset(e_gs_ms_ns_lengths, gDimIds); + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(e_gs_ms_ns_lengths, mDimIds); + + // lengths for K0, K1, ... + const auto nLengths = get_container_subset(e_gs_ms_ns_lengths, nDimIds); + + if constexpr(DESpec == TensorSpecialization::Packed) + { + auto G = container_reduce(gLengths, math::multiplies{}, Number<1>{}); + auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{}); + auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{}); + const auto e_grid_desc_g_mraw_nraw = make_naive_tensor_descriptor( + make_tuple(G, M, N), + make_tuple(e_gs_ms_ns_strides[Number{}], + e_gs_ms_ns_strides[Number{}], + e_gs_ms_ns_strides[Number{}])); + // return matrix_padder.PadCDescriptor_M_N(e_grid_desc_g_mraw_nraw); + return e_grid_desc_g_mraw_nraw; + } + else + { + // naive tensor E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] + const auto e_grid_desc_gs_ms_ns = + make_naive_tensor_descriptor(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + + // transformed tensor E[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * + // N2 * ...] + const auto e_grid_desc_g_mraw_nraw = transform_tensor_descriptor( + e_grid_desc_gs_ms_ns, + make_tuple(make_merge_transform(gLengths), + make_merge_transform(mLengths), + make_merge_transform(nLengths)), + make_tuple(gDimIds, mDimIds, nDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // return matrix_padder.PadCDescriptor_M_N(e_grid_desc_g_mraw_nraw); + return e_grid_desc_g_mraw_nraw; + } + } + + static auto MakeDsGridDescriptor_M_N( + const std::array, NumDTensor>& ds_gs_ms_ns_lengths_vec, + const std::array, NumDTensor>& ds_gs_ms_ns_strides_vec) + { + return generate_tuple( + [&](auto i) { + return DeviceOp::MakeEGridDescriptor_M_N(ds_gs_ms_ns_lengths_vec[i], + ds_gs_ms_ns_strides_vec[i]); + }, + Number{}); + } + + static auto MakeDsGridDescriptor_G_M_N( + const std::array, NumDTensor>& ds_gs_ms_ns_lengths_vec, + const std::array, NumDTensor>& ds_gs_ms_ns_strides_vec) + { + return generate_tuple( + [&](auto i) { + return DeviceOp::MakeEGridDescriptor_G_M_N(ds_gs_ms_ns_lengths_vec[i], + ds_gs_ms_ns_strides_vec[i]); + }, + Number{}); + } + + // GridwiseGemm + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using DsLayout = decltype(generate_tuple( + [](auto) { return ck::tensor_layout::gemm::RowMajor{}; }, Number{})); + using ELayout = ck::tensor_layout::gemm::RowMajor; + + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< + ALayout, + BLayout, + DsLayout, + ELayout, + Tuple, + Tuple, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + false, // PermuteA + false // PermuteB + >; + + // block-to-e-tile map + using Block2ETileMap = GridwiseGemm::Block2CTileMap; + + // problem grid descriptors + using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K({}, {})); + using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K({}, {})); + using DsGridDesc_M_N = remove_cvref_t; + using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {})); + + using DsGridDesc_G_M_N = remove_cvref_t; + using EGridDesc_G_M_N = decltype(MakeEGridDescriptor_G_M_N({}, {})); + + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{}, 0, 0))>; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemm::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + EGridDesc_M_N{}, 0, 0))>; + + struct ComputePtrOffsetOfStridedBatch + { + ComputePtrOffsetOfStridedBatch(index_t batch_stride_A, + index_t batch_stride_B, + DsGridDesc_G_M_N ds_grid_desc_g_m_n, + EGridDesc_G_M_N e_grid_desc_g_m_n) + : batch_stride_A_(batch_stride_A), + batch_stride_B_(batch_stride_B), + ds_grid_desc_g_m_n_(ds_grid_desc_g_m_n), + e_grid_desc_g_m_n_(e_grid_desc_g_m_n) + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return static_cast(g_idx) * batch_stride_A_; + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return static_cast(g_idx) * batch_stride_B_; + } + + __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const + { + std::array ds_offset; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + ds_offset[i] = static_cast(g_idx) * + ds_grid_desc_g_m_n_[i].CalculateOffset(make_multi_index(1, 0, 0)); + }); + + return ds_offset; + } + + __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const + { + return static_cast(g_idx) * + e_grid_desc_g_m_n_.CalculateOffset(make_multi_index(1, 0, 0)); + } + + private: + index_t batch_stride_A_; + index_t batch_stride_B_; + DsGridDesc_G_M_N ds_grid_desc_g_m_n_; + EGridDesc_G_M_N e_grid_desc_g_m_n_; + }; + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a_grid, + const void* p_b_grid, + std::array p_ds_grid, + void* p_e_grid, + const std::vector& a_gs_ms_ns_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b_gs_ns_ks_lengths, + const std::vector& b_gs_ns_ks_strides, + const std::array, NumDTensor>& ds_gs_ms_ns_lengths, + const std::array, NumDTensor>& ds_gs_ms_ns_strides, + const std::vector& e_gs_ms_ns_lengths, + const std::vector& e_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : p_a_grid_{static_cast(p_a_grid)}, + p_b_grid_{static_cast(p_b_grid)}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e_grid)}, + KBatch(1), + a_grid_desc_m_k_{ + DeviceOp::MakeAGridDescriptor_M_K(a_gs_ms_ns_lengths, a_gs_ms_ks_strides)}, + b_grid_desc_n_k_{ + DeviceOp::MakeBGridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)}, + ds_grid_desc_m_n_{}, + e_grid_desc_m_n_{ + DeviceOp::MakeEGridDescriptor_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides)}, + ds_grid_desc_g_m_n_{ + DeviceOp::MakeDsGridDescriptor_G_M_N(ds_gs_ms_ns_lengths, ds_gs_ms_ns_strides)}, + e_grid_desc_g_m_n_{ + DeviceOp::MakeEGridDescriptor_G_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides)}, + ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + compute_ptr_offset_of_batch_{a_gs_ms_ks_strides[NumDimG - 1], + b_gs_ns_ks_strides[NumDimG - 1], + ds_grid_desc_g_m_n_, + e_grid_desc_g_m_n_} + { + static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0, + "Invalid number of dimensions"); + + // populate pointer, batch stride, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + + // D pointer + p_ds_grid_(i) = static_cast(p_ds_grid[i]); + + // D desc + ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N(ds_gs_ms_ns_lengths[i], + ds_gs_ms_ns_strides[i]); + }); + + // Extract 2D GEMM dimensions + G = e_grid_desc_g_m_n_.GetLength(I0); + M = e_grid_desc_g_m_n_.GetLength(I1); + N = e_grid_desc_g_m_n_.GetLength(I2); + K = a_grid_desc_m_k_.GetLength(I1); + AK0 = GridwiseGemm::CalculateAK0Padded(K); + + index_t MBlock = GridwiseGemm::CalculateMBlock(M); + index_t NBlock = GridwiseGemm::CalculateMBlock(N); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_, MBlock, NBlock); + + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_, MBlock, NBlock); + + block_2_etile_map_ = GridwiseGemm::DefaultBlock2CTileMap(M, N); + } + + void Print() const + { + std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl; + std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl; + static_for<0, NumDTensor, 1>{}( + [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; }); + std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; + } + + // private: + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + index_t G, M, N, K; + index_t KBatch; // Always 1, but included for compatability with GridwiseGemm::CheckValidity + index_t AK0; // Also included for compatibility + + // tensor descriptors for problem definiton + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_N_K b_grid_desc_n_k_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + + DsGridDesc_G_M_N ds_grid_desc_g_m_n_; + EGridDesc_G_M_N e_grid_desc_g_m_n_; + + // tensor descriptors for block/thread-wise copy + // AK0_M_AK1/BK0_N_BK1 are generated in the kernel to match the transfer method used + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock_; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // block-to-e-tile map + Block2ETileMap block_2_etile_map_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!DeviceOp::IsSupportedArgument(arg)) + { + throw std::runtime_error( + "wrong! DeviceBatchedContractionMultipleD_Wmma_CShuffle_V3 has invalid " + "setting"); + } + + const index_t grid_size = arg.block_2_etile_map_.CalculateGridSize(arg.M, arg.N); + + auto launch_kernel = [&](auto has_main_k_block_loop, auto tail_number) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + constexpr auto tail_num = tail_number.value; + + constexpr index_t minimum_occupancy = []() { + if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave) + { + return 2; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1; + } + else + { + return 1; + } + }(); + + const auto kernel = + kernel_contraction_multiple_d_wmma_cshuffle_v3; + + return launch_and_time_kernel( + stream_config, kernel, dim3(grid_size, arg.G, 1), dim3(BlockSize), 0, arg); + }; + + bool HasMainKBlockLoop = GridwiseGemm::CalculateHasMainKBlockLoop(arg.K); + TailNumber TailNum = GridwiseGemm::CalculateKBlockLoopTailNum(arg.K); + + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(HasMainKBlockLoop && TailNum == TailNumber::Full) + { + return launch_kernel(std::integral_constant{}, + std::integral_constant{}); + } + else if(!HasMainKBlockLoop && TailNum == TailNumber::Full) + { + return launch_kernel(std::integral_constant{}, + std::integral_constant{}); + } + else + { + throw std::runtime_error( + "Invalid HasMainKBlockLoop and TailNum combination for pipeline V1!\n"); + } + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(HasMainKBlockLoop && TailNum == TailNumber::Full) + { + return launch_kernel(std::integral_constant{}, + std::integral_constant{}); + } + else if(!HasMainKBlockLoop && TailNum == TailNumber::Even) + { + return launch_kernel(std::integral_constant{}, + std::integral_constant{}); + } + else if(!HasMainKBlockLoop && TailNum == TailNumber::Odd) + { + return launch_kernel(std::integral_constant{}, + std::integral_constant{}); + } + else + { + throw std::runtime_error( + "Invalid HasMainKBlockLoop and TailNum combination for pipeline V3!\n"); + } + } + else + { + throw std::runtime_error("Invalid pipeline version! Only V1 and V3 supported\n"); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported())) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "GPU Arch not supported" << std::endl; + } + return false; + } + + // check vector access + static_assert((ABlockTransferSrcVectorDim == 1 || ABlockTransferSrcVectorDim == 2) && + (BBlockTransferSrcVectorDim == 1 || BBlockTransferSrcVectorDim == 2), + "Wrong dimension for A or B vector loads, should be 1 or 2!"); + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto + MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + const std::vector& a_gs_ms_ns_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b_gs_ns_ks_lengths, + const std::vector& b_gs_ns_ks_strides, + const std::array, NumDTensor>& ds_gs_ms_ns_lengths, + const std::array, NumDTensor>& ds_gs_ms_ns_strides, + const std::vector& e_gs_ms_ns_lengths, + const std::vector& e_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + a_gs_ms_ns_lengths, + a_gs_ms_ks_strides, + b_gs_ns_ks_lengths, + b_gs_ns_ks_strides, + ds_gs_ms_ns_lengths, + ds_gs_ms_ns_strides, + e_gs_ms_ns_lengths, + e_gs_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + const std::vector& a_gs_ms_ns_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b_gs_ns_ks_lengths, + const std::vector& b_gs_ns_ks_strides, + const std::array, NumDTensor>& ds_gs_ms_ns_lengths, + const std::array, NumDTensor>& ds_gs_ms_ns_strides, + const std::vector& e_gs_ms_ns_lengths, + const std::vector& e_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_gs_ms_ns_lengths, + a_gs_ms_ks_strides, + b_gs_ns_ks_lengths, + b_gs_ns_ks_strides, + ds_gs_ms_ns_lengths, + ds_gs_ms_ns_strides, + e_gs_ms_ns_lengths, + e_gs_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceBatchedContractionMultipleD_Wmma_CShuffle_V3" + << "<" + << NumDimG << ", " + << NumDimM << ", " + << NumDimN << ", " + << NumDimK << ", " + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << ABlockTransferSrcVectorDim << ", " + << BBlockTransferSrcVectorDim + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp index b46afda8b7..a1cba118b2 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp @@ -414,22 +414,22 @@ struct GridwiseGemm_wmma_cshuffle_v3 struct Argument : public tensor_operation::device::BaseArgument, public Problem { __host__ Argument() = default; - __host__ Argument(std::array p_as_grid_, - std::array p_bs_grid_, - std::array p_ds_grid_, - EDataType* p_e_grid_, - index_t M_, - index_t N_, - index_t K_, - std::array StrideAs_, - std::array StrideBs_, - std::array StrideDs_, - index_t StrideE_, - index_t k_batch_, - AElementwiseOperation a_element_op_, - BElementwiseOperation b_element_op_, - CDEElementwiseOperation cde_element_op_, - bool is_reduce_ = false) + __host__ __device__ Argument(std::array p_as_grid_, + std::array p_bs_grid_, + std::array p_ds_grid_, + EDataType* p_e_grid_, + index_t M_, + index_t N_, + index_t K_, + std::array StrideAs_, + std::array StrideBs_, + std::array StrideDs_, + index_t StrideE_, + index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CDEElementwiseOperation cde_element_op_, + bool is_reduce_ = false) : Problem{M_, N_, K_, StrideAs_, StrideBs_, StrideDs_, StrideE_, k_batch_}, p_as_grid{}, p_bs_grid{}, @@ -607,6 +607,67 @@ struct GridwiseGemm_wmma_cshuffle_v3 MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( e_grid_desc_m_n, problem.MBlock, problem.NBlock); + Run(p_as_grid, + p_bs_grid, + p_ds_grid, + p_e_grid, + p_shared, + as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_ctile_map, + a_element_op, + b_element_op, + cde_element_op, + epilogue_args, + A_k_id, + B_k_id); + } + + // Overload to pass in custom As/Bs/Ds/E grid descriptors + // Used for contraction operations, where tensor transforms are non-trivial + template + __device__ static void Run(AsGridPointer& p_as_grid, + BsGridPointer& p_bs_grid, + DsGridPointer& p_ds_grid, + EDataType* p_e_grid, + void* p_shared, + const AsGridDescriptor_AK0_M_AK1 as_grid_desc_ak0_m_ak1, + const BsGridDescriptor_BK0_N_BK1 bs_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2CTileMap& block_2_ctile_map, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + EpilogueArgument& epilogue_args, + const index_t A_k_id = 0, + const index_t B_k_id = 0) + { + const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); @@ -773,9 +834,13 @@ struct GridwiseGemm_wmma_cshuffle_v3 B_k_id); } - __device__ static auto DefaultBlock2CTileMap(const Problem& problem) + __device__ __host__ static auto DefaultBlock2CTileMap(const Problem& problem) { - return Block2CTileMap{problem.M, problem.N, 4}; + return DefaultBlock2CTileMap(problem.M, problem.N); + } + __device__ __host__ static auto DefaultBlock2CTileMap(const index_t M, const index_t N) + { + return Block2CTileMap{M, N, 4}; } // Run method for convolution for bwd_data (grid descriptors are passed as arguments, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index ec7710d066..b7b88d4920 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -499,8 +499,10 @@ struct GridwiseGemm_wmma_cshuffle_v3_base } } + template __host__ __device__ static auto - MakeAsGridDescriptor_AK0_M_AK1(const index_t M, + MakeAsGridDescriptor_AK0_M_AK1(const BaseDescriptors_M_K& base_descs, + const index_t M, const index_t MPad, const index_t K, const index_t KPad, @@ -518,10 +520,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base GemmSpec == GemmSpecialization::NKPadding; return generate_tuple( [&](auto i) { - const auto base_desc = MakeAGridDescriptor_M_K(M, K, StrideAs[i]); - return ATransfer::template MakeGridDescriptor( - base_desc, M, MPad, K, KPad, StrideAs[i], AK0); + base_descs[i], M, MPad, K, KPad, StrideAs[i], AK0); }, Number{}); } @@ -539,8 +539,39 @@ struct GridwiseGemm_wmma_cshuffle_v3_base return ATransfer::template MakeGridDescriptor(base_desc, M, M, K, K, 0, AK0); } + template __host__ __device__ static auto - MakeBsGridDescriptor_BK0_N_BK1(const index_t K, + MakeAsGridDescriptor_AK0_M_AK1(const BaseDescriptors_M_K& base_descs, const index_t KBatch = 1) + { + const index_t M = base_descs.At(I0).GetLength(I0); + const index_t K = base_descs.At(I0).GetLength(I1); + + const index_t MPad = CalculateMPadded(M); + const index_t KPad = CalculateKPadded(K, KBatch); + + const index_t AK0 = CalculateAK0Padded(K, KBatch); + + return MakeAsGridDescriptor_AK0_M_AK1(base_descs, M, MPad, K, KPad, {}, AK0); + } + + __host__ __device__ static auto + MakeAsGridDescriptor_AK0_M_AK1(const index_t M, + const index_t MPad, + const index_t K, + const index_t KPad, + const std::array& StrideAs, + const index_t AK0) + { + const auto base_descs = + generate_tuple([&](auto i) { return MakeAGridDescriptor_M_K(M, K, StrideAs[i]); }, + Number{}); + return MakeAsGridDescriptor_AK0_M_AK1(base_descs, M, MPad, K, KPad, StrideAs, AK0); + } + + template + __host__ __device__ static auto + MakeBsGridDescriptor_BK0_N_BK1(const BaseDescriptors_N_K& base_descs, + const index_t K, const index_t KPad, const index_t N, const index_t NPad, @@ -558,9 +589,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base GemmSpec == GemmSpecialization::MKPadding; return generate_tuple( [&](auto i) { - const auto base_desc = MakeBGridDescriptor_N_K(N, K, StrideBs[i]); return BTransfer::template MakeGridDescriptor( - base_desc, N, NPad, K, KPad, StrideBs[i], BK0); + base_descs[i], N, NPad, K, KPad, StrideBs[i], BK0); }, Number{}); } @@ -578,6 +608,36 @@ struct GridwiseGemm_wmma_cshuffle_v3_base return BTransfer::template MakeGridDescriptor(base_desc, N, N, K, K, 0, BK0); } + template + __host__ __device__ static auto + MakeBsGridDescriptor_BK0_N_BK1(const BaseDescriptors_N_K& base_descs, const index_t KBatch = 1) + { + const index_t N = base_descs.At(I0).GetLength(I0); + const index_t K = base_descs.At(I0).GetLength(I1); + + const index_t NPad = CalculateNPadded(N); + const index_t KPad = CalculateKPadded(K, KBatch); + + const index_t BK0 = CalculateBK0Padded(K, KBatch); + + return MakeBsGridDescriptor_BK0_N_BK1(base_descs, K, KPad, N, NPad, {}, BK0); + } + + __host__ __device__ static auto + MakeBsGridDescriptor_BK0_N_BK1(const index_t K, + const index_t KPad, + const index_t N, + const index_t NPad, + const std::array& StrideBs, + const index_t BK0) + { + + const auto base_descs = + generate_tuple([&](auto i) { return MakeBGridDescriptor_N_K(N, K, StrideBs[i]); }, + Number{}); + return MakeBsGridDescriptor_BK0_N_BK1(base_descs, K, KPad, N, NPad, StrideBs, BK0); + } + __host__ __device__ static constexpr auto MakeAWmmaTileDescriptor() { constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma); @@ -681,7 +741,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base } template - __device__ __host__ static constexpr auto + __host__ __device__ static constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock) diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_contraction.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_contraction.hpp index 24343666cc..d73ceb1de5 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_contraction.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_contraction.hpp @@ -231,6 +231,279 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base } }; +// hardcoded for NumDimG == 1, NumDimM == 2, NumDimN == 3, NumDimK == 1 +template = + false> +struct ReferenceBatchedContraction_G1_M2_N3_K1 : public ck::tensor_operation::device::BaseOperator +{ + // Argument + struct Argument : public ck::tensor_operation::device::BaseArgument + { + Argument(const Tensor& a_gs_ms_ks, + const Tensor& b_gs_ns_ks, + Tensor& e_gs_ms_ns, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : a_gs_ms_ks_{a_gs_ms_ks}, + b_gs_ns_ks_{b_gs_ns_ks}, + e_gs_ms_ns_{e_gs_ms_ns}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + } + + const Tensor& a_gs_ms_ks_; + const Tensor& b_gs_ns_ks_; + Tensor& e_gs_ms_ns_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + }; + + // Invoker + struct Invoker : public ck::tensor_operation::device::BaseInvoker + { + using Argument = ReferenceBatchedContraction_G1_M2_N3_K1::Argument; + + float Run(const Argument& arg) + { + auto f_gs_ms_ns = [&](auto g0, auto m0, auto m1, auto n0, auto n1, auto n2) { + const int K0 = arg.a_gs_ms_ks_.mDesc.GetLengths()[3]; + + AccDataType v_acc = 0; + + for(int k0 = 0; k0 < K0; ++k0) + { + AccDataType v_a; + AccDataType v_b; + + arg.a_element_op_( + v_a, ck::type_convert(arg.a_gs_ms_ks_(g0, m0, m1, k0))); + arg.b_element_op_( + v_b, + ck::type_convert(arg.b_gs_ns_ks_(g0, n0, n1, n2, k0))); + + v_acc += v_a * v_b; + } + + AccDataType v_c; + + arg.cde_element_op_(v_c, v_acc); + + arg.e_gs_ms_ns_(g0, m0, m1, n0, n1, n2) = v_c; + }; + + make_ParallelTensorFunctor(f_gs_ms_ns, + arg.e_gs_ms_ns_.mDesc.GetLengths()[0], + arg.e_gs_ms_ns_.mDesc.GetLengths()[1], + arg.e_gs_ms_ns_.mDesc.GetLengths()[2], + arg.e_gs_ms_ns_.mDesc.GetLengths()[3], + arg.e_gs_ms_ns_.mDesc.GetLengths()[4], + arg.e_gs_ms_ns_.mDesc.GetLengths()[5])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const ck::tensor_operation::device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override + { + return true; + } + + static auto MakeArgument(const Tensor& a_gs_ms_ks, + const Tensor& b_gs_ns_ks, + Tensor& e_gs_ms_ns, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{ + a_gs_ms_ks, b_gs_ns_ks, e_gs_ms_ns, a_element_op, b_element_op, cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceBatchedContraction_G1_M3_N2_K1" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +template = + false> +struct ReferenceBatchedContraction_G1_M3_N2_K1 : public ck::tensor_operation::device::BaseOperator +{ + // Argument + struct Argument : public ck::tensor_operation::device::BaseArgument + { + Argument(const Tensor& a_gs_ms_ks, + const Tensor& b_gs_ns_ks, + Tensor& e_gs_ms_ns, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : a_gs_ms_ks_{a_gs_ms_ks}, + b_gs_ns_ks_{b_gs_ns_ks}, + e_gs_ms_ns_{e_gs_ms_ns}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + } + + const Tensor& a_gs_ms_ks_; + const Tensor& b_gs_ns_ks_; + Tensor& e_gs_ms_ns_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + }; + + // Invoker + struct Invoker : public ck::tensor_operation::device::BaseInvoker + { + using Argument = ReferenceBatchedContraction_G1_M3_N2_K1::Argument; + + float Run(const Argument& arg) + { + auto f_gs_ms_ns = [&](auto g0, auto m0, auto m1, auto m2, auto n0, auto n1) { + const int K0 = arg.a_gs_ms_ks_.mDesc.GetLengths()[4]; + + AccDataType v_acc = 0; + + for(int k0 = 0; k0 < K0; ++k0) + { + AccDataType v_a; + AccDataType v_b; + + arg.a_element_op_( + v_a, + ck::type_convert(arg.a_gs_ms_ks_(g0, m0, m1, m2, k0))); + arg.b_element_op_( + v_b, ck::type_convert(arg.b_gs_ns_ks_(g0, n0, n1, k0))); + + v_acc += v_a * v_b; + } + + AccDataType v_c; + + arg.cde_element_op_(v_c, v_acc); + + arg.e_gs_ms_ns_(g0, m0, m1, m2, n0, n1) = v_c; + }; + + make_ParallelTensorFunctor(f_gs_ms_ns, + arg.e_gs_ms_ns_.mDesc.GetLengths()[0], + arg.e_gs_ms_ns_.mDesc.GetLengths()[1], + arg.e_gs_ms_ns_.mDesc.GetLengths()[2], + arg.e_gs_ms_ns_.mDesc.GetLengths()[3], + arg.e_gs_ms_ns_.mDesc.GetLengths()[4], + arg.e_gs_ms_ns_.mDesc.GetLengths()[5])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const ck::tensor_operation::device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override + { + return true; + } + + static auto MakeArgument(const Tensor& a_gs_ms_ks, + const Tensor& b_gs_ns_ks, + Tensor& e_gs_ms_ns, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{ + a_gs_ms_ks, b_gs_ns_ks, e_gs_ms_ns, a_element_op, b_element_op, cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceBatchedContraction_G1_M3_N2_K1" + << std::endl; + // clang-format on + + return str.str(); + } +}; + } // namespace host } // namespace tensor_operation } // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_bias_permute.hpp b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_bias_permute.hpp index e510f17fb2..9886ccdfbf 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_bias_permute.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_bias_permute.hpp @@ -19,6 +19,7 @@ namespace tensor_operation { namespace device { namespace instance { +#ifdef CK_USE_XDL void add_device_batched_contraction_bias_permute_m2_n3_k1_xdl_c_shuffle_f16_f16_f16_f16_mnnm_instance( std::vector>>& instances); +#endif + +#ifdef CK_USE_WMMA +void add_device_batched_contraction_bias_permute_m2_n3_k1_wmma_c_shuffle_f16_f16_f16_f16_mnnm_instance( + std::vector>>& instances); +#endif // Contraction + add template && is_same_v && is_same_v && is_same_v) { + if constexpr(NumDimG == 1 && NumDimM == 2 && NumDimN == 3 && NumDimK == 1) { +#ifdef CK_USE_XDL add_device_batched_contraction_bias_permute_m2_n3_k1_xdl_c_shuffle_f16_f16_f16_f16_mnnm_instance( op_ptrs); +#endif +#ifdef CK_USE_WMMA + add_device_batched_contraction_bias_permute_m2_n3_k1_wmma_c_shuffle_f16_f16_f16_f16_mnnm_instance( + op_ptrs); +#endif } } diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_bias_permute/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_bias_permute/CMakeLists.txt index a4f66fdd4d..a0f9b6fb07 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_bias_permute/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_bias_permute/CMakeLists.txt @@ -1,8 +1,9 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS add_instance_library(device_batched_gemm_bias_permute_instance device_batched_gemm_bias_permute_m2_n3_k1_xdl_c_shuffle_f16_f16_f16_f16_instance.cpp + device_batched_gemm_bias_permute_m2_n3_k1_wmma_c_shuffle_f16_f16_f16_f16_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_bias_permute/device_batched_gemm_bias_permute_m2_n3_k1_wmma_c_shuffle_f16_f16_f16_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_bias_permute/device_batched_gemm_bias_permute_m2_n3_k1_wmma_c_shuffle_f16_f16_f16_f16_instance.cpp new file mode 100644 index 0000000000..8bcd223e19 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_bias_permute/device_batched_gemm_bias_permute_m2_n3_k1_wmma_c_shuffle_f16_f16_f16_f16_instance.cpp @@ -0,0 +1,78 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using F16_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Add = ck::tensor_operation::element_wise::Add; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; +static constexpr auto ABSpec = ck::tensor_operation::device::TensorSpecialization::Packed; +static constexpr auto DESpec = ck::tensor_operation::device::TensorSpecialization::Default; + +// A[g0, m0, m1, k0] * B[g0, n0, n1, n2, k0] + D[g0, m0, m1, n0, n1, n2] = E[g0, n0, m0, n0, n1, m1] +// m/n/n/n are the fast changing dimension for A/B/D/E +using device_batched_contraction_bias_permute_m2_n3_k1_wmma_c_shuffle_f16_f16_f16_f16_mnnm_instance = + std::tuple< + // clang-format off + //################################################| NumDimG| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| A| B| DE| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CDEBlockTransferClusterLengths| CDEBlockTransfer| + //################################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Specialization| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| + //################################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| + //################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedContractionMultipleD_Wmma_CShuffle_V3< 1, 2, 3, 1, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmSpec, ABSpec, ABSpec, DESpec, 256, 256, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1>>, + DeviceBatchedContractionMultipleD_Wmma_CShuffle_V3< 1, 2, 3, 1, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmSpec, ABSpec, ABSpec, DESpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1>>, + DeviceBatchedContractionMultipleD_Wmma_CShuffle_V3< 1, 2, 3, 1, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmSpec, ABSpec, ABSpec, DESpec, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<1, 1>>, + DeviceBatchedContractionMultipleD_Wmma_CShuffle_V3< 1, 2, 3, 1, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmSpec, ABSpec, ABSpec, DESpec, 128, 64, 32, 64, 8, 8, 16, 16, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<1, 1>>, + DeviceBatchedContractionMultipleD_Wmma_CShuffle_V3< 1, 2, 3, 1, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmSpec, ABSpec, ABSpec, DESpec, 256, 256, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedContractionMultipleD_Wmma_CShuffle_V3< 1, 2, 3, 1, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmSpec, ABSpec, ABSpec, DESpec, 128, 64, 32, 64, 8, 8, 16, 16, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<4, 4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedContractionMultipleD_Wmma_CShuffle_V3< 1, 2, 3, 1, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmSpec, ABSpec, ABSpec, DESpec, 64, 64, 32, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<4, 4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; + +void add_device_batched_contraction_bias_permute_m2_n3_k1_wmma_c_shuffle_f16_f16_f16_f16_mnnm_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_batched_contraction_bias_permute_m2_n3_k1_wmma_c_shuffle_f16_f16_f16_f16_mnnm_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_batched_contraction_multiple_d_impl.hpp b/profiler/include/profiler/profile_batched_contraction_multiple_d_impl.hpp new file mode 100644 index 0000000000..e1035b37ed --- /dev/null +++ b/profiler/include/profiler/profile_batched_contraction_multiple_d_impl.hpp @@ -0,0 +1,309 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" +#include "ck/library/tensor_operation_instance/gpu/batched_gemm_bias_permute.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/utility/numeric.hpp" + +namespace ck { +namespace profiler { + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Row = ck::tensor_layout::gemm::RowMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + +template +bool profile_batched_contraction_multiple_d_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + std::array Gs, + std::array Ms, + std::array Ns, + std::array Ks, + int instance_index = -1, + bool fail_if_no_supported_instances = false) +{ + static_assert(NumDimG == 1 && NumDimM == 2 && NumDimN == 3 && NumDimK == 1, + "Tensor ranks not supported. Supported: G=1, M=2, N=3, K=1"); + static_assert(DsDataType::Size() == 1, "Only single D tensor is supported at the moment."); + + using AccDataType = float; + using DDataType = ck::tuple_element_t<0, DsDataType>; + + bool pass = true; + + ignore = do_log; + + ck::index_t G0 = Gs[0]; + + ck::index_t M0 = Ms[0]; + ck::index_t M1 = Ms[1]; + + ck::index_t N0 = Ns[0]; + ck::index_t N1 = Ns[1]; + ck::index_t N2 = Ns[2]; + + ck::index_t K0 = Ks[0]; + + // A[M0, M1, M2, K0] + std::vector a_gs_ms_ks_lengths{G0, M0, M1, K0}; + std::vector a_gs_ms_ks_strides{M0 * M1 * K0, M1 * K0, K0, 1}; + // B[N0, N1, K0] + std::vector b_gs_ns_ks_lengths{G0, N0, N1, N2, K0}; + std::vector b_gs_ns_ks_strides{N0 * N1 * N2 * K0, N1 * N2 * K0, N2 * K0, K0, 1}; + + // D[N0, M0, N1, M1, N2] + std::vector d_gs_ms_ns_lengths{G0, M0, M1, N0, N1, N2}; + std::vector d_gs_ms_ns_strides{N0 * N1 * N2, 0, 0, N1 * N2, N2, 1}; + // E[N0, M0, N1, M1, N2] + std::vector e_gs_ms_ns_lengths{G0, M0, M1, N0, N1, N2}; + std::vector e_gs_ms_ns_strides{ + M0 * M1 * N0 * N1 * N2, N1 * M1 * N2, N2, M0 * N1 * M1 * N2, M1 * N2, 1}; + + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, Row{}); + Tensor b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides, Row{}); + Tensor d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides, Bypass{}); + Tensor e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); + Tensor e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); + + std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; + std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl; + std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.mDesc << std::endl; + std::cout << "e_gs_ms_ns: " << e_gs_ms_ns_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_gs_ns_ks.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * + e_gs_ms_ns_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_gs_ms_ks.mData.data()); + b_device_buf.ToDevice(b_gs_ns_ks.mData.data()); + d_device_buf.ToDevice(d_gs_ms_ns.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + if(do_verification) + { + Tensor c_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); + + using ReferenceOpInstance = + ck::tensor_operation::host::ReferenceBatchedContraction_G1_M2_N3_K1; + + auto ref_gemm = ReferenceOpInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_gs_ms_ks, + b_gs_ns_ks, + c_gs_ms_ns_host_result, + a_element_op, + b_element_op, + PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(size_t g0 = 0; g0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[0]; ++g0) + { + for(size_t m0 = 0; m0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[1]; ++m0) + { + for(size_t m1 = 0; m1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[2]; ++m1) + { + for(size_t n0 = 0; n0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[3]; ++n0) + { + for(size_t n1 = 0; n1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[4]; ++n1) + { + for(size_t n2 = 0; n2 < e_gs_ms_ns_host_result.mDesc.GetLengths()[5]; + ++n2) + { + cde_element_op(e_gs_ms_ns_host_result(g0, m0, m1, n0, n1, n2), + c_gs_ms_ns_host_result(g0, m0, m1, n0, n1, n2), + d_gs_ms_ns(g0, m0, m1, n0, n1, n2)); + } + } + } + } + } + } + } + + // get device op instances + using DeviceOp = ck::tensor_operation::device::DeviceBatchedContractionMultipleD; + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + int num_kernel = 0; + + // profile device op instances + for(auto& op_ptr : op_ptrs) + { + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + auto argument_ptr = + op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + a_gs_ms_ks_lengths, + a_gs_ms_ks_strides, + b_gs_ns_ks_lengths, + b_gs_ns_ks_strides, + std::array, 1>{d_gs_ms_ns_lengths}, + std::array, 1>{d_gs_ms_ns_strides}, + e_gs_ms_ns_lengths, + e_gs_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + num_kernel++; + + if((instance_index != -1) && (instance_index + 1 != num_kernel)) + { + // skip test if instance_index is specified + continue; + } + + // re-init E to zero before profiling next kernel + e_device_buf.SetZero(); + + std::string op_name = op_ptr->GetTypeString(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + ck::index_t G = ck::accumulate_n( + e_gs_ms_ns_lengths.begin(), NumDimG, 1, std::multiplies<>{}); + + ck::index_t M = ck::accumulate_n( + e_gs_ms_ns_lengths.begin() + NumDimG, NumDimM, 1, std::multiplies<>{}); + + ck::index_t N = ck::accumulate_n( + e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM, NumDimN, 1, std::multiplies<>{}); + + ck::index_t K = ck::accumulate_n( + a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM, NumDimK, 1, std::multiplies<>{}); + + std::size_t flop = std::size_t(2) * G * M * N * K; + std::size_t num_btype = sizeof(ADataType) * G * M * K + sizeof(BDataType) * G * K * N + + sizeof(DDataType) * G * M * N + sizeof(EDataType) * G * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + e_device_buf.FromDevice(e_gs_ms_ns_device_result.mData.data()); + + pass = + pass & ck::utils::check_err(e_gs_ms_ns_device_result, e_gs_ms_ns_host_result); + } + } + else + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + if(instance_index != -1) + { + std::cout << "batched_contraction_instance (" << instance_index << "/" << num_kernel + << "): Passed" << std::endl; + } + + if(fail_if_no_supported_instances && num_kernel == 0) + { + return false; + } + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index f9ad14d654..9fee3b5697 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -271,6 +271,7 @@ add_subdirectory(gemm_b_scale) add_subdirectory(gemm_universal_streamk) add_subdirectory(gemm_reduce) add_subdirectory(gemm_universal_reduce) +add_subdirectory(batched_contraction) add_subdirectory(batched_gemm) add_subdirectory(batched_gemm_reduce) add_subdirectory(batched_gemm_gemm) diff --git a/test/batched_contraction/CMakeLists.txt b/test/batched_contraction/CMakeLists.txt new file mode 100644 index 0000000000..b0a1b823d6 --- /dev/null +++ b/test/batched_contraction/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +if (CK_USE_XDL OR CK_USE_WMMA) + add_gtest_executable(test_batched_contraction test_batched_contraction.cpp) + if(result EQUAL 0) + target_link_libraries(test_batched_contraction PRIVATE utility device_batched_gemm_bias_permute_instance) + endif() +endif() \ No newline at end of file diff --git a/test/batched_contraction/test_batched_contraction.cpp b/test/batched_contraction/test_batched_contraction.cpp new file mode 100644 index 0000000000..eb6134e673 --- /dev/null +++ b/test/batched_contraction/test_batched_contraction.cpp @@ -0,0 +1,164 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include +#include + +#include + +#include "profiler/profile_batched_contraction_multiple_d_impl.hpp" + +static ck::index_t param_mask = 0xffff; +static ck::index_t instance_index = -1; + +using F16 = ck::half_t; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Add = ck::tensor_operation::element_wise::Add; + +template +class TestBatchedContraction : public ::testing::Test +{ + using ADataType = std::tuple_element_t<0, Tuple>; + using BDataType = std::tuple_element_t<1, Tuple>; + using DsDataType = std::tuple_element_t<2, Tuple>; + using EDataType = std::tuple_element_t<3, Tuple>; + using AElementOp = std::tuple_element_t<4, Tuple>; + using BElementOp = std::tuple_element_t<5, Tuple>; + using CDEElementOp = std::tuple_element_t<6, Tuple>; + + static constexpr ck::index_t NumDimG = 1; + static constexpr ck::index_t NumDimM = 2; + static constexpr ck::index_t NumDimN = 3; + static constexpr ck::index_t NumDimK = 1; + + protected: + struct GemmParams + { + std::array Gs; + std::array Ms; + std::array Ns; + std::array Ks; + }; + + bool bench_ = true; + bool verify_ = true; + bool do_log_ = true; + int init_method_ = 1; + + std::vector params; + + void Run() + { + bool pass = true; + for(size_t i = 0; i < params.size(); i++) + { + if((param_mask & (1 << i)) == 0) + { + continue; + } + auto& param = params[i]; + + pass = pass && ck::profiler::profile_batched_contraction_multiple_d_impl( + verify_, + init_method_, + do_log_, + bench_, + param.Gs, + param.Ms, + param.Ns, + param.Ks, + instance_index, + true); + } + EXPECT_TRUE(pass); + } +}; + +// clang-format off +using KernelTypes = ::testing::Types< + std::tuple, F16, PassThrough, PassThrough, Add> +>; +// clang-format on + +TYPED_TEST_SUITE(TestBatchedContraction, KernelTypes); + +TYPED_TEST(TestBatchedContraction, BaseCase) +{ + this->params = std::vector{ + // Gs, Ms, Ns, Ks + {{1}, {4, 128}, {4, 16, 32}, {256}}, + {{4}, {4, 128}, {4, 16, 32}, {256}}, + }; + this->Run(); +} +TYPED_TEST(TestBatchedContraction, TinyCases) +{ + this->params = std::vector{ + // Gs, Ms, Ns, Ks + {{1}, {1, 16}, {1, 1, 16}, {16}}, + {{2}, {4, 8}, {2, 2, 8}, {32}}, + }; + this->Run(); +} +TYPED_TEST(TestBatchedContraction, PadM) +{ + this->params = std::vector{ + // Gs, Ms, Ns, Ks + {{1}, {1, 130}, {2, 4, 32}, {256}}, + }; + this->Run(); +} + +// Disabled: Currently fails on the XDL instances +TYPED_TEST(TestBatchedContraction, DISABLED_PadN) +{ + this->params = std::vector{ + // Gs, Ms, Ns, Ks + {{1}, {1, 128}, {1, 1, 66}, {256}}, + }; + this->Run(); +} + +// Disabled: Currently fails on the WMMA and XDL instances +TYPED_TEST(TestBatchedContraction, DISABLED_PadK) +{ + this->params = std::vector{ + // Gs, Ms, Ns, Ks + {{1}, {1, 128}, {1, 1, 64}, {258}}, + }; + this->Run(); +} + +int main(int argc, char** argv) +{ + testing::InitGoogleTest(&argc, argv); + if(argc == 1) {} + else if(argc == 3) + { + param_mask = strtol(argv[1], nullptr, 0); + instance_index = atoi(argv[2]); + } + else + { + std::cout << "Usage of " << argv[0] << std::endl; + std::cout << "Arg1,2: param_mask instance_index(-1 means all)" << std::endl; + } + return RUN_ALL_TESTS(); +} From 1a6d1b59ef7358e4f07afcc0a163af7aa4b985a9 Mon Sep 17 00:00:00 2001 From: Adam Osewski <19374865+aosewski@users.noreply.github.com> Date: Mon, 19 Jan 2026 10:54:10 +0100 Subject: [PATCH 22/99] [CK_BUILDER] Convolution forward transfer concepts. (#3535) * Rename member variable to better reflect its actuall meaning. * Add transfer checks for conv fwd xdl. * Validate tensor layouts & vector size conv fwd v3. * Add combined transfer concepts. * Add transfer concepts for conv fwd factories. * Fix clang format * Add helper instruction to get max mem vector instruction width. * Apply review comments. * Rename thread cluster access(->arrange) order concept * FIx merge artifacts. * Add generic access order limits into block transfer concept. --- .../builder/conv_algorithm_concepts.hpp | 12 +- .../ck_tile/builder/conv_algorithm_limits.hpp | 223 ++++++++++++++++++ .../builder/factory/conv_algorithms.hpp | 2 +- .../factory/conv_fwd_large_tensor_factory.hpp | 57 ++++- .../builder/factory/conv_fwd_v3_factory.hpp | 66 +++++- .../builder/factory/conv_fwd_wmma_factory.hpp | 58 ++++- .../builder/factory/conv_fwd_xdl_factory.hpp | 66 +++++- .../helpers/ck/conv_block_transfer.hpp | 4 +- .../test/impl/conv_algorithm_types.hpp | 6 +- .../builder/test/test_conv_description.cpp | 32 +-- .../test/utils/ckb_conv_test_configs.hpp | 192 +++++++-------- .../test/utils/conv_algorithm_type_utils.hpp | 2 +- include/ck_tile/core/arch/arch.hpp | 7 + 13 files changed, 570 insertions(+), 157 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index 791924ccd4..29a04d9b6c 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -104,7 +104,7 @@ concept EpilogueDescriptor = requires(T t) { // Concept for the thread cluster access order template -concept AccessOrderDescriptor = requires(T t) { +concept ThreadClusterOrderDescriptor = requires(T t) { { t.order } -> std::convertible_to>; } || requires(T t) { { t.order } -> std::convertible_to>; @@ -231,16 +231,16 @@ concept SpecifiesLdsTransfer = requires(T t) { // Concept to check if a struct specifies thread cluster access order info. template -concept SpecifiesThreadClusterAccessOrder = requires(T t) { - { T::transfer.a.block_transfer_access_order } -> AccessOrderDescriptor; - { T::transfer.b.block_transfer_access_order } -> AccessOrderDescriptor; +concept SpecifiesThreadClusterArrangeOrder = requires(T t) { + { T::transfer.a.thread_cluster_arrange_order } -> ThreadClusterOrderDescriptor; + { T::transfer.b.thread_cluster_arrange_order } -> ThreadClusterOrderDescriptor; }; // Concept to check if a struct specifies source access order info. template concept SpecifiesSourceAccessOrder = requires(T t) { - { T::transfer.a.src_access_order } -> AccessOrderDescriptor; - { T::transfer.b.src_access_order } -> AccessOrderDescriptor; + { T::transfer.a.src_access_order } -> ThreadClusterOrderDescriptor; + { T::transfer.b.src_access_order } -> ThreadClusterOrderDescriptor; }; // Concept to check if struct specifies block GEMM. diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp index d35897fc78..5196eae6c7 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp @@ -5,6 +5,9 @@ #include #include +#include +#include "ck_tile/core/utility/type_traits.hpp" +#include "ck_tile/core/arch/arch.hpp" namespace ck_tile::builder { @@ -45,4 +48,224 @@ concept AccessOrderLimits4D = requires { (Value.Size() == 4)); }; +namespace detail { + +// Helper to check if access order is a valid permutation +template +constexpr bool is_valid_permutation() +{ + constexpr auto size = Value.Size(); + + // Check all values are in range [0, size) + for(size_t i = 0; i < size; ++i) + { + if(Value[i] < 0 || Value[i] >= static_cast(size)) + return false; + } + + // Check all values are unique (valid permutation) + for(size_t i = 0; i < size; ++i) + { + for(size_t j = i + 1; j < size; ++j) + { + if(Value[i] == Value[j]) + return false; + } + } + + return true; +} + +} // namespace detail + +// Generic access order limits. Must be a valid permutation of {0, 1, ..., Dims-1}. +// Works with both 3D and 4D (or any dimensionality) access orders. +template +concept AccessOrderLimits = requires { + requires Value.Size() == Dims; + requires detail::is_valid_permutation(); +}; + +namespace detail { + +// Helper trait to get compile-time size from ck::Array +template +concept HasStaticSize = requires { + { T::Size() } -> std::convertible_to; +}; + +// Helper trait to get compile-time size from std::array and similar +template +concept HasTupleSize = requires { + { std::tuple_size::value } -> std::convertible_to; +}; + +// Helper for dependent static_assert +template +constexpr bool always_false = false; + +// Get compile-time size of a range +template +constexpr size_t get_range_size() +{ + if constexpr(HasStaticSize) + { + return Range::Size(); + } + else if constexpr(HasTupleSize) + { + return std::tuple_size_v; + } + else + { + static_assert(always_false, "Unsupported type of range object."); + } +} + +// Fold expression implementation for product calculation +template +constexpr auto get_cluster_size_impl(const Range& range, std::index_sequence) +{ + using value_type = std::remove_cvref_t; + return ((range[Is]) * ... * value_type{1}); +} + +// Generic function that calculates the product of all elements in a range +// Works with any indexable range with compile-time size (ck::Array, std::array, etc.) +template + requires requires(Range r) { + r[0]; // Must be indexable + get_range_size(); // Must have compile-time size + } +constexpr auto get_cluster_size(const Range& range) +{ + return get_cluster_size_impl(range, std::make_index_sequence()>{}); +} + +// Calculate K dimension coverage (k0 * k1, with vectorization if applicable) +template +constexpr auto get_k_coverage() +{ + auto k0 = BlockTransfer.thread_cluster_dims[0]; + auto k1 = BlockTransfer.thread_cluster_dims[2]; + auto k_total = k0 * k1; + + // If vectorization is on k0 (dim 0) or k1 (dim 2), multiply by vector size + if constexpr(BlockTransfer.src_vector_dim == 0 || BlockTransfer.src_vector_dim == 2) + { + k_total *= BlockTransfer.src_scalar_per_vector; + } + + return k_total; +} + +// Calculate M/N dimension coverage (m_n, with vectorization if applicable) +template +constexpr auto get_mn_coverage() +{ + auto mn = BlockTransfer.thread_cluster_dims[1]; + + // If vectorization is on m_n (dim 1), multiply by vector size + if constexpr(BlockTransfer.src_vector_dim == 1) + { + mn *= BlockTransfer.src_scalar_per_vector; + } + + return mn; +} + +template +constexpr auto get_data_max_vec_size() +{ + constexpr auto max_vec_inst_size_bytes = get_max_mem_vec_inst_width(); + static_assert(max_vec_inst_size_bytes % DataTypeSize == 0, + "The max vec instruction size is not a multiple of given data type size."); + return max_vec_inst_size_bytes / DataTypeSize; +} + +} // namespace detail + +// product of thread cluster lengths must be <= workgroup size +template +concept ValidBlockTransferClusterSize = + requires { requires detail::get_cluster_size(BlockTransfer.thread_cluster_dims) <= BlockSize; }; + +// Check that thread cluster covers the K and M dimensions for A transfer +template +concept ThreadsCoverATile = requires { + // K dimension: k0 * k1 * (vectorization) must divide K + requires TileSize.k % detail::get_k_coverage() == 0; + // M dimension: m_n * (vectorization) must divide M + requires TileSize.m % detail::get_mn_coverage() == 0; +}; + +// Check that thread cluster covers the K and N dimensions for B transfer +template +concept ThreadsCoverBTile = requires { + // K dimension: k0 * k1 * (vectorization) must divide K + requires TileSize.k % detail::get_k_coverage() == 0; + // N dimension: m_n * (vectorization) must divide N + requires TileSize.n % detail::get_mn_coverage() == 0; +}; + +template +concept ThreadsCoverCTile = requires { + // M dimension: m_wave_per_xdl must divide M + requires TileSize.m % CBlockTransfer.thread_cluster_dims[1] == 0; + // N dimension: n_wave_per_xdl * (vectorization) must divide N + requires TileSize.n % (CBlockTransfer.thread_cluster_dims[3] * + CBlockTransfer.scalar_per_vector) == 0; +}; + +template +concept IsPowerOf2 = (Value > 0) && ((Value & (Value - 1)) == 0); + +template +concept IsVectorSizeValid = + IsPowerOf2 && (ScalarPerVec <= detail::get_data_max_vec_size()); + +// Composite concept for input block transfer validation (A) +// Includes all validations: vector transfer limits, access order, cluster size, +// vector size validity, and tile coverage +template +concept ValidABlockTransfer = + InputVectorTransferLimits && + AccessOrderLimits && + AccessOrderLimits && + ValidBlockTransferClusterSize && + IsVectorSizeValid && + IsVectorSizeValid && + ThreadsCoverATile; + +// Composite concept for input block transfer validation (B) +template +concept ValidBBlockTransfer = + InputVectorTransferLimits && + AccessOrderLimits && + AccessOrderLimits && + ValidBlockTransferClusterSize && + IsVectorSizeValid && + IsVectorSizeValid && + ThreadsCoverBTile; + +// Composite concept for output block transfer validation (C) +template +concept ValidCBlockTransfer = + OutputVectorTransferLimits && + ValidBlockTransferClusterSize && + IsVectorSizeValid && + ThreadsCoverCTile; + +// Usage: IsValidLayout +template +concept IsValidLayout = ck_tile::is_any_value_of(ACTUAL_LAYOUT, VALID_LAYOUTS...); + } // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp index fc0ee48ec0..79b818555e 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp @@ -11,7 +11,7 @@ namespace ck_tile::builder::factory { template concept TileTransferParameters = SpecifiesBlockTransfer && SpecifiesLdsTransfer && - SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder; + SpecifiesThreadClusterArrangeOrder && SpecifiesSourceAccessOrder; template concept SpecifiesTileTransferParameters3D = TileTransferParameters; diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp index 0ff410d731..b80406c37e 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp @@ -46,14 +46,55 @@ struct ConvFwdLargeTensorFactory internal::SetFwdConvBlockTransfer(); static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer(); - // Check limits for the algorithm parameters. - static_assert(InputVectorTransferLimits); - static_assert(InputVectorTransferLimits); - static_assert(OutputVectorTransferLimits); - static_assert(AccessOrderLimits3D); - static_assert(AccessOrderLimits3D); - static_assert(AccessOrderLimits3D); - static_assert(AccessOrderLimits3D); + // Check limits for the data transfer parameters. + static_assert(ValidABlockTransfer); + static_assert(ValidBBlockTransfer); + static_assert(ValidCBlockTransfer); + + using enum TensorLayout; + static_assert(IsValidLayout && + A_BLOCK_TRANSFER.src_vector_dim == 2); + + static_assert(IsValidLayout && + B_BLOCK_TRANSFER.src_vector_dim == 2); + + static_assert(IsValidLayout); // The forward convolution kernel class instance with large tensor support. using Instance = diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp index dd2fa65eae..74554df7e9 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp @@ -52,14 +52,64 @@ struct ConvFwdXdlV3Factory static constexpr auto BLOCK_GEMM = internal::SetBlockGemm(); // Check limits for the algorithm parameters. - // TODO: Add more limits checks as needed. - static_assert(InputVectorTransferLimits); - static_assert(InputVectorTransferLimits); - static_assert(OutputVectorTransferLimits); - static_assert(AccessOrderLimits3D); - static_assert(AccessOrderLimits3D); - static_assert(AccessOrderLimits3D); - static_assert(AccessOrderLimits3D); + static_assert(ValidABlockTransfer); + static_assert(ValidBBlockTransfer); + static_assert(ValidCBlockTransfer); + + // Layout validations + using enum TensorLayout; + static_assert(IsValidLayout && + A_BLOCK_TRANSFER.src_vector_dim == 2); + + static_assert(IsValidLayout && + B_BLOCK_TRANSFER.src_vector_dim == 2); + + static_assert(IsValidLayout); // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp index 2d6f7c394b..cb36122f7c 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp @@ -48,14 +48,56 @@ struct ConvFwdWmmaFactory static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer(); // Check limits for the algorithm parameters. - // TODO: Add more limits checks as needed. - static_assert(InputVectorTransferLimits); - static_assert(InputVectorTransferLimits); - static_assert(OutputVectorTransferLimits); - static_assert(AccessOrderLimits3D); - static_assert(AccessOrderLimits3D); - static_assert(AccessOrderLimits3D); - static_assert(AccessOrderLimits3D); + static_assert(ValidABlockTransfer); + static_assert(ValidBBlockTransfer); + static_assert(ValidCBlockTransfer); + // TODO: verify Ds transfer as well + + // Layout validations (same as DeviceGroupedConvFwdMultipleD_Wmma_CShuffle) + using enum TensorLayout; + static_assert(IsValidLayout && + A_BLOCK_TRANSFER.src_vector_dim == 2); + + static_assert(IsValidLayout && + B_BLOCK_TRANSFER.src_vector_dim == 2); + + static_assert(IsValidLayout); // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp index e03e035969..b3be21f1f3 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp @@ -47,14 +47,64 @@ struct ConvFwdXdlFactory static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer(); // Check limits for the algorithm parameters. - // TODO: Add more limits checks as needed. - static_assert(InputVectorTransferLimits); - static_assert(InputVectorTransferLimits); - static_assert(OutputVectorTransferLimits); - static_assert(AccessOrderLimits3D); - static_assert(AccessOrderLimits3D); - static_assert(AccessOrderLimits3D); - static_assert(AccessOrderLimits3D); + static_assert(ValidABlockTransfer); + static_assert(ValidBBlockTransfer); + static_assert(ValidCBlockTransfer); + + // Layout validations (same as DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle) + using enum TensorLayout; + static_assert(IsValidLayout && + A_BLOCK_TRANSFER.src_vector_dim == 2); + + static_assert(IsValidLayout && + B_BLOCK_TRANSFER.src_vector_dim == 2); + + static_assert(IsValidLayout); // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp index d873a4b903..249fe0ba24 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp @@ -27,7 +27,7 @@ template constexpr BlockTransfer<> SetFwdConvBlockTransfer() { auto& block_xfer = TRANSFER.block_transfer; - auto& block_order = TRANSFER.block_transfer_access_order; + auto& block_order = TRANSFER.thread_cluster_arrange_order; auto& src_order = TRANSFER.src_access_order; auto& lds_cfg = TRANSFER.lds_transfer; @@ -47,7 +47,7 @@ template constexpr auto SetBwdConvBlockTransfer() { auto& block_xfer = TRANSFER.block_transfer; - auto& block_order = TRANSFER.block_transfer_access_order; + auto& block_order = TRANSFER.thread_cluster_arrange_order; auto& src_order = TRANSFER.src_access_order; auto& lds_cfg = TRANSFER.lds_transfer; diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index 617686fda1..b775505a26 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -126,15 +126,15 @@ struct AccessOrder { std::array order; }; -static_assert(AccessOrderDescriptor>); -static_assert(AccessOrderDescriptor>); +static_assert(ThreadClusterOrderDescriptor>); +static_assert(ThreadClusterOrderDescriptor>); template struct InputTransfer { BlockTransfer block_transfer; LdsTransfer lds_transfer; - AccessOrder block_transfer_access_order; + AccessOrder thread_cluster_arrange_order; AccessOrder src_access_order; }; diff --git a/experimental/builder/test/test_conv_description.cpp b/experimental/builder/test/test_conv_description.cpp index 9e8008ccf0..bcea406fa7 100644 --- a/experimental/builder/test/test_conv_description.cpp +++ b/experimental/builder/test/test_conv_description.cpp @@ -128,26 +128,26 @@ struct DefaultAlgorithm ckb::test::Transfer<> transfer{ .a = { - .block_transfer = {.k0 = 1, .m_n = 128, .k1 = 2}, - .lds_transfer = {.src_vector_dim = 2, - .src_scalar_per_vector = 2, - .lds_dst_scalar_per_vector = 2, - .is_direct_load = false, - .lds_padding = false}, - .block_transfer_access_order = {.order = {0, 1, 2}}, - .src_access_order = {.order = {0, 1, 2}}, + .block_transfer = {.k0 = 1, .m_n = 128, .k1 = 2}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 2, + .lds_dst_scalar_per_vector = 2, + .is_direct_load = false, + .lds_padding = false}, + .thread_cluster_arrange_order = {.order = {0, 1, 2}}, + .src_access_order = {.order = {0, 1, 2}}, }, .b = { - .block_transfer = {.k0 = 1, .m_n = 128, .k1 = 2}, - .lds_transfer = {.src_vector_dim = 2, - .src_scalar_per_vector = 2, - .lds_dst_scalar_per_vector = 2, - .is_direct_load = false, - .lds_padding = false}, - .block_transfer_access_order = {.order = {0, 1, 2}}, - .src_access_order = {.order = {0, 1, 2}}, + .block_transfer = {.k0 = 1, .m_n = 128, .k1 = 2}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 2, + .lds_dst_scalar_per_vector = 2, + .is_direct_load = false, + .lds_padding = false}, + .thread_cluster_arrange_order = {.order = {0, 1, 2}}, + .src_access_order = {.order = {0, 1, 2}}, }, .c = { diff --git a/experimental/builder/test/utils/ckb_conv_test_configs.hpp b/experimental/builder/test/utils/ckb_conv_test_configs.hpp index 3b83ead2d0..e48f1dd6ba 100644 --- a/experimental/builder/test/utils/ckb_conv_test_configs.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_configs.hpp @@ -53,25 +53,25 @@ constexpr DlTransfer<5> DlTransfer5D{.a = DlBlockTransfer_1x8x1x1x1, constexpr Transfer<> Transfer_4x64x1{ .a = { - .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1}, - .lds_transfer = {.src_vector_dim = 2, - .src_scalar_per_vector = 2, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = false}, - .block_transfer_access_order = {1, 0, 2}, - .src_access_order = {1, 0, 2}, + .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 2, + .lds_dst_scalar_per_vector = 4, + .is_direct_load = false, + .lds_padding = false}, + .thread_cluster_arrange_order = {1, 0, 2}, + .src_access_order = {1, 0, 2}, }, .b = { - .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1}, - .lds_transfer = {.src_vector_dim = 2, - .src_scalar_per_vector = 8, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = false}, - .block_transfer_access_order = {1, 0, 2}, - .src_access_order = {1, 0, 2}, + .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 4, + .lds_dst_scalar_per_vector = 4, + .is_direct_load = false, + .lds_padding = false}, + .thread_cluster_arrange_order = {1, 0, 2}, + .src_access_order = {1, 0, 2}, }, .c = { @@ -86,25 +86,25 @@ constexpr Transfer<> Transfer_4x64x1{ constexpr Transfer<4> BwdTransfer_4x64x1{ .a = { - .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1, .k_batch_size = 1}, - .lds_transfer = {.src_vector_dim = 2, - .src_scalar_per_vector = 2, - .lds_dst_scalar_per_vector = 4, - .is_direct_load = false, - .lds_padding = true}, - .block_transfer_access_order = {0, 3, 1, 2}, - .src_access_order = {0, 2, 1, 3}, + .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1, .k_batch_size = 1}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 2, + .lds_dst_scalar_per_vector = 4, + .is_direct_load = false, + .lds_padding = true}, + .thread_cluster_arrange_order = {0, 3, 1, 2}, + .src_access_order = {0, 2, 1, 3}, }, .b = { - .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1, .k_batch_size = 1}, - .lds_transfer = {.src_vector_dim = 2, - .src_scalar_per_vector = 2, - .lds_dst_scalar_per_vector = 4, - .is_direct_load = false, - .lds_padding = true}, - .block_transfer_access_order = {0, 3, 1, 2}, - .src_access_order = {0, 2, 1, 3}, + .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1, .k_batch_size = 1}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 2, + .lds_dst_scalar_per_vector = 4, + .is_direct_load = false, + .lds_padding = true}, + .thread_cluster_arrange_order = {0, 3, 1, 2}, + .src_access_order = {0, 2, 1, 3}, }, .c = { @@ -119,25 +119,25 @@ constexpr Transfer<4> BwdTransfer_4x64x1{ constexpr Transfer<> BwdTransfer_4x8x1_4x16x1_v3{ .a = { - .block_transfer = {.k0 = 4, .m_n = 8, .k1 = 1}, - .lds_transfer = {.src_vector_dim = 1, - .src_scalar_per_vector = 2, - .lds_dst_scalar_per_vector = 2, - .is_direct_load = false, - .lds_padding = false}, - .block_transfer_access_order = {2, 0, 1}, - .src_access_order = {1, 0, 2}, + .block_transfer = {.k0 = 4, .m_n = 8, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 1, + .src_scalar_per_vector = 2, + .lds_dst_scalar_per_vector = 2, + .is_direct_load = false, + .lds_padding = false}, + .thread_cluster_arrange_order = {2, 0, 1}, + .src_access_order = {1, 0, 2}, }, .b = { - .block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1}, - .lds_transfer = {.src_vector_dim = 1, - .src_scalar_per_vector = 2, - .lds_dst_scalar_per_vector = 2, - .is_direct_load = false, - .lds_padding = false}, - .block_transfer_access_order = {2, 0, 1}, - .src_access_order = {1, 0, 2}, + .block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 1, + .src_scalar_per_vector = 2, + .lds_dst_scalar_per_vector = 2, + .is_direct_load = false, + .lds_padding = false}, + .thread_cluster_arrange_order = {2, 0, 1}, + .src_access_order = {1, 0, 2}, }, .c = { @@ -152,25 +152,25 @@ constexpr Transfer<> BwdTransfer_4x8x1_4x16x1_v3{ constexpr Transfer<> Transfer_4x64x1_fp8{ .a = { - .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1}, - .lds_transfer = {.src_vector_dim = 2, - .src_scalar_per_vector = 8, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = true}, - .block_transfer_access_order = {1, 0, 2}, - .src_access_order = {1, 0, 2}, + .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = true}, + .thread_cluster_arrange_order = {1, 0, 2}, + .src_access_order = {1, 0, 2}, }, .b = { - .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1}, - .lds_transfer = {.src_vector_dim = 2, - .src_scalar_per_vector = 8, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = true}, - .block_transfer_access_order = {1, 0, 2}, - .src_access_order = {1, 0, 2}, + .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = true}, + .thread_cluster_arrange_order = {1, 0, 2}, + .src_access_order = {1, 0, 2}, }, .c = { @@ -185,25 +185,25 @@ constexpr Transfer<> Transfer_4x64x1_fp8{ constexpr Transfer<> Transfer_4x16x1{ .a = { - .block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1}, - .lds_transfer = {.src_vector_dim = 2, - .src_scalar_per_vector = 8, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = true}, - .block_transfer_access_order = {1, 0, 2}, - .src_access_order = {1, 0, 2}, + .block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = true}, + .thread_cluster_arrange_order = {1, 0, 2}, + .src_access_order = {1, 0, 2}, }, .b = { - .block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1}, - .lds_transfer = {.src_vector_dim = 2, - .src_scalar_per_vector = 8, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = true}, - .block_transfer_access_order = {1, 0, 2}, - .src_access_order = {1, 0, 2}, + .block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = true}, + .thread_cluster_arrange_order = {1, 0, 2}, + .src_access_order = {1, 0, 2}, }, .c = { @@ -219,25 +219,25 @@ constexpr Transfer<> Transfer_4x16x1{ constexpr Transfer<> Transfer_4x32x1{ .a = { - .block_transfer = {.k0 = 4, .m_n = 32, .k1 = 1}, - .lds_transfer = {.src_vector_dim = 2, - .src_scalar_per_vector = 16, - .lds_dst_scalar_per_vector = 16, - .is_direct_load = false, - .lds_padding = true}, - .block_transfer_access_order = {1, 0, 2}, - .src_access_order = {1, 0, 2}, + .block_transfer = {.k0 = 4, .m_n = 32, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 16, + .lds_dst_scalar_per_vector = 16, + .is_direct_load = false, + .lds_padding = true}, + .thread_cluster_arrange_order = {1, 0, 2}, + .src_access_order = {1, 0, 2}, }, .b = { - .block_transfer = {.k0 = 4, .m_n = 32, .k1 = 1}, - .lds_transfer = {.src_vector_dim = 2, - .src_scalar_per_vector = 16, - .lds_dst_scalar_per_vector = 16, - .is_direct_load = false, - .lds_padding = true}, - .block_transfer_access_order = {1, 0, 2}, - .src_access_order = {1, 0, 2}, + .block_transfer = {.k0 = 4, .m_n = 32, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 16, + .lds_dst_scalar_per_vector = 16, + .is_direct_load = false, + .lds_padding = true}, + .thread_cluster_arrange_order = {1, 0, 2}, + .src_access_order = {1, 0, 2}, }, .c = { diff --git a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp index 23f4cf3364..178029e338 100644 --- a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp +++ b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp @@ -165,7 +165,7 @@ template inline std::string to_string(InputTransfer t) { std::ostringstream oss; - oss << to_string(t.block_transfer) << "," << to_string(t.block_transfer_access_order) << "," + oss << to_string(t.block_transfer) << "," << to_string(t.thread_cluster_arrange_order) << "," << to_string(t.src_access_order) << "," << t.lds_transfer.src_vector_dim << "," << t.lds_transfer.src_scalar_per_vector << "," << t.lds_transfer.lds_dst_scalar_per_vector << "," << (t.lds_transfer.lds_padding ? "true" : "false"); diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index 97e962f5a3..ce6a1349e5 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -1173,4 +1173,11 @@ enum LLVMSchedGroupMask : int32_t DS_WRITE = 1 << 9, ALL = (DS_WRITE << 1) - 1, }; + +CK_TILE_HOST_DEVICE static constexpr auto get_max_mem_vec_inst_width() +{ + // Currently on all arch max memory vector instruction width is 16 bytes. + return 16; +} + } // namespace ck_tile From 66d6a1cfa6807866487becc87cba95a0965f51f9 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 19 Jan 2026 07:41:59 -0800 Subject: [PATCH 23/99] Bump rocm-docs-core[api_reference] from 1.31.2 to 1.31.3 in /docs/sphinx (#3602) Bumps [rocm-docs-core[api_reference]](https://github.com/ROCm/rocm-docs-core) from 1.31.2 to 1.31.3. - [Release notes](https://github.com/ROCm/rocm-docs-core/releases) - [Changelog](https://github.com/ROCm/rocm-docs-core/blob/develop/CHANGELOG.md) - [Commits](https://github.com/ROCm/rocm-docs-core/compare/v1.31.2...v1.31.3) --- updated-dependencies: - dependency-name: rocm-docs-core[api_reference] dependency-version: 1.31.3 dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/sphinx/requirements.in | 2 +- docs/sphinx/requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index b37c5c5652..f2fb27e2b9 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core[api_reference]==1.31.2 +rocm-docs-core[api_reference]==1.31.3 sphinxcontrib-bibtex==2.6.5 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index 7f0d71cc4b..23397503df 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -237,7 +237,7 @@ requests==2.32.3 # via # pygithub # sphinx -rocm-docs-core[api-reference]==1.31.2 +rocm-docs-core[api-reference]==1.31.3 # via -r requirements.in rpds-py==0.24.0 # via From 98abfa4ade0f7b5204adf4da00e95be9453dce74 Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Mon, 19 Jan 2026 12:23:06 -0800 Subject: [PATCH 24/99] Optimize clang-format check in Jenkins CI (#3597) This change improves the clang-format CI check to be faster and not depend on git being available in the build environment. Changes: - Use `find` instead of `git ls-files` (no git dependency) - Check all C++ files: *.h, *.hpp, *.cpp, *.h.in, *.hpp.in, *.cpp.in, *.cl - Exclude build/ and include/rapidjson directories - Use parallel processing with 8 cores (-P 8) for ~8x speedup - Show only errors with unified diff format (-u) - Clear error messages: "ERROR: needs formatting" - Preserve original logic: run clang-format only when RUN_CPPCHECK=false, or run both clang-format and cppcheck when RUN_CPPCHECK=true Performance: - Sequential processing: ~93 seconds for 5,899 files - Parallel with 8 cores: ~12 seconds for 5,899 files - Per-file processing time: ~15ms This reduces CI time while maintaining code formatting standards. --- Jenkinsfile | 31 +++++++++---------------------- 1 file changed, 9 insertions(+), 22 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index e8ce97780d..58b5194f60 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -1318,21 +1318,15 @@ pipeline { agent{ label rocmnode("nogpu") } environment{ setup_args = "NO_CK_BUILD" - execute_cmd = "(cd .. && git ls-files \'*.h\' \ - \'*.hpp\' \ - \'*.cpp\' \ - \'*.h.in\' \ - \'*.hpp.in\' \ - \'*.cpp.in\' \ - \'*.cl\' \ - | grep -v 'build/' \ - | grep -v 'include/rapidjson' \ - | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-18 -style=file {} | diff - {}\') && \ + execute_cmd = """cd .. && \ + find . -type f \\( -name '*.h' -o -name '*.hpp' -o -name '*.cpp' -o -name '*.h.in' -o -name '*.hpp.in' -o -name '*.cpp.in' -o -name '*.cl' \\) \ + -not -path '*/build/*' -not -path '*/include/rapidjson/*' | \ + xargs -P 8 -I{} sh -c 'clang-format-18 -style=file {} | diff -u - {} || (echo "ERROR: {} needs formatting" && exit 1)' && \ /cppcheck/build/bin/cppcheck ../* -v -j \$(nproc) -I ../include -I ../profiler/include -I ../library/include \ -D CK_ENABLE_FP64 -D CK_ENABLE_FP32 -D CK_ENABLE_FP16 -D CK_ENABLE_FP8 -D CK_ENABLE_BF16 -D CK_ENABLE_BF8 -D CK_ENABLE_INT8 \ -D __gfx908__ -D __gfx90a__ -D __gfx942__ -D __gfx1030__ -D __gfx1100__ -D __gfx1101__ -D __gfx1102__ \ -U __gfx803__ -U __gfx900__ -U __gfx906__ -U CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 \ - --file-filter=*.cpp --force --enable=all --output-file=ck_cppcheck.log" + --file-filter=*.cpp --force --enable=all --output-file=ck_cppcheck.log""" } steps{ buildHipClangJobAndReboot(setup_args:setup_args, setup_cmd: "", build_cmd: "", execute_cmd: execute_cmd) @@ -1348,17 +1342,10 @@ pipeline { agent{ label rocmnode("nogpu") } environment{ setup_args = "NO_CK_BUILD" - execute_cmd = "(cd .. && git ls-files \ - \'*.h\' \ - \'*.hpp\' \ - \'*.cpp\' \ - \'*.h.in\' \ - \'*.hpp.in\' \ - \'*.cpp.in\' \ - \'*.cl\' \ - | grep -v 'build/' \ - | grep -v 'include/rapidjson' \ - | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-18 -style=file {} | diff - {}\')" + execute_cmd = """cd .. && \ + find . -type f \\( -name '*.h' -o -name '*.hpp' -o -name '*.cpp' -o -name '*.h.in' -o -name '*.hpp.in' -o -name '*.cpp.in' -o -name '*.cl' \\) \ + -not -path '*/build/*' -not -path '*/include/rapidjson/*' | \ + xargs -P 8 -I{} sh -c 'clang-format-18 -style=file {} | diff -u - {} || (echo "ERROR: {} needs formatting" && exit 1)'""" } steps{ buildHipClangJobAndReboot(setup_args:setup_args, setup_cmd: "", build_cmd: "", execute_cmd: execute_cmd) From f3aafb95552cc2570f952667848310fbe3e982e7 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Tue, 20 Jan 2026 07:22:33 +0800 Subject: [PATCH 25/99] [CK_TILE][FMHA] Add new tile size for async (#3586) * add new tile size for async Signed-off-by: Linjun-AMD * Update example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix lse error Signed-off-by: Linjun-AMD --------- Signed-off-by: Linjun-AMD Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 8 +++++++- .../fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp | 2 ++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index dd65c0298b..81c7b067d3 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -315,7 +315,7 @@ class FmhaFwdApiTrait: assert False def seqtune(self, max_bm0: int) -> str: - if self.bm0 == max_bm0: + if self.bm0 == max_bm0 or self.bm0 == 64: return "true/*fall back to largest tile*/" else: return f"a.seqlen_q <= {self.bm0}" @@ -847,6 +847,11 @@ class CompatibilityRuleFactoryGfx9(CompatibilityRuleFactory): (problem_ctx.hdim, problem_ctx.hdim_v) != (128, 128) and kernel_ctx.tile.F_bm0 != 128 ) + or ( + (problem_ctx.hdim, problem_ctx.hdim_v) == (128, 128) + and kernel_ctx.pipeline.tag != "qr_async" + and kernel_ctx.tile.F_bk0 == 64 + ) ): # non qr_async_trload only support km0=128 tile size when hdim is not 128 # non qr_async only support kn0=128 tile size when hdim is 128 @@ -942,6 +947,7 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9): ( 96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (128, 128) : [FmhaFwdTileSize( 16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), FmhaFwdTileSize( 32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), + FmhaFwdTileSize( 64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 16, -1, CppConstraint('get_num_blocks(64) <= num_cus')), FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], # (160, 160) : [FmhaFwdTileSize(128, 128 , 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 7224ed3a70..e30d4215d6 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -321,6 +321,8 @@ struct BlockFmhaPipelineQRKSVSAsync { if(num_total_loop <= 0) { + buffer_load_fence(0); // rocm-7.1.1, if whole tile is masked out, need to fence(0) + // otherwise will have compute error(maybe compiler bug?) if constexpr(kStoreLSE) { auto lse = From 0517d43d312356c62cc33bea4f0ecc5613e87079 Mon Sep 17 00:00:00 2001 From: Cong Ma <142121551+CongMa13@users.noreply.github.com> Date: Mon, 19 Jan 2026 16:31:02 -0700 Subject: [PATCH 26/99] [CK TILE] remove dependency on std chrono (#3599) * [CK TILE] remove dependency on std chrono * Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- include/ck_tile/host.hpp | 1 + include/ck_tile/host/high_res_cpu_clock.hpp | 95 +++++++++++++++++++++ include/ck_tile/host/timer.hpp | 16 ++-- 3 files changed, 103 insertions(+), 9 deletions(-) create mode 100644 include/ck_tile/host/high_res_cpu_clock.hpp diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index b543fd84e9..014fcfdd65 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -11,6 +11,7 @@ #include "ck_tile/host/device_prop.hpp" #include "ck_tile/host/fill.hpp" #include "ck_tile/host/flush_icache.hpp" +#include "ck_tile/host/high_res_cpu_clock.hpp" #include "ck_tile/host/hip_check_error.hpp" #include "ck_tile/host/host_tensor.hpp" #include "ck_tile/host/joinable_thread.hpp" diff --git a/include/ck_tile/host/high_res_cpu_clock.hpp b/include/ck_tile/host/high_res_cpu_clock.hpp new file mode 100644 index 0000000000..c86f7368d4 --- /dev/null +++ b/include/ck_tile/host/high_res_cpu_clock.hpp @@ -0,0 +1,95 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +namespace ck_tile { + +// Time structure to hold nanoseconds since epoch or arbitrary start point +struct timepoint_t +{ + int64_t nanoseconds; +}; + +// Platform-specific includes and implementation +#if defined(_WIN32) || defined(_WIN64) +// Windows +#include + +static inline timepoint_t high_res_now() +{ + // Cache the performance counter frequency; it is constant for the system lifetime. + static LARGE_INTEGER frequency = []() { + LARGE_INTEGER f; + QueryPerformanceFrequency(&f); + return f; + }(); + + LARGE_INTEGER counter; + timepoint_t tp; + QueryPerformanceCounter(&counter); + + // Convert to nanoseconds using floating-point to avoid 64-bit integer overflow + tp.nanoseconds = + static_cast((static_cast(counter.QuadPart) * 1000000000.0L) / + static_cast(frequency.QuadPart)); + + return tp; +} + +#elif defined(__linux__) || defined(__unix__) || defined(_POSIX_VERSION) +// Linux/Unix/POSIX +#include + +static inline timepoint_t high_res_now() +{ + struct timespec ts; + timepoint_t tp; + + // Use CLOCK_MONOTONIC for consistent timing unaffected by system time changes + // Use CLOCK_REALTIME if you need wall-clock time + clock_gettime(CLOCK_MONOTONIC, &ts); + + tp.nanoseconds = static_cast(ts.tv_sec * 1000000000LL + ts.tv_nsec); + + return tp; +} + +#else +// Fallback for other platforms +#include + +static inline timepoint_t high_res_now() +{ + timepoint_t tp; + time_t t = time(NULL); + tp.nanoseconds = static_cast(t * 1000000000LL); + return tp; +} + +#endif + +// Duration calculation functions +static inline int64_t duration_ns(timepoint_t start, timepoint_t end) +{ + return end.nanoseconds - start.nanoseconds; +} + +static inline int64_t duration_us(timepoint_t start, timepoint_t end) +{ + return (end.nanoseconds - start.nanoseconds) / 1000LL; +} + +static inline int64_t duration_ms(timepoint_t start, timepoint_t end) +{ + return (end.nanoseconds - start.nanoseconds) / 1000000LL; +} + +static inline double duration_sec(timepoint_t start, timepoint_t end) +{ + return static_cast(end.nanoseconds - start.nanoseconds) / 1000000000.0; +} + +} // namespace ck_tile diff --git a/include/ck_tile/host/timer.hpp b/include/ck_tile/host/timer.hpp index 1d641d1812..a300c877e8 100644 --- a/include/ck_tile/host/timer.hpp +++ b/include/ck_tile/host/timer.hpp @@ -5,9 +5,9 @@ #include "ck_tile/core/config.hpp" #include "ck_tile/host/hip_check_error.hpp" +#include "ck_tile/host/high_res_cpu_clock.hpp" #include #include -#include namespace ck_tile { @@ -54,26 +54,24 @@ struct cpu_timer CK_TILE_HOST void start(const hipStream_t& s) { HIP_CHECK_ERROR(hipStreamSynchronize(s)); - start_tick = std::chrono::high_resolution_clock::now(); + start_tick = high_res_now(); } // torch.utils.benchmark.Timer(), there is a sync inside each timer callback CK_TILE_HOST void stop(const hipStream_t& s) { HIP_CHECK_ERROR(hipStreamSynchronize(s)); - stop_tick = std::chrono::high_resolution_clock::now(); + stop_tick = high_res_now(); } // return in ms CK_TILE_HOST float duration() const { - double sec = - std::chrono::duration_cast>(stop_tick - start_tick) - .count(); - return static_cast(sec * 1e3); + auto us = duration_us(start_tick, stop_tick); + return static_cast(us) / 1e3; } private: - std::chrono::time_point start_tick; - std::chrono::time_point stop_tick; + timepoint_t start_tick; + timepoint_t stop_tick; }; } // namespace ck_tile From 0727e85e523aac7a1e82af00f44081cc67f5cde0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Tue, 20 Jan 2026 06:29:01 +0100 Subject: [PATCH 27/99] [CK_BUILDER] Add grouped conv fwd ck tile profiler (#3518) * [BULDER] Add grouped conv fwd ck tile profiler * [CK TILE] Fix grouped conv kernels splitk and double lds * Updates * Fixes * Move to ckProfiler * Fixes * fix * fix * Change instances to empty list by default * fix * fix * Update grouped_convolution_signatures.hpp * Update grouped_convolution_forward_tile_algs.hpp * [CK TILE] Add grouped convolution forward tests (#3556) * [CK TILE] Add grouped convolution forward tests * fix jenkins * fixes * comments fixes * unit test * unit test fix * Move instances outside builder * fix includes * clang format fix * readme fix * fix includes * fixes --- .gitignore | 4 + CMakeLists.txt | 9 +- Jenkinsfile | 38 ++- .../builder/factory/conv_tile_factory.hpp | 41 ++- .../ck_tile/builder/testing/conv_fwd.hpp | 33 +++ .../ck_tile/builder/testing/conv_fwd_ck.hpp | 18 +- .../builder/testing/conv_fwd_ck_tile.hpp | 91 ++++++ .../builder/testing/conv_fwd_reference.hpp | 31 +- .../ck_tile/builder/testing/filter_extent.hpp | 21 ++ .../builder/testing/tensor_descriptor.hpp | 13 + .../ck_tile/builder/testing/testing.hpp | 14 +- .../conv/ck_tile/test_ckb_conv_fwd_e2e.cpp | 84 ++++++ .../test/impl/conv_signature_types.hpp | 8 + .../builder/test/unit_tensor_descriptor.cpp | 2 + .../test/utils/ckb_conv_tile_test_configs.hpp | 4 +- .../CMakeLists.txt | 19 ++ .../README.md | 5 + .../configs/profiler/ndhwgc_bf16.conf | 237 +++++++++++++++ .../configs/profiler/ndhwgc_fp16.conf | 228 +++++++++++++++ .../configs/profiler/ndhwgc_fp32.conf | 176 +++++++++++ .../configs/profiler/nhwgc_bf16.conf | 237 +++++++++++++++ .../configs/profiler/nhwgc_fp16.conf | 228 +++++++++++++++ .../configs/profiler/nhwgc_fp32.conf | 176 +++++++++++ .../configs/tests/ndhwgc_bf16.conf | 41 +++ .../configs/tests/ndhwgc_fp16.conf | 41 +++ .../configs/tests/ndhwgc_fp32.conf | 42 +++ .../configs/tests/nhwgc_bf16.conf | 41 +++ .../configs/tests/nhwgc_fp16.conf | 41 +++ .../configs/tests/nhwgc_fp32.conf | 42 +++ .../generate_instances.py | 275 ++++++++++++++++++ .../grouped_convolution_forward_tile.cpp.in | 19 ++ .../instances/instance_includes.inc | 64 ++++ .../instances/instance_run.inc | 9 + include/ck/library/utility/host_tensor.hpp | 20 +- .../gemm_pipeline_agmem_bgmem_creg_v1.hpp | 2 + include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 12 + .../ops/gemm/warp/warp_gemm_dispatcher.hpp | 2 + .../grouped_convolution_forward_kernel.hpp | 46 ++- .../grouped_convolution_forward_tile_algs.hpp | 169 +++++++++++ .../grouped_convolution_signatures.hpp | 70 +++++ profiler/src/CMakeLists.txt | 9 + .../src/profile_grouped_conv_fwd_tile.cpp | 201 +++++++++++++ test/grouped_convnd_fwd/CMakeLists.txt | 12 + .../test_grouped_convnd_fwd_tile.cpp | 273 +++++++++++++++++ 44 files changed, 3083 insertions(+), 65 deletions(-) create mode 100644 experimental/builder/include/ck_tile/builder/testing/conv_fwd_ck_tile.hpp create mode 100644 experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_e2e.cpp create mode 100644 experimental/grouped_convolution_tile_instances/CMakeLists.txt create mode 100644 experimental/grouped_convolution_tile_instances/README.md create mode 100644 experimental/grouped_convolution_tile_instances/configs/profiler/ndhwgc_bf16.conf create mode 100644 experimental/grouped_convolution_tile_instances/configs/profiler/ndhwgc_fp16.conf create mode 100644 experimental/grouped_convolution_tile_instances/configs/profiler/ndhwgc_fp32.conf create mode 100644 experimental/grouped_convolution_tile_instances/configs/profiler/nhwgc_bf16.conf create mode 100644 experimental/grouped_convolution_tile_instances/configs/profiler/nhwgc_fp16.conf create mode 100644 experimental/grouped_convolution_tile_instances/configs/profiler/nhwgc_fp32.conf create mode 100644 experimental/grouped_convolution_tile_instances/configs/tests/ndhwgc_bf16.conf create mode 100644 experimental/grouped_convolution_tile_instances/configs/tests/ndhwgc_fp16.conf create mode 100644 experimental/grouped_convolution_tile_instances/configs/tests/ndhwgc_fp32.conf create mode 100644 experimental/grouped_convolution_tile_instances/configs/tests/nhwgc_bf16.conf create mode 100644 experimental/grouped_convolution_tile_instances/configs/tests/nhwgc_fp16.conf create mode 100644 experimental/grouped_convolution_tile_instances/configs/tests/nhwgc_fp32.conf create mode 100644 experimental/grouped_convolution_tile_instances/generate_instances.py create mode 100644 experimental/grouped_convolution_tile_instances/instances/grouped_convolution_forward_tile.cpp.in create mode 100644 experimental/grouped_convolution_tile_instances/instances/instance_includes.inc create mode 100644 experimental/grouped_convolution_tile_instances/instances/instance_run.inc create mode 100644 profiler/include/profiler/grouped_convolution_forward_tile_algs.hpp create mode 100644 profiler/include/profiler/grouped_convolution_signatures.hpp create mode 100644 profiler/src/profile_grouped_conv_fwd_tile.cpp create mode 100644 test/grouped_convnd_fwd/test_grouped_convnd_fwd_tile.cpp diff --git a/.gitignore b/.gitignore index 98234268c1..740d5464fb 100644 --- a/.gitignore +++ b/.gitignore @@ -92,3 +92,7 @@ test_data/* # The experimental/builder directory should be tracked despite matching build* !experimental/builder !experimental/builder/** +experimental/grouped_convolution_tile_instances/instances/* +!experimental/grouped_convolution_tile_instances/instances/*.in +!experimental/grouped_convolution_tile_instances/instances/*.inc +experimental/grouped_convolution_tile_instances/*.inc diff --git a/CMakeLists.txt b/CMakeLists.txt index 121c663f64..cd7121b39d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -704,6 +704,11 @@ option(BUILD_MHA_LIB "Build the static library for flash attention" OFF) add_subdirectory(library) +if (CK_EXPERIMENTAL_BUILDER) + add_subdirectory(experimental/builder) + add_subdirectory(experimental/grouped_convolution_tile_instances) +endif() + if(NOT GPU_ARCHS AND USER_GPU_TARGETS AND NOT MIOPEN_REQ_LIBS_ONLY) rocm_package_setup_component(tests LIBRARY_NAME composablekernel @@ -735,10 +740,6 @@ if (NOT MIOPEN_REQ_LIBS_ONLY) add_subdirectory(profiler) endif() -if (CK_EXPERIMENTAL_BUILDER) - add_subdirectory(experimental/builder) -endif() - if(CK_USE_CODEGEN AND (SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR GPU_ARCHS)) add_subdirectory(codegen) endif() diff --git a/Jenkinsfile b/Jenkinsfile index 58b5194f60..2f2229c7a5 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -580,7 +580,7 @@ def cmake_build(Map conf=[:]){ if (params.NINJA_BUILD_TRACE) { echo "running ninja build trace" } - if (params.RUN_BUILDER_TESTS && !setup_args.contains("-DCK_CXX_STANDARD=") && !setup_args.contains("gfx10") && !setup_args.contains("gfx11")) { + if ((params.RUN_BUILDER_TESTS || params.RUN_FULL_CONV_TILE_TESTS) && !setup_args.contains("-DCK_CXX_STANDARD=") && !setup_args.contains("gfx10") && !setup_args.contains("gfx11")) { setup_args = " -D CK_EXPERIMENTAL_BUILDER=ON " + setup_args } setup_cmd = conf.get( @@ -1091,7 +1091,7 @@ CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;RUN_ 0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true 0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true 0 15 * * * % BUILD_INSTANCES_ONLY=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;FORCE_CI=true - 0 13 * * * % RUN_AITER_TESTS=true;BUILD_LEGACY_OS=true;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false;FORCE_CI=true + 0 13 * * * % RUN_FULL_CONV_TILE_TESTS=true;RUN_AITER_TESTS=true;BUILD_LEGACY_OS=true;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false;FORCE_CI=true 0 11 * * * % RUN_PYTORCH_TESTS=true;RUN_CODEGEN_TESTS=false;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false;BUILD_GFX101=false;BUILD_GFX103=false;BUILD_GFX11=false;BUILD_GFX12=false;BUILD_GFX90A=false;FORCE_CI=true''' : "" pipeline { @@ -1255,6 +1255,10 @@ pipeline { name: "RUN_AITER_TESTS", defaultValue: false, description: "Run AITER tests with latest CK develop branch (default: OFF)") + booleanParam( + name: "RUN_FULL_CONV_TILE_TESTS", + defaultValue: false, + description: "Run CK Tile grouped convolution tests with latest CK develop branch (default: OFF)") string( name: 'aiter_branch', defaultValue: 'main', @@ -1410,6 +1414,36 @@ pipeline { } } } + stage("Run Full Grouped Conv Tile Tests") + { + when { + beforeAgent true + expression { env.SHOULD_RUN_CI.toBoolean() } + } + parallel + { + stage("Run Full Grouped Conv Tile Tests on gfx90a") + { + when { + beforeAgent true + expression { params.RUN_FULL_CONV_TILE_TESTS.toBoolean() } + } + agent{ label rocmnode("gfx90a")} + environment{ + setup_args = "NO_CK_BUILD" + execute_args = """ python3 ../experimental/builder/src/generate_instances.py --mode=profiler && \ + ../script/cmake-ck-dev.sh ../ gfx90a && \ + make -j64 test_grouped_convnd_fwd_tile && \ + ./bin/test_grouped_convnd_fwd_tile""" + } + steps{ + // TODO: Reenable after the instance fixes + // buildHipClangJobAndReboot(setup_args:setup_args, build_type: 'Release', execute_cmd: execute_args) + cleanWs() + } + } + } + } stage("Run Grouped Conv Large Case Tests") { when { diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_tile_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_tile_factory.hpp index 6ce508b47d..35c87b61ce 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_tile_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_tile_factory.hpp @@ -98,27 +98,26 @@ struct ConvTileFactory using GemmPipeline = typename internal::TilePipelineType< BLOCK_GEMM.pipeline_version>::template GemmPipeline; - using ConvEpilogue = ck_tile::CShuffleEpilogue>; + using ConvEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; using Instance = typename internal::GroupedConvolutionTileKernel Ops::WeiElementwiseOp b_elementwise_op; Ops::OutElementwiseOp cde_elementwise_op; + int k_batch = 1; + /// This function returns the `TensorDescriptor` corresponding to /// the input-tensor of the convolution problem. This can then /// be used to, for example, allocate memory. @@ -169,6 +172,36 @@ struct Args to_vector(this->input_left_pad), to_vector(this->input_right_pad)); } + + /// Convert the Args structure into a CK Tile conv_param structure. + /// This function is mainly used to be able to use the existing + /// CK Tile functionality to obtain tensor descriptors. + ck_tile::conv::ConvParam to_ck_tile_conv_param() const + { + const auto to_vector = [](const auto& extent) { + if constexpr(SPATIAL_DIM == 1) + return std::vector{ck::index_t(extent.width)}; + else if constexpr(SPATIAL_DIM == 2) + return std::vector{ck::index_t(extent.height), + ck::index_t(extent.width)}; + else + return std::vector{ck::index_t(extent.depth), + ck::index_t(extent.height), + ck::index_t(extent.width)}; + }; + + return ck_tile::conv::ConvParam(SPATIAL_DIM, + this->lengths.groups, + this->lengths.batch_size, + this->lengths.output_channels, + this->lengths.input_channels, + to_vector(this->lengths.filter), + to_vector(this->lengths.image), + to_vector(this->filter_strides), + to_vector(this->filter_dilation), + to_vector(this->input_left_pad), + to_vector(this->input_right_pad)); + } }; /// @brief `Inputs` specialization for forward convolution. diff --git a/experimental/builder/include/ck_tile/builder/testing/conv_fwd_ck.hpp b/experimental/builder/include/ck_tile/builder/testing/conv_fwd_ck.hpp index a90f53ba7d..f911dca21c 100644 --- a/experimental/builder/include/ck_tile/builder/testing/conv_fwd_ck.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/conv_fwd_ck.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/builder/testing/conv_fwd.hpp" +#include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" #include #include @@ -87,16 +88,19 @@ concept CkConvInstance = detail::CkConvInstance; /// @brief `run()` specialization for forward convolution and old CK. /// /// @tparam SIGNATURE Forward convolution signature. -/// @throws std::runtime_error if the arguments werent actually valid for the +/// @throws std::runtime_error if the arguments weren't actually valid for the /// operation. This should be caught and reported by the testing framework. +/// @return std::tuple - whether the problem is supported and +/// kernel execution time (0.0f if s_conf time_kernel is false). /// /// @see run() template requires ValidConvSignature && ConvDirectionIsForward -void run(CkConvInstance auto& conv, - const Args& args, - const Inputs& inputs, - const Outputs& outputs) +std::tuple run(CkConvInstance auto& conv, + const Args& args, + const Inputs& inputs, + const Outputs& outputs, + const StreamConfig s_conf = {}) { constexpr auto spatial_dim = SIGNATURE.spatial_dim; @@ -144,10 +148,10 @@ void run(CkConvInstance auto& conv, if(!conv.IsSupportedArgument(ck_args)) { - throw std::runtime_error("invalid argument"); + std::cout << "invalid argument" << std::endl; } - conv.MakeInvoker().Run(ck_args, {}); + return std::make_tuple(true, conv.MakeInvoker().Run(ck_args, s_conf)); } } // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/conv_fwd_ck_tile.hpp b/experimental/builder/include/ck_tile/builder/testing/conv_fwd_ck_tile.hpp new file mode 100644 index 0000000000..a8f6825524 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/testing/conv_fwd_ck_tile.hpp @@ -0,0 +1,91 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/builder/testing/conv_fwd.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/grouped_convolution.hpp" +#include +#include + +/// This file contains the implementation details for invoking/testing +/// grouped convolution operations in CK Tile. The main item is the +/// `run()` function, which is the main implementation used to invoke +/// CK Tile grouped forward convolution kernels. + +namespace ck_tile::builder::test { + +namespace detail { + +/// @brief Concept for checking whether this is the CK Tile convolution +/// implementation. +/// +/// This is the same as `::ck_tile::builder::test::CkConvInstance`, except +/// with some utility aliases. For that reason, its moved to this detail +/// namespace. +template +concept CkTileConvInstance = requires(Conv&) { + { Conv::BlockSize() }; +}; + +} // namespace detail + +/// @brief Concept for checking whether a convolution is invoked like CK Tile. +/// +/// This concept is used to tell whether a convolution implementation is +/// likely to be an "CK Tile" implementation - that is, whether we should +/// invoke it as an CK Tile kernel. This is mainly used with `run()` to +/// differentiate which implementation that should be invoked. +/// +/// - SIGNATURE is the operation signature. +/// - Conv is a convolution instance created by the CK Builder API. +template +concept CkTileConvInstance = detail::CkTileConvInstance; + +/// @brief `run()` specialization for forward convolution and CK Tile. +/// +/// @tparam SIGNATURE Forward convolution signature. +/// @throws std::runtime_error if the arguments weren't actually valid for the +/// operation. This should be caught and reported by the testing framework. +/// @return std::tuple - whether the problem is supported and +/// kernel execution time (0.0f if s_conf time_kernel is false). +/// +/// @see run() +template + requires ValidConvSignature && ConvDirectionIsForward +std::tuple run(CkTileConvInstance auto& conv, + const Args& args, + const Inputs& inputs, + const Outputs& outputs, + const ck_tile::stream_config s_conf = {}) +{ + using Conv = std::remove_reference_t; + const auto param = args.to_ck_tile_conv_param(); + + ck_tile::GroupedConvFwdHostArgs<> host_args( + param, inputs.input, inputs.weight, {}, outputs.output, args.k_batch); + + auto kargs = Conv::MakeKernelArgs(host_args); + + const dim3 grids = Conv::GridSize(kargs); + const dim3 blocks = Conv::BlockSize(); + + if(!Conv::IsSupportedArgument(kargs)) + { + std::cout << "Not supported!"; + return std::make_tuple(false, 0.f); + } + + constexpr index_t minimum_occupancy = + Conv::GemmPipeline::Scheduler == ck_tile::GemmPipelineScheduler::Intrawave ? 1 : 2; + + return std::make_tuple( + true, + ck_tile::launch_kernel( + s_conf, ck_tile::make_kernel(conv, grids, blocks, 0, kargs))); +} + +} // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/conv_fwd_reference.hpp b/experimental/builder/include/ck_tile/builder/testing/conv_fwd_reference.hpp index 85493e32eb..6401c6a5d5 100644 --- a/experimental/builder/include/ck_tile/builder/testing/conv_fwd_reference.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/conv_fwd_reference.hpp @@ -62,6 +62,8 @@ concept RefConvInstance = requires(Conv& conv, /// @throws std::runtime_error if the arguments weren't actually valid for the /// operation. This should be caught and reported by the testing framework. /// +/// @return std::tuple - whether the problem is supported and +/// kernel execution time (0.0f for reference). /// @see run() template requires ValidConvSignature && @@ -69,10 +71,10 @@ template // for now, just concern outselves with reference and see when the // rest of the bwd/weight plumbing is there. ConvDirectionIsForward -void run(RefConvInstance auto& conv, - const Args& args, - const Inputs& inputs, - const Outputs& outputs) +std::tuple run(RefConvInstance auto& conv, + const Args& args, + const Inputs& inputs, + const Outputs& outputs) { // We don't want to compute the output dims manually, just get // them via the existing infrastructure @@ -86,15 +88,27 @@ void run(RefConvInstance auto& conv, for(auto right_pad : param.input_right_pads_) { if(right_pad != 0) - throw std::runtime_error("TODO: Support right pad in reference conv"); + { + std::cout << "TODO: Support right pad in reference conv" << std::endl; + return std::make_tuple(false, 0.0f); + } } if(!args.make_input_descriptor().is_packed()) - throw std::runtime_error("TODO: Support non-packed input tensor in reference conv"); + { + std::cout << "TODO: Support non-packed input tensor in reference conv" << std::endl; + return std::make_tuple(false, 0.0f); + } if(!args.make_weight_descriptor().is_packed()) - throw std::runtime_error("TODO: Support non-packed weight tensor in reference conv"); + { + std::cout << "TODO: Support non-packed weight tensor in reference conv" << std::endl; + return std::make_tuple(false, 0.0f); + } if(!args.make_output_descriptor().is_packed()) - throw std::runtime_error("TODO: Support non-packed output tensor in reference conv"); + { + std::cout << "TODO: Support non-packed output tensor in reference conv" << std::endl; + return std::make_tuple(false, 0.0f); + } conv.Run(inputs.input, inputs.weight, @@ -109,6 +123,7 @@ void run(RefConvInstance auto& conv, param.conv_filter_strides_, param.conv_filter_dilations_, param.input_left_pads_); + return std::make_tuple(true, 0.0f); } } // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/filter_extent.hpp b/experimental/builder/include/ck_tile/builder/testing/filter_extent.hpp index 3587ac406f..2fc1f39012 100644 --- a/experimental/builder/include/ck_tile/builder/testing/filter_extent.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/filter_extent.hpp @@ -34,4 +34,25 @@ struct FilterExtent<3> size_t depth = 1; }; +template +inline FilterExtent filter_extent_from_vector(const std::vector& vec); + +template <> +inline FilterExtent<1> filter_extent_from_vector<1>(const std::vector& vec) +{ + return FilterExtent<1>{.width = vec[0]}; +} + +template <> +inline FilterExtent<2> filter_extent_from_vector<2>(const std::vector& vec) +{ + return FilterExtent<2>{.width = vec[1], .height = vec[0]}; +} + +template <> +inline FilterExtent<3> filter_extent_from_vector<3>(const std::vector& vec) +{ + return FilterExtent<3>{.width = vec[2], .height = vec[1], .depth = vec[0]}; +} + } // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/tensor_descriptor.hpp b/experimental/builder/include/ck_tile/builder/testing/tensor_descriptor.hpp index 4c99f05c46..6a150a0233 100644 --- a/experimental/builder/include/ck_tile/builder/testing/tensor_descriptor.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/tensor_descriptor.hpp @@ -418,6 +418,10 @@ struct TensorDescriptor size_t x = 1; for(size_t i = 0; i < RANK; ++i) { + if(lengths[indices[i]] == 1) + { + continue; + } if(strides[indices[i]] != x) return false; @@ -443,6 +447,15 @@ struct TensorDescriptor return TensorDescriptor(lengths, strides); } + /// @brief Print tensor descriptor details. + /// + /// Print tensor descriptor details - lengths and strides. + friend std::ostream& operator<<(std::ostream& os, const TensorDescriptor& tensor_desc) + { + os << tensor_desc.inner_descriptor_; + return os; + } + private: ck_tile::HostTensorDescriptor inner_descriptor_; }; diff --git a/experimental/builder/include/ck_tile/builder/testing/testing.hpp b/experimental/builder/include/ck_tile/builder/testing/testing.hpp index eb16402bc2..e61d7c4da5 100644 --- a/experimental/builder/include/ck_tile/builder/testing/testing.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/testing.hpp @@ -317,13 +317,17 @@ ValidationReport validate(const Args& args, /// @param inputs The input tensor data. Will not be modified by this function. /// @param outputs The output tensor data. The contents will be overwritten by /// this function. +/// @param s_conf Stream config used to launch kernel. +/// @return std::tuple - whether the problem is supported and +/// kernel execution time (0.0f if s_conf time_kernel is false). /// /// @note This function is explicitly deleted to generate compile errors /// for missing implementations. -template -void run(Operation& operation, - const Args& args, - const Inputs& inputs, - const Outputs& outputs) = delete; +template +std::tuple run(Operation& operation, + const Args& args, + const Inputs& inputs, + const Outputs& outputs, + const StreamConf s_conf = {}) = delete; } // namespace ck_tile::builder::test diff --git a/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_e2e.cpp b/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_e2e.cpp new file mode 100644 index 0000000000..128744dcc6 --- /dev/null +++ b/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_e2e.cpp @@ -0,0 +1,84 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_tile_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" +#include "ck_tile/builder/testing/conv_fwd_ck_tile.hpp" +#include "ck_tile/builder/testing/conv_fwd_reference.hpp" +#include "ck_tile/host/device_prop.hpp" +#include "testing_utils.hpp" + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; +namespace cku = ck_tile::builder::test_utils; + +constexpr auto SIGNATURE = + ckt::ConvSignature{.spatial_dim = 2, + .direction = ckb::ConvDirection::FORWARD, + .data_type = ckb::DataType::FP16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = ckb::TensorLayout::NHWGC}}, + .weight = {.config = {.layout = ckb::TensorLayout::GKYXC}}, + .output = {.config = {.layout = ckb::TensorLayout::NHWGK}}}; + +constexpr auto ALGORITHM = + cku::ConvAlgorithm_Tile_GroupedConvolutionKernel{} + .with_tile_specializations(ckb::TileConvSpecialization::DEFAULT) + .with_tile_thread_block(cku::FwdTileThreadBlock_64x64x64) + .with_tile_block_gemm(cku::TileBlockGemmDesc_16x16_v3_intrawave) + .with_tile_transfer(cku::FwdTileTransfer_4x4x4) + .with_tile_optimizations(ckt::TileOptimizations{ + .num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false}); + +using Builder = ckb::ConvBuilder; +using Instance = Builder::Instance; +using Reference = ckb::ConvBuilder::Instance; + +TEST(Fwd2DFp16_CShufV3_NHWGC, EndToEnd) +{ + if(!ck_tile::get_device_name().starts_with("gfx9")) + { + GTEST_SKIP() << "unsupported architecture"; + } + + ckt::Args args = { + .lengths = + { + .batch_size = 16, + .groups = 1, + .input_channels = 32, + .output_channels = 48, + .image = + { + .width = 56, + .height = 64, + }, + .filter = + { + .width = 3, + .height = 5, + }, + }, + .filter_strides = {.width = 1, .height = 1}, + .filter_dilation = {.width = 1, .height = 1}, + .input_left_pad = {.width = 0, .height = 0}, + .input_right_pad = {.width = 0, .height = 0}, + .a_elementwise_op = {}, + .b_elementwise_op = {}, + .cde_elementwise_op = {}, + }; + + auto inputs = alloc_inputs(args); + auto outputs = alloc_outputs(args); + auto reference = alloc_outputs(args); + ckt::init_inputs(args, inputs.get()); + + auto conv = Instance{}; + ckt::run(conv, args, inputs.get(), outputs.get()); + + auto ref_conv = Reference{}; + ckt::run(ref_conv, args, inputs.get(), reference.get()); + + EXPECT_THAT(outputs.get(), ck_tile::test::MatchesReference(args, reference.get())); +} diff --git a/experimental/builder/test/impl/conv_signature_types.hpp b/experimental/builder/test/impl/conv_signature_types.hpp index f046289057..e90e10141d 100644 --- a/experimental/builder/test/impl/conv_signature_types.hpp +++ b/experimental/builder/test/impl/conv_signature_types.hpp @@ -16,6 +16,8 @@ struct TensorConfig // Optional data types, override the type defined in the signature if provided. DataType data_type{DataType::UNDEFINED_DATA_TYPE}; DataType compute_type{DataType::UNDEFINED_DATA_TYPE}; + + constexpr bool operator==(const TensorConfig& other) const = default; }; template @@ -31,6 +33,8 @@ struct TensorOperation return TensorOperation{ .elementwise_operation = this->elementwise_operation}; } + + constexpr bool operator==(const TensorOperation& other) const = default; }; template > @@ -38,6 +42,8 @@ struct ConvolutionTensor { TensorConfig config; Op operation{}; + + constexpr bool operator==(const ConvolutionTensor& other) const = default; }; template , @@ -52,6 +58,8 @@ struct ConvSignature InputTensor input; WeightTensor weight; OutputTensor output; + + constexpr bool operator==(const ConvSignature& other) const = default; }; } // namespace ck_tile::builder::test diff --git a/experimental/builder/test/unit_tensor_descriptor.cpp b/experimental/builder/test/unit_tensor_descriptor.cpp index ce6209795a..8e6e269610 100644 --- a/experimental/builder/test/unit_tensor_descriptor.cpp +++ b/experimental/builder/test/unit_tensor_descriptor.cpp @@ -190,6 +190,8 @@ TEST(TensorDescriptor, IsPacked) ckt::make_descriptor
(ckt::Extent{10, 11, 12}, ckt::Extent{1, 100, 1100}).is_packed()); EXPECT_FALSE( ckt::make_descriptor
(ckt::Extent{30, 20, 10}, ckt::Extent{1, 1, 1}).is_packed()); + EXPECT_TRUE( + ckt::make_descriptor
(ckt::Extent{30, 20, 1}, ckt::Extent{1, 30, 30}).is_packed()); } TEST(TensorDescriptor, PrintExtent) diff --git a/experimental/builder/test/utils/ckb_conv_tile_test_configs.hpp b/experimental/builder/test/utils/ckb_conv_tile_test_configs.hpp index 41a1250854..ec59fcca48 100644 --- a/experimental/builder/test/utils/ckb_conv_tile_test_configs.hpp +++ b/experimental/builder/test/utils/ckb_conv_tile_test_configs.hpp @@ -3,8 +3,8 @@ #pragma once -#include "impl/conv_algorithm_types.hpp" -#include "impl/conv_signature_types.hpp" +#include "../impl/conv_algorithm_types.hpp" +#include "../impl/conv_signature_types.hpp" #include "ck_tile/builder/conv_builder.hpp" namespace ck_tile::builder::test_utils { diff --git a/experimental/grouped_convolution_tile_instances/CMakeLists.txt b/experimental/grouped_convolution_tile_instances/CMakeLists.txt new file mode 100644 index 0000000000..1264a68906 --- /dev/null +++ b/experimental/grouped_convolution_tile_instances/CMakeLists.txt @@ -0,0 +1,19 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +if(GPU_TARGETS MATCHES "gfx9") + # Generate instances using python script (empty to just generate empty instance list) + if(NOT EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/grouped_convolution_forward_tile_ndhwgc_fp32.inc) + find_package(Python3 COMPONENTS Interpreter Development) + execute_process( + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/generate_instances.py --mode=tests + RESULT_VARIABLE ret + ) + endif() + + # Find cpp files and create lib for instances + file(GLOB_RECURSE GROUPED_CONV_FWD_TILE "instances/*.cpp") + add_instance_library(device_grouped_conv_fwd_tile_instances ${GROUPED_CONV_FWD_TILE}) + target_include_directories(device_grouped_conv_fwd_tile_instances PRIVATE + "${PROJECT_SOURCE_DIR}/experimental/builder/test/utils") +endif() diff --git a/experimental/grouped_convolution_tile_instances/README.md b/experimental/grouped_convolution_tile_instances/README.md new file mode 100644 index 0000000000..1ba5189695 --- /dev/null +++ b/experimental/grouped_convolution_tile_instances/README.md @@ -0,0 +1,5 @@ +# Grouped Convolution Tile Instances Generator +CK Tile Convolution instances implemented via builder and generated via python script. +It is integrated with tests and ckProfiler +This functionality will be refactored and moved under the Tile Engine. +At now to speed up development and provide tests for CK Tile Convolution it has been implemented under experimental directory. diff --git a/experimental/grouped_convolution_tile_instances/configs/profiler/ndhwgc_bf16.conf b/experimental/grouped_convolution_tile_instances/configs/profiler/ndhwgc_bf16.conf new file mode 100644 index 0000000000..ee62db40ba --- /dev/null +++ b/experimental/grouped_convolution_tile_instances/configs/profiler/ndhwgc_bf16.conf @@ -0,0 +1,237 @@ +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 32, Default, 32, 32, 2, 2, 1, 1, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 32, 32, Default, 32, 32, 2, 1, 8, 8, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 128, 32, Default, 32, 32, 2, 2, 1, 1, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 256, 128, 32, Default, 32, 32, 4, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 256, 32, Default, 32, 32, 2, 4, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 128, 32, Default, 32, 32, 4, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 128, 32, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 64, 32, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 64, 128, 32, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 32, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 64, 32, Default, 32, 32, 2, 1, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 128, 32, Default, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 32, 32, Default, 32, 32, 2, 1, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 32, 128, 32, Default, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 32, 32, Default, 32, 32, 2, 1, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 32, 64, 32, Default, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 32, Filter1x1Pad0, 32, 32, 2, 2, 1, 1, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 32, 32, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 128, 32, Filter1x1Pad0, 32, 32, 2, 2, 1, 1, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 256, 128, 32, Filter1x1Pad0, 32, 32, 4, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 256, 32, Filter1x1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 128, 32, Filter1x1Pad0, 32, 32, 4, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 128, 32, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 64, 32, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 64, 128, 32, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 32, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 64, 32, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 128, 32, Filter1x1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 32, 32, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 32, 128, 32, Filter1x1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 32, 32, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 32, 64, 32, Filter1x1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 32, Filter1x1Stride1Pad0, 32, 32, 2, 2, 1, 1, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 32, 32, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 128, 32, Filter1x1Stride1Pad0, 32, 32, 2, 2, 1, 1, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 256, 128, 32, Filter1x1Stride1Pad0, 32, 32, 4, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 256, 32, Filter1x1Stride1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 128, 32, Filter1x1Stride1Pad0, 32, 32, 4, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 128, 32, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 64, 32, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 64, 128, 32, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 32, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 64, 32, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 128, 32, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 32, 32, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 32, 128, 32, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 32, 32, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 32, 64, 32, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Default, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Default, 16, 16, 2, 2, 2, 1, 2, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Default, 16, 16, 2, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Default, 16, 16, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Pad0, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Pad0, 16, 16, 2, 2, 2, 1, 2, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Pad0, 16, 16, 2, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Pad0, 16, 16, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 2, 1, 2, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<64, 64, 64, 32, Default, 32, 32, 2, 2, 1, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<256, 256, 128, 32, Default, 32, 32, 4, 2, 2, 2, 2, 1, 1> +DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<256, 256, 128, 32, Default, 32, 32, 4, 2, 8, 8, 8, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 8> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 16> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 32> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Filter3x3, 16, 16, 4, 1, 4, 1, 1, 1, 1, 8> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Filter3x3, 16, 16, 4, 1, 4, 1, 1, 1, 1, 16> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Filter3x3, 16, 16, 4, 1, 4, 1, 1, 1, 1, 32> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 32, 32, 4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 32, 32, 4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 32, 32, 4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 224, 256, 64, Default, 16, 16, 7, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 224, 256, 64, Filter1x1Pad0, 16, 16, 7, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 224, 256, 64, Filter1x1Stride1Pad0, 16, 16, 7, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 32, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 32, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 32, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 224, 64, Default, 16, 16, 8, 7, 8, 8, 8, 2, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 224, 64, Filter1x1Pad0, 16, 16, 8, 7, 8, 8, 8, 2, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 224, 64, Filter1x1Stride1Pad0, 16, 16, 8, 7, 8, 8, 8, 2, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 32, 32, 4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 32, 32, 4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 32, 32, 4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 32, 32, 4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 32, 32, 4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 32, 32, 4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Default, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Stride1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 128, 32, Default, 32, 32, 4, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 128, 32, Filter1x1Pad0, 32, 32, 4, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 128, 32, Filter1x1Stride1Pad0, 32, 32, 4, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 64, 64, Default, 32, 32, 2, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 64, 64, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 64, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 64, 128, 64, Default, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 64, 128, 64, Filter1x1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 64, 128, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 64, 64, 64, Default, 32, 32, 1, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 64, 64, 64, Filter1x1Pad0, 32, 32, 1, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 64, 64, 64, Filter1x1Stride1Pad0, 32, 32, 1, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Default, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Default, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 16, 64, Default, 16, 16, 4, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 32, 64, Default, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 16, 64, Default, 16, 16, 4, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 32, 64, Default, 32, 32, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 16, 64, Default, 16, 16, 2, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Default, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 64, 64, Default, 16, 16, 1, 2, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 64, 64, Default, 32, 32, 1, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 128, 64, Default, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 128, 64, Default, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 16, 256, 64, Default, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 32, 256, 64, Default, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 16, 64, Filter1x1Pad0, 16, 16, 4, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 32, 64, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 16, 64, Filter1x1Pad0, 16, 16, 4, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 32, 64, Filter1x1Pad0, 32, 32, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 16, 64, Filter1x1Pad0, 16, 16, 2, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 64, 64, Filter1x1Pad0, 16, 16, 1, 2, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 64, 64, Filter1x1Pad0, 32, 32, 1, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 128, 64, Filter1x1Pad0, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 128, 64, Filter1x1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 16, 256, 64, Filter1x1Pad0, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 32, 256, 64, Filter1x1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 16, 64, Filter1x1Stride1Pad0, 16, 16, 4, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 16, 64, Filter1x1Stride1Pad0, 16, 16, 4, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 32, 64, Filter1x1Stride1Pad0, 32, 32, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 16, 64, Filter1x1Stride1Pad0, 16, 16, 2, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 64, 64, Filter1x1Stride1Pad0, 16, 16, 1, 2, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 64, 64, Filter1x1Stride1Pad0, 32, 32, 1, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 128, 64, Filter1x1Stride1Pad0, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 128, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 16, 256, 64, Filter1x1Stride1Pad0, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 32, 256, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Default, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Default, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 16, 64, Default, 16, 16, 4, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 32, 64, Default, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 16, 64, Default, 16, 16, 4, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 32, 64, Default, 32, 32, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 16, 64, Default, 16, 16, 2, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Default, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 64, 64, Default, 16, 16, 1, 2, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 64, 64, Default, 32, 32, 1, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 128, 64, Default, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 128, 64, Default, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 16, 256, 64, Default, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 32, 256, 64, Default, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 16, 64, Filter1x1Pad0, 16, 16, 4, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 32, 64, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 16, 64, Filter1x1Pad0, 16, 16, 4, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 32, 64, Filter1x1Pad0, 32, 32, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 16, 64, Filter1x1Pad0, 16, 16, 2, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 64, 64, Filter1x1Pad0, 16, 16, 1, 2, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 64, 64, Filter1x1Pad0, 32, 32, 1, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 128, 64, Filter1x1Pad0, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 128, 64, Filter1x1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 16, 256, 64, Filter1x1Pad0, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 32, 256, 64, Filter1x1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 16, 64, Filter1x1Stride1Pad0, 16, 16, 4, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 16, 64, Filter1x1Stride1Pad0, 16, 16, 4, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 32, 64, Filter1x1Stride1Pad0, 32, 32, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 16, 64, Filter1x1Stride1Pad0, 16, 16, 2, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 64, 64, Filter1x1Stride1Pad0, 16, 16, 1, 2, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 64, 64, Filter1x1Stride1Pad0, 32, 32, 1, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 128, 64, Filter1x1Stride1Pad0, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 128, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 16, 256, 64, Filter1x1Stride1Pad0, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 32, 256, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> \ No newline at end of file diff --git a/experimental/grouped_convolution_tile_instances/configs/profiler/ndhwgc_fp16.conf b/experimental/grouped_convolution_tile_instances/configs/profiler/ndhwgc_fp16.conf new file mode 100644 index 0000000000..466b246787 --- /dev/null +++ b/experimental/grouped_convolution_tile_instances/configs/profiler/ndhwgc_fp16.conf @@ -0,0 +1,228 @@ +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 32, Default, 32, 32, 2, 2, 1, 1, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 32, 32, Default, 32, 32, 2, 1, 8, 8, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 128, 32, Default, 32, 32, 2, 2, 1, 1, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 256, 128, 32, Default, 32, 32, 4, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 256, 32, Default, 32, 32, 2, 4, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 128, 32, Default, 32, 32, 4, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 128, 32, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 64, 32, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 64, 128, 32, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 32, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 64, 32, Default, 32, 32, 2, 1, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 128, 32, Default, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 32, 32, Default, 32, 32, 2, 1, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 32, 128, 32, Default, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 32, 32, Default, 32, 32, 2, 1, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 32, 64, 32, Default, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 32, Filter1x1Pad0, 32, 32, 2, 2, 1, 1, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 32, 32, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 128, 32, Filter1x1Pad0, 32, 32, 2, 2, 1, 1, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 256, 128, 32, Filter1x1Pad0, 32, 32, 4, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 256, 32, Filter1x1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 128, 32, Filter1x1Pad0, 32, 32, 4, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 128, 32, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 64, 32, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 64, 128, 32, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 32, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 64, 32, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 128, 32, Filter1x1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 32, 32, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 32, 128, 32, Filter1x1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 32, 32, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 32, 64, 32, Filter1x1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 32, Filter1x1Stride1Pad0, 32, 32, 2, 2, 1, 1, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 32, 32, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 128, 32, Filter1x1Stride1Pad0, 32, 32, 2, 2, 1, 1, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 256, 128, 32, Filter1x1Stride1Pad0, 32, 32, 4, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 256, 32, Filter1x1Stride1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 128, 32, Filter1x1Stride1Pad0, 32, 32, 4, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 128, 32, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 64, 32, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 64, 128, 32, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 32, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 64, 32, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 128, 32, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 32, 32, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 32, 128, 32, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 32, 32, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 32, 64, 32, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Default, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Default, 16, 16, 2, 2, 2, 1, 2, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Default, 16, 16, 2, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Default, 16, 16, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Pad0, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Pad0, 16, 16, 2, 2, 2, 1, 2, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Pad0, 16, 16, 2, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Pad0, 16, 16, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 2, 1, 2, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<64, 64, 64, 32, Default, 32, 32, 2, 2, 1, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<256, 256, 128, 32, Default, 32, 32, 4, 2, 2, 2, 2, 1, 1> +DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<256, 256, 128, 32, Default, 32, 32, 4, 2, 8, 8, 8, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 8> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 16> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 32> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Filter3x3, 16, 16, 4, 1, 4, 1, 1, 1, 1, 8> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Filter3x3, 16, 16, 4, 1, 4, 1, 1, 1, 1, 16> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Filter3x3, 16, 16, 4, 1, 4, 1, 1, 1, 1, 32> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 32, 32, 4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 32, 32, 4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 32, 32, 4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 32, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 32, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 32, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 32, 32, 4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 32, 32, 4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 32, 32, 4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 32, 32, 4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 32, 32, 4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 32, 32, 4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 224, 256, 64, Default, 16, 16, 7, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 224, 256, 64, Filter1x1Pad0, 16, 16, 7, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 224, 256, 64, Filter1x1Stride1Pad0, 16, 16, 7, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 224, 64, Default, 16, 16, 8, 7, 8, 8, 8, 2, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 224, 64, Filter1x1Pad0, 16, 16, 8, 7, 8, 8, 8, 2, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 224, 64, Filter1x1Stride1Pad0, 16, 16, 8, 7, 8, 8, 8, 2, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Default, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Stride1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 128, 32, Default, 32, 32, 4, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 128, 32, Filter1x1Pad0, 32, 32, 4, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 128, 32, Filter1x1Stride1Pad0, 32, 32, 4, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Default, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Default, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 16, 64, Default, 16, 16, 4, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 32, 64, Default, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 16, 64, Default, 16, 16, 4, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 32, 64, Default, 32, 32, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 16, 64, Default, 16, 16, 2, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Default, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 64, 64, Default, 16, 16, 1, 2, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 64, 64, Default, 32, 32, 1, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 128, 64, Default, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 128, 64, Default, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 16, 256, 64, Default, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 32, 256, 64, Default, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 16, 64, Filter1x1Pad0, 16, 16, 4, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 32, 64, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 16, 64, Filter1x1Pad0, 16, 16, 4, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 32, 64, Filter1x1Pad0, 32, 32, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 16, 64, Filter1x1Pad0, 16, 16, 2, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 64, 64, Filter1x1Pad0, 16, 16, 1, 2, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 64, 64, Filter1x1Pad0, 32, 32, 1, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 128, 64, Filter1x1Pad0, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 128, 64, Filter1x1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 16, 256, 64, Filter1x1Pad0, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 32, 256, 64, Filter1x1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 16, 64, Filter1x1Stride1Pad0, 16, 16, 4, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 16, 64, Filter1x1Stride1Pad0, 16, 16, 4, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 32, 64, Filter1x1Stride1Pad0, 32, 32, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 16, 64, Filter1x1Stride1Pad0, 16, 16, 2, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 64, 64, Filter1x1Stride1Pad0, 16, 16, 1, 2, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 64, 64, Filter1x1Stride1Pad0, 32, 32, 1, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 128, 64, Filter1x1Stride1Pad0, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 128, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 16, 256, 64, Filter1x1Stride1Pad0, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 32, 256, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Default, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Default, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 16, 64, Default, 16, 16, 4, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 32, 64, Default, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 16, 64, Default, 16, 16, 4, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 32, 64, Default, 32, 32, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 16, 64, Default, 16, 16, 2, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Default, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 64, 64, Default, 16, 16, 1, 2, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 64, 64, Default, 32, 32, 1, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 128, 64, Default, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 128, 64, Default, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 16, 256, 64, Default, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 32, 256, 64, Default, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 16, 64, Filter1x1Pad0, 16, 16, 4, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 32, 64, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 16, 64, Filter1x1Pad0, 16, 16, 4, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 32, 64, Filter1x1Pad0, 32, 32, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 16, 64, Filter1x1Pad0, 16, 16, 2, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 64, 64, Filter1x1Pad0, 16, 16, 1, 2, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 64, 64, Filter1x1Pad0, 32, 32, 1, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 128, 64, Filter1x1Pad0, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 128, 64, Filter1x1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 16, 256, 64, Filter1x1Pad0, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 32, 256, 64, Filter1x1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 16, 64, Filter1x1Stride1Pad0, 16, 16, 4, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 16, 64, Filter1x1Stride1Pad0, 16, 16, 4, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 32, 64, Filter1x1Stride1Pad0, 32, 32, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 16, 64, Filter1x1Stride1Pad0, 16, 16, 2, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 64, 64, Filter1x1Stride1Pad0, 16, 16, 1, 2, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 64, 64, Filter1x1Stride1Pad0, 32, 32, 1, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 128, 64, Filter1x1Stride1Pad0, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 128, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 16, 256, 64, Filter1x1Stride1Pad0, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 32, 256, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> \ No newline at end of file diff --git a/experimental/grouped_convolution_tile_instances/configs/profiler/ndhwgc_fp32.conf b/experimental/grouped_convolution_tile_instances/configs/profiler/ndhwgc_fp32.conf new file mode 100644 index 0000000000..7dc982b6f7 --- /dev/null +++ b/experimental/grouped_convolution_tile_instances/configs/profiler/ndhwgc_fp32.conf @@ -0,0 +1,176 @@ +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 16, Default, 32, 32, 2, 2, 1, 1, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 32, 16, Default, 32, 32, 2, 1, 4, 4, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 128, 16, Default, 32, 32, 2, 2, 1, 1, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 256, 128, 16, Default, 32, 32, 4, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 256, 16, Default, 32, 32, 2, 4, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 128, 16, Default, 32, 32, 4, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 128, 16, Default, 32, 32, 2, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 64, 16, Default, 32, 32, 2, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 64, 128, 16, Default, 32, 32, 2, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 16, Default, 32, 32, 2, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 64, 16, Default, 32, 32, 2, 1, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 128, 16, Default, 32, 32, 1, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 32, 16, Default, 32, 32, 2, 1, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 32, 128, 16, Default, 32, 32, 1, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 32, 16, Default, 32, 32, 2, 1, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 32, 64, 16, Default, 32, 32, 1, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 192, 16, Default, 32, 32, 2, 3, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 16, Filter1x1Pad0, 32, 32, 2, 2, 1, 1, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 32, 16, Filter1x1Pad0, 32, 32, 2, 1, 4, 4, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 128, 16, Filter1x1Pad0, 32, 32, 2, 2, 1, 1, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 256, 128, 16, Filter1x1Pad0, 32, 32, 4, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 256, 16, Filter1x1Pad0, 32, 32, 2, 4, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 128, 16, Filter1x1Pad0, 32, 32, 4, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 128, 16, Filter1x1Pad0, 32, 32, 2, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 64, 16, Filter1x1Pad0, 32, 32, 2, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 64, 128, 16, Filter1x1Pad0, 32, 32, 2, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 16, Filter1x1Pad0, 32, 32, 2, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 64, 16, Filter1x1Pad0, 32, 32, 2, 1, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 128, 16, Filter1x1Pad0, 32, 32, 1, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 32, 16, Filter1x1Pad0, 32, 32, 2, 1, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 32, 128, 16, Filter1x1Pad0, 32, 32, 1, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 32, 16, Filter1x1Pad0, 32, 32, 2, 1, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 32, 64, 16, Filter1x1Pad0, 32, 32, 1, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 192, 16, Filter1x1Pad0, 32, 32, 2, 3, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 16, Filter1x1Stride1Pad0, 32, 32, 2, 2, 1, 1, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 32, 16, Filter1x1Stride1Pad0, 32, 32, 2, 1, 4, 4, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 128, 16, Filter1x1Stride1Pad0, 32, 32, 2, 2, 1, 1, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 256, 128, 16, Filter1x1Stride1Pad0, 32, 32, 4, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 256, 16, Filter1x1Stride1Pad0, 32, 32, 2, 4, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 128, 16, Filter1x1Stride1Pad0, 32, 32, 4, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 128, 16, Filter1x1Stride1Pad0, 32, 32, 2, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 64, 16, Filter1x1Stride1Pad0, 32, 32, 2, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 64, 128, 16, Filter1x1Stride1Pad0, 32, 32, 2, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 16, Filter1x1Stride1Pad0, 32, 32, 2, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 64, 16, Filter1x1Stride1Pad0, 32, 32, 2, 1, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 128, 16, Filter1x1Stride1Pad0, 32, 32, 1, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 32, 16, Filter1x1Stride1Pad0, 32, 32, 2, 1, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 32, 128, 16, Filter1x1Stride1Pad0, 32, 32, 1, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 32, 16, Filter1x1Stride1Pad0, 32, 32, 2, 1, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 32, 64, 16, Filter1x1Stride1Pad0, 32, 32, 1, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 192, 16, Filter1x1Stride1Pad0, 32, 32, 2, 3, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Default, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Default, 16, 16, 2, 2, 2, 1, 2, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Default, 16, 16, 2, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Pad0, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Pad0, 16, 16, 2, 2, 2, 1, 2, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Pad0, 16, 16, 2, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 2, 1, 2, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<64, 64, 64, 16, Default, 32, 32, 2, 2, 1, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<256, 256, 128, 16, Default, 32, 32, 4, 2, 4, 4, 4, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 8> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 16> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 32> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Filter3x3, 16, 16, 4, 1, 4, 1, 1, 1, 1, 8> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Filter3x3, 16, 16, 4, 1, 4, 1, 1, 1, 1, 16> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Filter3x3, 16, 16, 4, 1, 4, 1, 1, 1, 1, 32> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 32, Default, 32, 32, 2, 2, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 32, Filter1x1Pad0, 32, 32, 2, 2, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 32, Filter1x1Stride1Pad0, 32, 32, 2, 2, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Default, 16, 16, 1, 1, 4, 4, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Default, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Default, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Default, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 32, 64, Default, 32, 32, 2, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 16, 64, Default, 16, 16, 4, 1, 4, 4, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 32, 64, Default, 32, 32, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 16, 64, Default, 16, 16, 2, 1, 4, 4, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Default, 16, 16, 1, 1, 4, 4, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Default, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Default, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Default, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 64, 64, Default, 16, 16, 1, 2, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 64, 64, Default, 32, 32, 1, 1, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 128, 64, Default, 16, 16, 1, 4, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 128, 64, Default, 32, 32, 1, 2, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 4, 4, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Filter1x1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 32, 64, Filter1x1Pad0, 32, 32, 2, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 16, 64, Filter1x1Pad0, 16, 16, 4, 1, 4, 4, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 32, 64, Filter1x1Pad0, 32, 32, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 16, 64, Filter1x1Pad0, 16, 16, 2, 1, 4, 4, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 4, 4, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Filter1x1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 64, 64, Filter1x1Pad0, 16, 16, 1, 2, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 64, 64, Filter1x1Pad0, 32, 32, 1, 1, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 128, 64, Filter1x1Pad0, 16, 16, 1, 4, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 128, 64, Filter1x1Pad0, 32, 32, 1, 2, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 4, 4, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Stride1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 16, 64, Filter1x1Stride1Pad0, 16, 16, 4, 1, 4, 4, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 32, 64, Filter1x1Stride1Pad0, 32, 32, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 16, 64, Filter1x1Stride1Pad0, 16, 16, 2, 1, 4, 4, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 4, 4, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Stride1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 64, 64, Filter1x1Stride1Pad0, 16, 16, 1, 2, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 64, 64, Filter1x1Stride1Pad0, 32, 32, 1, 1, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 128, 64, Filter1x1Stride1Pad0, 16, 16, 1, 4, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 128, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Default, 16, 16, 1, 1, 4, 4, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Default, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Default, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Default, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 32, 64, Default, 32, 32, 2, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 16, 64, Default, 16, 16, 4, 1, 4, 4, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 32, 64, Default, 32, 32, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 16, 64, Default, 16, 16, 2, 1, 4, 4, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Default, 16, 16, 1, 1, 4, 4, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Default, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Default, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Default, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 64, 64, Default, 16, 16, 1, 2, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 64, 64, Default, 32, 32, 1, 1, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 128, 64, Default, 16, 16, 1, 4, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 128, 64, Default, 32, 32, 1, 2, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 4, 4, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Filter1x1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 32, 64, Filter1x1Pad0, 32, 32, 2, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 16, 64, Filter1x1Pad0, 16, 16, 4, 1, 4, 4, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 32, 64, Filter1x1Pad0, 32, 32, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 16, 64, Filter1x1Pad0, 16, 16, 2, 1, 4, 4, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 4, 4, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Filter1x1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 64, 64, Filter1x1Pad0, 16, 16, 1, 2, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 64, 64, Filter1x1Pad0, 32, 32, 1, 1, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 128, 64, Filter1x1Pad0, 16, 16, 1, 4, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 128, 64, Filter1x1Pad0, 32, 32, 1, 2, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 4, 4, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Stride1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 16, 64, Filter1x1Stride1Pad0, 16, 16, 4, 1, 4, 4, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 32, 64, Filter1x1Stride1Pad0, 32, 32, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 16, 64, Filter1x1Stride1Pad0, 16, 16, 2, 1, 4, 4, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 4, 4, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Stride1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 64, 64, Filter1x1Stride1Pad0, 16, 16, 1, 2, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 64, 64, Filter1x1Stride1Pad0, 32, 32, 1, 1, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 128, 64, Filter1x1Stride1Pad0, 16, 16, 1, 4, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 128, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> \ No newline at end of file diff --git a/experimental/grouped_convolution_tile_instances/configs/profiler/nhwgc_bf16.conf b/experimental/grouped_convolution_tile_instances/configs/profiler/nhwgc_bf16.conf new file mode 100644 index 0000000000..c7a6ba489e --- /dev/null +++ b/experimental/grouped_convolution_tile_instances/configs/profiler/nhwgc_bf16.conf @@ -0,0 +1,237 @@ +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 32, Default, 32, 32, 2, 2, 1, 1, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 32, 32, Default, 32, 32, 2, 1, 8, 8, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 128, 32, Default, 32, 32, 2, 2, 1, 1, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 256, 128, 32, Default, 32, 32, 4, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 256, 32, Default, 32, 32, 2, 4, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 128, 32, Default, 32, 32, 4, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 128, 32, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 64, 32, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 64, 128, 32, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 32, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 64, 32, Default, 32, 32, 2, 1, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 128, 32, Default, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 32, 32, Default, 32, 32, 2, 1, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 32, 128, 32, Default, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 32, 32, Default, 32, 32, 2, 1, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 32, 64, 32, Default, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 32, Filter1x1Pad0, 32, 32, 2, 2, 1, 1, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 32, 32, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 128, 32, Filter1x1Pad0, 32, 32, 2, 2, 1, 1, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 256, 128, 32, Filter1x1Pad0, 32, 32, 4, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 256, 32, Filter1x1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 128, 32, Filter1x1Pad0, 32, 32, 4, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 128, 32, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 64, 32, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 64, 128, 32, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 32, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 64, 32, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 128, 32, Filter1x1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 32, 32, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 32, 128, 32, Filter1x1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 32, 32, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 32, 64, 32, Filter1x1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 32, Filter1x1Stride1Pad0, 32, 32, 2, 2, 1, 1, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 32, 32, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 128, 32, Filter1x1Stride1Pad0, 32, 32, 2, 2, 1, 1, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 256, 128, 32, Filter1x1Stride1Pad0, 32, 32, 4, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 256, 32, Filter1x1Stride1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 128, 32, Filter1x1Stride1Pad0, 32, 32, 4, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 128, 32, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 64, 32, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 64, 128, 32, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 32, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 64, 32, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 128, 32, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 32, 32, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 32, 128, 32, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 32, 32, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 32, 64, 32, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Default, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Default, 16, 16, 2, 2, 2, 1, 2, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Default, 16, 16, 2, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Default, 16, 16, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Pad0, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Pad0, 16, 16, 2, 2, 2, 1, 2, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Pad0, 16, 16, 2, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Pad0, 16, 16, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 2, 1, 2, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<64, 64, 64, 32, Default, 32, 32, 2, 2, 1, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<256, 256, 128, 32, Default, 32, 32, 4, 2, 2, 2, 2, 1, 1> +DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<256, 256, 128, 32, Default, 32, 32, 4, 2, 8, 8, 8, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 8> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 16> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 32> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Filter3x3, 16, 16, 4, 1, 4, 1, 1, 1, 1, 8> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Filter3x3, 16, 16, 4, 1, 4, 1, 1, 1, 1, 16> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Filter3x3, 16, 16, 4, 1, 4, 1, 1, 1, 1, 32> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 32, 32, 4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 32, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 32, 32, 4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 32, 32, 4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Default, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 128, 32, Default, 32, 32, 4, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 64, 64, Default, 32, 32, 2, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 64, 128, 64, Default, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 64, 64, 64, Default, 32, 32, 1, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 32, 32, 4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 32, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 32, 32, 4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 32, 32, 4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 128, 32, Filter1x1Pad0, 32, 32, 4, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 64, 64, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 64, 128, 64, Filter1x1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 64, 64, 64, Filter1x1Pad0, 32, 32, 1, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 32, 32, 4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 32, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 32, 32, 4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 32, 32, 4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Stride1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 128, 32, Filter1x1Stride1Pad0, 32, 32, 4, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 64, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 64, 128, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 64, 64, 64, Filter1x1Stride1Pad0, 32, 32, 1, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 224, 256, 64, Default, 16, 16, 7, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 224, 64, Default, 16, 16, 8, 7, 8, 8, 8, 2, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 224, 256, 64, Filter1x1Pad0, 16, 16, 7, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 224, 64, Filter1x1Pad0, 16, 16, 8, 7, 8, 8, 8, 2, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 224, 256, 64, Filter1x1Stride1Pad0, 16, 16, 7, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 224, 64, Filter1x1Stride1Pad0, 16, 16, 8, 7, 8, 8, 8, 2, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Default, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Default, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 16, 64, Default, 16, 16, 4, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 32, 64, Default, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 16, 64, Default, 16, 16, 4, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 32, 64, Default, 32, 32, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 16, 64, Default, 16, 16, 2, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Default, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 64, 64, Default, 16, 16, 1, 2, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 64, 64, Default, 32, 32, 1, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 128, 64, Default, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 128, 64, Default, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 16, 256, 64, Default, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 32, 256, 64, Default, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 16, 64, Filter1x1Pad0, 16, 16, 4, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 32, 64, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 16, 64, Filter1x1Pad0, 16, 16, 4, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 32, 64, Filter1x1Pad0, 32, 32, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 16, 64, Filter1x1Pad0, 16, 16, 2, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 64, 64, Filter1x1Pad0, 16, 16, 1, 2, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 64, 64, Filter1x1Pad0, 32, 32, 1, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 128, 64, Filter1x1Pad0, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 128, 64, Filter1x1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 16, 256, 64, Filter1x1Pad0, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 32, 256, 64, Filter1x1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 16, 64, Filter1x1Stride1Pad0, 16, 16, 4, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 16, 64, Filter1x1Stride1Pad0, 16, 16, 4, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 32, 64, Filter1x1Stride1Pad0, 32, 32, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 16, 64, Filter1x1Stride1Pad0, 16, 16, 2, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 64, 64, Filter1x1Stride1Pad0, 16, 16, 1, 2, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 64, 64, Filter1x1Stride1Pad0, 32, 32, 1, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 128, 64, Filter1x1Stride1Pad0, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 128, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 16, 256, 64, Filter1x1Stride1Pad0, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 32, 256, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Default, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Default, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 16, 64, Default, 16, 16, 4, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 32, 64, Default, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 16, 64, Default, 16, 16, 4, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 32, 64, Default, 32, 32, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 16, 64, Default, 16, 16, 2, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Default, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 64, 64, Default, 16, 16, 1, 2, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 64, 64, Default, 32, 32, 1, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 128, 64, Default, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 128, 64, Default, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 16, 256, 64, Default, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 32, 256, 64, Default, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 16, 64, Filter1x1Pad0, 16, 16, 4, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 32, 64, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 16, 64, Filter1x1Pad0, 16, 16, 4, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 32, 64, Filter1x1Pad0, 32, 32, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 16, 64, Filter1x1Pad0, 16, 16, 2, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 64, 64, Filter1x1Pad0, 16, 16, 1, 2, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 64, 64, Filter1x1Pad0, 32, 32, 1, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 128, 64, Filter1x1Pad0, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 128, 64, Filter1x1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 16, 256, 64, Filter1x1Pad0, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 32, 256, 64, Filter1x1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 16, 64, Filter1x1Stride1Pad0, 16, 16, 4, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 16, 64, Filter1x1Stride1Pad0, 16, 16, 4, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 32, 64, Filter1x1Stride1Pad0, 32, 32, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 16, 64, Filter1x1Stride1Pad0, 16, 16, 2, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 64, 64, Filter1x1Stride1Pad0, 16, 16, 1, 2, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 64, 64, Filter1x1Stride1Pad0, 32, 32, 1, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 128, 64, Filter1x1Stride1Pad0, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 128, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 16, 256, 64, Filter1x1Stride1Pad0, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 32, 256, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> \ No newline at end of file diff --git a/experimental/grouped_convolution_tile_instances/configs/profiler/nhwgc_fp16.conf b/experimental/grouped_convolution_tile_instances/configs/profiler/nhwgc_fp16.conf new file mode 100644 index 0000000000..4e31ba2b06 --- /dev/null +++ b/experimental/grouped_convolution_tile_instances/configs/profiler/nhwgc_fp16.conf @@ -0,0 +1,228 @@ +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 32, Default, 32, 32, 2, 2, 1, 1, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 32, 32, Default, 32, 32, 2, 1, 8, 8, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 128, 32, Default, 32, 32, 2, 2, 1, 1, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 256, 128, 32, Default, 32, 32, 4, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 256, 32, Default, 32, 32, 2, 4, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 128, 32, Default, 32, 32, 4, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 128, 32, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 64, 32, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 64, 128, 32, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 32, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 64, 32, Default, 32, 32, 2, 1, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 128, 32, Default, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 32, 32, Default, 32, 32, 2, 1, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 32, 128, 32, Default, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 32, 32, Default, 32, 32, 2, 1, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 32, 64, 32, Default, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 32, Filter1x1Pad0, 32, 32, 2, 2, 1, 1, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 32, 32, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 128, 32, Filter1x1Pad0, 32, 32, 2, 2, 1, 1, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 256, 128, 32, Filter1x1Pad0, 32, 32, 4, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 256, 32, Filter1x1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 128, 32, Filter1x1Pad0, 32, 32, 4, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 128, 32, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 64, 32, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 64, 128, 32, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 32, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 64, 32, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 128, 32, Filter1x1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 32, 32, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 32, 128, 32, Filter1x1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 32, 32, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 32, 64, 32, Filter1x1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 32, Filter1x1Stride1Pad0, 32, 32, 2, 2, 1, 1, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 32, 32, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 128, 32, Filter1x1Stride1Pad0, 32, 32, 2, 2, 1, 1, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 256, 128, 32, Filter1x1Stride1Pad0, 32, 32, 4, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 256, 32, Filter1x1Stride1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 128, 32, Filter1x1Stride1Pad0, 32, 32, 4, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 128, 32, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 64, 32, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 64, 128, 32, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 32, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 64, 32, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 128, 32, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 32, 32, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 32, 128, 32, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 32, 32, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 32, 64, 32, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Default, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Default, 16, 16, 2, 2, 2, 1, 2, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Default, 16, 16, 2, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Default, 16, 16, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Pad0, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Pad0, 16, 16, 2, 2, 2, 1, 2, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Pad0, 16, 16, 2, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Pad0, 16, 16, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 2, 1, 2, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<64, 64, 64, 32, Default, 32, 32, 2, 2, 1, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<256, 256, 128, 32, Default, 32, 32, 4, 2, 2, 2, 2, 1, 1> +DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<256, 256, 128, 32, Default, 32, 32, 4, 2, 8, 8, 8, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 8> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 16> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 32> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Filter3x3, 16, 16, 4, 1, 4, 1, 1, 1, 1, 8> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Filter3x3, 16, 16, 4, 1, 4, 1, 1, 1, 1, 16> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Filter3x3, 16, 16, 4, 1, 4, 1, 1, 1, 1, 32> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 32, 32, 4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 32, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 32, 32, 4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 32, 32, 4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 224, 256, 64, Default, 16, 16, 7, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 224, 64, Default, 16, 16, 8, 7, 8, 8, 8, 2, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Default, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 128, 32, Default, 32, 32, 4, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 32, 32, 4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 32, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 32, 32, 4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 32, 32, 4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 224, 256, 64, Filter1x1Pad0, 16, 16, 7, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 224, 64, Filter1x1Pad0, 16, 16, 8, 7, 8, 8, 8, 2, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 128, 32, Filter1x1Pad0, 32, 32, 4, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 32, 32, 4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 32, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 32, 32, 4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 32, 32, 4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 224, 256, 64, Filter1x1Stride1Pad0, 16, 16, 7, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 224, 64, Filter1x1Stride1Pad0, 16, 16, 8, 7, 8, 8, 8, 2, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Stride1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 128, 32, Filter1x1Stride1Pad0, 32, 32, 4, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Default, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Default, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 16, 64, Default, 16, 16, 4, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 32, 64, Default, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 16, 64, Default, 16, 16, 4, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 32, 64, Default, 32, 32, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 16, 64, Default, 16, 16, 2, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Default, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 64, 64, Default, 16, 16, 1, 2, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 64, 64, Default, 32, 32, 1, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 128, 64, Default, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 128, 64, Default, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 16, 256, 64, Default, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 32, 256, 64, Default, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 16, 64, Filter1x1Pad0, 16, 16, 4, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 32, 64, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 16, 64, Filter1x1Pad0, 16, 16, 4, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 32, 64, Filter1x1Pad0, 32, 32, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 16, 64, Filter1x1Pad0, 16, 16, 2, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 64, 64, Filter1x1Pad0, 16, 16, 1, 2, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 64, 64, Filter1x1Pad0, 32, 32, 1, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 128, 64, Filter1x1Pad0, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 128, 64, Filter1x1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 16, 256, 64, Filter1x1Pad0, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 32, 256, 64, Filter1x1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 16, 64, Filter1x1Stride1Pad0, 16, 16, 4, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 16, 64, Filter1x1Stride1Pad0, 16, 16, 4, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 32, 64, Filter1x1Stride1Pad0, 32, 32, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 16, 64, Filter1x1Stride1Pad0, 16, 16, 2, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 64, 64, Filter1x1Stride1Pad0, 16, 16, 1, 2, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 64, 64, Filter1x1Stride1Pad0, 32, 32, 1, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 128, 64, Filter1x1Stride1Pad0, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 128, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 16, 256, 64, Filter1x1Stride1Pad0, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 32, 256, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Default, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Default, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 16, 64, Default, 16, 16, 4, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 32, 64, Default, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 16, 64, Default, 16, 16, 4, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 32, 64, Default, 32, 32, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 16, 64, Default, 16, 16, 2, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Default, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 64, 64, Default, 16, 16, 1, 2, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 64, 64, Default, 32, 32, 1, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 128, 64, Default, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 128, 64, Default, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 16, 256, 64, Default, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 32, 256, 64, Default, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 16, 64, Filter1x1Pad0, 16, 16, 4, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 32, 64, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 16, 64, Filter1x1Pad0, 16, 16, 4, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 32, 64, Filter1x1Pad0, 32, 32, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 16, 64, Filter1x1Pad0, 16, 16, 2, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 64, 64, Filter1x1Pad0, 16, 16, 1, 2, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 64, 64, Filter1x1Pad0, 32, 32, 1, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 128, 64, Filter1x1Pad0, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 128, 64, Filter1x1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 16, 256, 64, Filter1x1Pad0, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 32, 256, 64, Filter1x1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 16, 64, Filter1x1Stride1Pad0, 16, 16, 4, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 16, 64, Filter1x1Stride1Pad0, 16, 16, 4, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 32, 64, Filter1x1Stride1Pad0, 32, 32, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 16, 64, Filter1x1Stride1Pad0, 16, 16, 2, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 64, 64, Filter1x1Stride1Pad0, 16, 16, 1, 2, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 64, 64, Filter1x1Stride1Pad0, 32, 32, 1, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 128, 64, Filter1x1Stride1Pad0, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 128, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 16, 256, 64, Filter1x1Stride1Pad0, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 32, 256, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> \ No newline at end of file diff --git a/experimental/grouped_convolution_tile_instances/configs/profiler/nhwgc_fp32.conf b/experimental/grouped_convolution_tile_instances/configs/profiler/nhwgc_fp32.conf new file mode 100644 index 0000000000..7dc982b6f7 --- /dev/null +++ b/experimental/grouped_convolution_tile_instances/configs/profiler/nhwgc_fp32.conf @@ -0,0 +1,176 @@ +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 16, Default, 32, 32, 2, 2, 1, 1, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 32, 16, Default, 32, 32, 2, 1, 4, 4, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 128, 16, Default, 32, 32, 2, 2, 1, 1, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 256, 128, 16, Default, 32, 32, 4, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 256, 16, Default, 32, 32, 2, 4, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 128, 16, Default, 32, 32, 4, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 128, 16, Default, 32, 32, 2, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 64, 16, Default, 32, 32, 2, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 64, 128, 16, Default, 32, 32, 2, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 16, Default, 32, 32, 2, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 64, 16, Default, 32, 32, 2, 1, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 128, 16, Default, 32, 32, 1, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 32, 16, Default, 32, 32, 2, 1, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 32, 128, 16, Default, 32, 32, 1, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 32, 16, Default, 32, 32, 2, 1, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 32, 64, 16, Default, 32, 32, 1, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 192, 16, Default, 32, 32, 2, 3, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 16, Filter1x1Pad0, 32, 32, 2, 2, 1, 1, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 32, 16, Filter1x1Pad0, 32, 32, 2, 1, 4, 4, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 128, 16, Filter1x1Pad0, 32, 32, 2, 2, 1, 1, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 256, 128, 16, Filter1x1Pad0, 32, 32, 4, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 256, 16, Filter1x1Pad0, 32, 32, 2, 4, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 128, 16, Filter1x1Pad0, 32, 32, 4, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 128, 16, Filter1x1Pad0, 32, 32, 2, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 64, 16, Filter1x1Pad0, 32, 32, 2, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 64, 128, 16, Filter1x1Pad0, 32, 32, 2, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 16, Filter1x1Pad0, 32, 32, 2, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 64, 16, Filter1x1Pad0, 32, 32, 2, 1, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 128, 16, Filter1x1Pad0, 32, 32, 1, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 32, 16, Filter1x1Pad0, 32, 32, 2, 1, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 32, 128, 16, Filter1x1Pad0, 32, 32, 1, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 32, 16, Filter1x1Pad0, 32, 32, 2, 1, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 32, 64, 16, Filter1x1Pad0, 32, 32, 1, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 192, 16, Filter1x1Pad0, 32, 32, 2, 3, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 16, Filter1x1Stride1Pad0, 32, 32, 2, 2, 1, 1, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 32, 16, Filter1x1Stride1Pad0, 32, 32, 2, 1, 4, 4, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 128, 16, Filter1x1Stride1Pad0, 32, 32, 2, 2, 1, 1, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 256, 128, 16, Filter1x1Stride1Pad0, 32, 32, 4, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 256, 16, Filter1x1Stride1Pad0, 32, 32, 2, 4, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 128, 16, Filter1x1Stride1Pad0, 32, 32, 4, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 128, 16, Filter1x1Stride1Pad0, 32, 32, 2, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 64, 16, Filter1x1Stride1Pad0, 32, 32, 2, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 64, 128, 16, Filter1x1Stride1Pad0, 32, 32, 2, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 16, Filter1x1Stride1Pad0, 32, 32, 2, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 64, 16, Filter1x1Stride1Pad0, 32, 32, 2, 1, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 128, 16, Filter1x1Stride1Pad0, 32, 32, 1, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 128, 32, 16, Filter1x1Stride1Pad0, 32, 32, 2, 1, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<128, 32, 128, 16, Filter1x1Stride1Pad0, 32, 32, 1, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 32, 16, Filter1x1Stride1Pad0, 32, 32, 2, 1, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 32, 64, 16, Filter1x1Stride1Pad0, 32, 32, 1, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 192, 16, Filter1x1Stride1Pad0, 32, 32, 2, 3, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Default, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Default, 16, 16, 2, 2, 2, 1, 2, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Default, 16, 16, 2, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Pad0, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Pad0, 16, 16, 2, 2, 2, 1, 2, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Pad0, 16, 16, 2, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 2, 1, 2, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 4, 4, 4, 1, 1, 1> +DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<64, 64, 64, 16, Default, 32, 32, 2, 2, 1, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<256, 256, 128, 16, Default, 32, 32, 4, 2, 4, 4, 4, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 8> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 16> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 32> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Filter3x3, 16, 16, 4, 1, 4, 1, 1, 1, 1, 8> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Filter3x3, 16, 16, 4, 1, 4, 1, 1, 1, 1, 16> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Filter3x3, 16, 16, 4, 1, 4, 1, 1, 1, 1, 32> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 32, Default, 32, 32, 2, 2, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 32, Filter1x1Pad0, 32, 32, 2, 2, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 32, Filter1x1Stride1Pad0, 32, 32, 2, 2, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Default, 16, 16, 1, 1, 4, 4, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Default, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Default, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Default, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 32, 64, Default, 32, 32, 2, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 16, 64, Default, 16, 16, 4, 1, 4, 4, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 32, 64, Default, 32, 32, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 16, 64, Default, 16, 16, 2, 1, 4, 4, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Default, 16, 16, 1, 1, 4, 4, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Default, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Default, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Default, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 64, 64, Default, 16, 16, 1, 2, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 64, 64, Default, 32, 32, 1, 1, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 128, 64, Default, 16, 16, 1, 4, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 128, 64, Default, 32, 32, 1, 2, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 4, 4, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Filter1x1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 32, 64, Filter1x1Pad0, 32, 32, 2, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 16, 64, Filter1x1Pad0, 16, 16, 4, 1, 4, 4, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 32, 64, Filter1x1Pad0, 32, 32, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 16, 64, Filter1x1Pad0, 16, 16, 2, 1, 4, 4, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 4, 4, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Filter1x1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 64, 64, Filter1x1Pad0, 16, 16, 1, 2, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 64, 64, Filter1x1Pad0, 32, 32, 1, 1, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 128, 64, Filter1x1Pad0, 16, 16, 1, 4, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 128, 64, Filter1x1Pad0, 32, 32, 1, 2, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 4, 4, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Stride1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 16, 64, Filter1x1Stride1Pad0, 16, 16, 4, 1, 4, 4, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 32, 64, Filter1x1Stride1Pad0, 32, 32, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 16, 64, Filter1x1Stride1Pad0, 16, 16, 2, 1, 4, 4, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 4, 4, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Stride1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 64, 64, Filter1x1Stride1Pad0, 16, 16, 1, 2, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 64, 64, Filter1x1Stride1Pad0, 32, 32, 1, 1, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 128, 64, Filter1x1Stride1Pad0, 16, 16, 1, 4, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 128, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Default, 16, 16, 1, 1, 4, 4, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Default, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Default, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Default, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 32, 64, Default, 32, 32, 2, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 16, 64, Default, 16, 16, 4, 1, 4, 4, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 32, 64, Default, 32, 32, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 16, 64, Default, 16, 16, 2, 1, 4, 4, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Default, 16, 16, 1, 1, 4, 4, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Default, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Default, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Default, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 64, 64, Default, 16, 16, 1, 2, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 64, 64, Default, 32, 32, 1, 1, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 128, 64, Default, 16, 16, 1, 4, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 128, 64, Default, 32, 32, 1, 2, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 4, 4, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Filter1x1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 32, 64, Filter1x1Pad0, 32, 32, 2, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 16, 64, Filter1x1Pad0, 16, 16, 4, 1, 4, 4, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 32, 64, Filter1x1Pad0, 32, 32, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 16, 64, Filter1x1Pad0, 16, 16, 2, 1, 4, 4, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 4, 4, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Filter1x1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 64, 64, Filter1x1Pad0, 16, 16, 1, 2, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 64, 64, Filter1x1Pad0, 32, 32, 1, 1, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 128, 64, Filter1x1Pad0, 16, 16, 1, 4, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 128, 64, Filter1x1Pad0, 32, 32, 1, 2, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 4, 4, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Stride1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 128, 16, 64, Filter1x1Stride1Pad0, 16, 16, 4, 1, 4, 4, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 32, 64, Filter1x1Stride1Pad0, 32, 32, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 64, 16, 64, Filter1x1Stride1Pad0, 16, 16, 2, 1, 4, 4, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 4, 4, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Stride1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 64, 64, Filter1x1Stride1Pad0, 16, 16, 1, 2, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 64, 64, Filter1x1Stride1Pad0, 32, 32, 1, 1, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 128, 64, Filter1x1Stride1Pad0, 16, 16, 1, 4, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 128, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> \ No newline at end of file diff --git a/experimental/grouped_convolution_tile_instances/configs/tests/ndhwgc_bf16.conf b/experimental/grouped_convolution_tile_instances/configs/tests/ndhwgc_bf16.conf new file mode 100644 index 0000000000..9222a0858f --- /dev/null +++ b/experimental/grouped_convolution_tile_instances/configs/tests/ndhwgc_bf16.conf @@ -0,0 +1,41 @@ +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 32, Default, 32, 32, 2, 2, 1, 1, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 32, 64, 32, Default, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 32, Filter1x1Pad0, 32, 32, 2, 2, 1, 1, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 32, 64, 32, Filter1x1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 32, Filter1x1Stride1Pad0, 32, 32, 2, 2, 1, 1, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 32, 64, 32, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Default, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Pad0, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<64, 64, 64, 32, Default, 32, 32, 2, 2, 1, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 8> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 16> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 32> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Filter3x3, 16, 16, 4, 1, 4, 1, 1, 1, 1, 8> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Filter3x3, 16, 16, 4, 1, 4, 1, 1, 1, 1, 16> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Filter3x3, 16, 16, 4, 1, 4, 1, 1, 1, 1, 32> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Default, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Stride1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Default, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Default, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Default, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> \ No newline at end of file diff --git a/experimental/grouped_convolution_tile_instances/configs/tests/ndhwgc_fp16.conf b/experimental/grouped_convolution_tile_instances/configs/tests/ndhwgc_fp16.conf new file mode 100644 index 0000000000..9222a0858f --- /dev/null +++ b/experimental/grouped_convolution_tile_instances/configs/tests/ndhwgc_fp16.conf @@ -0,0 +1,41 @@ +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 32, Default, 32, 32, 2, 2, 1, 1, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 32, 64, 32, Default, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 32, Filter1x1Pad0, 32, 32, 2, 2, 1, 1, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 32, 64, 32, Filter1x1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 32, Filter1x1Stride1Pad0, 32, 32, 2, 2, 1, 1, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 32, 64, 32, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Default, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Pad0, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<64, 64, 64, 32, Default, 32, 32, 2, 2, 1, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 8> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 16> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 32> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Filter3x3, 16, 16, 4, 1, 4, 1, 1, 1, 1, 8> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Filter3x3, 16, 16, 4, 1, 4, 1, 1, 1, 1, 16> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Filter3x3, 16, 16, 4, 1, 4, 1, 1, 1, 1, 32> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Default, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Stride1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Default, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Default, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Default, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> \ No newline at end of file diff --git a/experimental/grouped_convolution_tile_instances/configs/tests/ndhwgc_fp32.conf b/experimental/grouped_convolution_tile_instances/configs/tests/ndhwgc_fp32.conf new file mode 100644 index 0000000000..b9704c8100 --- /dev/null +++ b/experimental/grouped_convolution_tile_instances/configs/tests/ndhwgc_fp32.conf @@ -0,0 +1,42 @@ +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 32, Default, 32, 32, 2, 2, 1, 1, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 32, 64, 32, Default, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 32, Filter1x1Pad0, 32, 32, 2, 2, 1, 1, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 32, 64, 32, Filter1x1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 32, Filter1x1Stride1Pad0, 32, 32, 2, 2, 1, 1, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 32, 64, 32, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Default, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Pad0, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<64, 64, 64, 32, Default, 32, 32, 2, 2, 1, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 8> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 16> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 32> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Filter3x3, 16, 16, 4, 1, 4, 1, 1, 1, 1, 8> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Filter3x3, 16, 16, 4, 1, 4, 1, 1, 1, 1, 16> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Filter3x3, 16, 16, 4, 1, 4, 1, 1, 1, 1, 32> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 32, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 32, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 32, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Default, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Stride1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Default, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Default, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Stride1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 128, 32, Default, 32, 32, 1, 2, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 128, 32, Filter1x1Pad0, 32, 32, 1, 2, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 128, 32, Filter1x1Stride1Pad0, 32, 32, 1, 2, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> \ No newline at end of file diff --git a/experimental/grouped_convolution_tile_instances/configs/tests/nhwgc_bf16.conf b/experimental/grouped_convolution_tile_instances/configs/tests/nhwgc_bf16.conf new file mode 100644 index 0000000000..9222a0858f --- /dev/null +++ b/experimental/grouped_convolution_tile_instances/configs/tests/nhwgc_bf16.conf @@ -0,0 +1,41 @@ +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 32, Default, 32, 32, 2, 2, 1, 1, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 32, 64, 32, Default, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 32, Filter1x1Pad0, 32, 32, 2, 2, 1, 1, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 32, 64, 32, Filter1x1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 32, Filter1x1Stride1Pad0, 32, 32, 2, 2, 1, 1, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 32, 64, 32, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Default, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Pad0, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<64, 64, 64, 32, Default, 32, 32, 2, 2, 1, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 8> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 16> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 32> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Filter3x3, 16, 16, 4, 1, 4, 1, 1, 1, 1, 8> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Filter3x3, 16, 16, 4, 1, 4, 1, 1, 1, 1, 16> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Filter3x3, 16, 16, 4, 1, 4, 1, 1, 1, 1, 32> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Default, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Stride1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Default, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Default, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Default, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> \ No newline at end of file diff --git a/experimental/grouped_convolution_tile_instances/configs/tests/nhwgc_fp16.conf b/experimental/grouped_convolution_tile_instances/configs/tests/nhwgc_fp16.conf new file mode 100644 index 0000000000..9222a0858f --- /dev/null +++ b/experimental/grouped_convolution_tile_instances/configs/tests/nhwgc_fp16.conf @@ -0,0 +1,41 @@ +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 32, Default, 32, 32, 2, 2, 1, 1, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 32, 64, 32, Default, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 32, Filter1x1Pad0, 32, 32, 2, 2, 1, 1, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 32, 64, 32, Filter1x1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 32, Filter1x1Stride1Pad0, 32, 32, 2, 2, 1, 1, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 32, 64, 32, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Default, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Pad0, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<64, 64, 64, 32, Default, 32, 32, 2, 2, 1, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 8> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 16> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 32> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Filter3x3, 16, 16, 4, 1, 4, 1, 1, 1, 1, 8> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Filter3x3, 16, 16, 4, 1, 4, 1, 1, 1, 1, 16> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Filter3x3, 16, 16, 4, 1, 4, 1, 1, 1, 1, 32> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Default, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Stride1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Default, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Default, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Default, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> \ No newline at end of file diff --git a/experimental/grouped_convolution_tile_instances/configs/tests/nhwgc_fp32.conf b/experimental/grouped_convolution_tile_instances/configs/tests/nhwgc_fp32.conf new file mode 100644 index 0000000000..b9704c8100 --- /dev/null +++ b/experimental/grouped_convolution_tile_instances/configs/tests/nhwgc_fp32.conf @@ -0,0 +1,42 @@ +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 32, Default, 32, 32, 2, 2, 1, 1, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 32, 64, 32, Default, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 32, Filter1x1Pad0, 32, 32, 2, 2, 1, 1, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 32, 64, 32, Filter1x1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 64, 32, Filter1x1Stride1Pad0, 32, 32, 2, 2, 1, 1, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 32, 64, 32, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Default, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Pad0, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<64, 64, 64, 32, Default, 32, 32, 2, 2, 1, 1, 1, 1, 1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 8> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 16> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 32> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Filter3x3, 16, 16, 4, 1, 4, 1, 1, 1, 1, 8> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Filter3x3, 16, 16, 4, 1, 4, 1, 1, 1, 1, 16> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Filter3x3, 16, 16, 4, 1, 4, 1, 1, 1, 1, 32> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 32, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 32, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 32, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Default, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Stride1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Default, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 32, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Default, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<64, 16, 16, 128, Filter1x1Stride1Pad0, 16, 16, 1, 1, 4, 4, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 128, 32, Default, 32, 32, 1, 2, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 128, 32, Filter1x1Pad0, 32, 32, 1, 2, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> +DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 128, 32, Filter1x1Stride1Pad0, 32, 32, 1, 2, 4, 4, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v2> \ No newline at end of file diff --git a/experimental/grouped_convolution_tile_instances/generate_instances.py b/experimental/grouped_convolution_tile_instances/generate_instances.py new file mode 100644 index 0000000000..91424987f3 --- /dev/null +++ b/experimental/grouped_convolution_tile_instances/generate_instances.py @@ -0,0 +1,275 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +import argparse +from pathlib import Path + + +class ConvInstanceTemplateParams: + def __init__( + self, + specialization, + tile_size, + warps, + warp_tile, + double_smem_buffer, + num_wave_groups, + pipeline_version, + scheduler, + scalar_per_vector, + num_groups_to_merge, + split_image, + explicit_gemm, + id, + ): + self.specialization = specialization + self.tile_size = tile_size + self.warps = warps + self.warp_tile = warp_tile + self.double_smem_buffer = double_smem_buffer + self.num_wave_groups = num_wave_groups + self.pipeline_version = pipeline_version + self.scheduler = scheduler + self.scalar_per_vector = scalar_per_vector + self.num_groups_to_merge = num_groups_to_merge + self.split_image = split_image + self.explicit_gemm = explicit_gemm + self.id = id + + def get_optimizations(self): + explicit_gemm = "true" if self.explicit_gemm else "false" + split_image = "true" if self.split_image else "false" + num_groups_to_merge = str(self.num_groups_to_merge) + return f"ckt::TileOptimizations{{.num_groups_to_merge = {num_groups_to_merge}, .split_image = {split_image}, .explicit_gemm = {explicit_gemm}}}" + + def get_specialization(self): + namespace = "ckb::TileConvSpecialization::" + if self.specialization == "Default" or self.specialization == "OddC": + return namespace + "DEFAULT" + if self.specialization == "Filter1x1Pad0": + return namespace + "FILTER_1X1_PAD0" + if self.specialization == "Filter1x1Stride1Pad0": + return namespace + "FILTER_1X1_STRIDE1_PAD0" + if self.specialization == "Filter3x3": + return namespace + "FILTER_3x3" + else: + raise RuntimeError("not supported specialization") + + def get_thread_block(self): + return f"ckt::TileThreadBlock{{.tile_size = {{.m = {self.tile_size[0]}, .n = {self.tile_size[1]}, .k = {self.tile_size[2]}}}}}" + + def get_block_gemm_desc(self): + double_smem_buffer = "true" if self.double_smem_buffer else "false" + pipeline_version = self.pipeline_version[-1:] + scheduler = ( + "INTRAWAVE" if self.scheduler.find("Intrawave") != -1 else "INTERWAVE" + ) + return f"""ckt::TileBlockGemm{{ + .warps = {{.m = {self.warps[0]}, .n = {self.warps[1]}, .k = {self.warps[2]}}}, + .warp_tile = {{.m = {self.warp_tile[0]}, .n = {self.warp_tile[1]}, .k = {self.warp_tile[2]}}}, + .double_smem_buffer = {double_smem_buffer}, + .num_wave_groups = {self.num_wave_groups}, + .pipeline_version = ckb::PipelineVersion::V{pipeline_version}, + .scheduler = ckb::PipelineScheduler::{scheduler}}}""" + + def get_block_transfer(self): + return f"""ckt::TileTransfer{{.a_scalar_per_vector = {self.scalar_per_vector[0]}, + .b_scalar_per_vector = {self.scalar_per_vector[1]}, .c_scalar_per_vector = {self.scalar_per_vector[2]}}}""" + + +def get_dtype(problem_name): + if problem_name.find("fp32") != -1: + return "float" + if problem_name.find("fp16") != -1: + return "ck_tile::half_t" + if problem_name.find("bf16") != -1: + return "ck_tile::bf16_t" + else: + raise RuntimeError("wrong dtype") + + +def generate_calls_inc(instances, problem_name, direction, filter_pattern): + generate_dir = Path(__file__).resolve().parent + with open(f"{generate_dir}/{problem_name}_calls.inc", "w") as f: + if problem_name.find(filter_pattern) == -1: + return + for instance in instances: + instance_name = problem_name + "_" + str(instance.id) + f.write(f"run_alg(run_{instance_name});\n") + + +def generate_defs_inc(instances, problem_name, signature, direction, filter_pattern): + generate_dir = Path(__file__).resolve().parent + with open(f"{generate_dir}/{problem_name}.inc", "w") as f: + if problem_name.find(filter_pattern) == -1: + return + for instance in instances: + instance_name = problem_name + "_" + str(instance.id) + f.write( + f"std::tuple run_{instance_name}(\n" + f" const ckt::Args<{signature}>& args,\n" + f" const ckt::Inputs<{signature}>& inputs,\n" + f" const ckt::Outputs<{signature}>& outputs,\n" + f" const ck_tile::stream_config& s_conf);\n" + ) + + +def generate_fwd_cpp( + instances, problem_name, config, direction, signature_name, filter_pattern +): + for instance in instances: + if problem_name.find(filter_pattern) == -1: + break + instance_name = problem_name + "_" + str(instance.id) + generate_dir = Path(__file__).resolve().parent + directory_path = Path(f"{generate_dir}/instances/{config}") + directory_path.mkdir(parents=True, exist_ok=True) + with open( + f"{generate_dir}/instances/grouped_convolution_forward_tile.cpp.in", + "r", + ) as f: + content = f.read() + + content = content.replace("gen_signature", signature_name) + content = content.replace("gen_instance_name", instance_name) + content = content.replace("gen_specialization", instance.get_specialization()) + content = content.replace("gen_thread_block", instance.get_thread_block()) + content = content.replace("gen_block_gemm_desc", instance.get_block_gemm_desc()) + content = content.replace("gen_block_transfer", instance.get_block_transfer()) + content = content.replace("gen_optimizations", instance.get_optimizations()) + + with open( + f"{generate_dir}/instances/{config}/{instance_name}.cpp", + "w", + ) as f: + f.write(content) + + +def parse_fwd_instances(instances, problem_name): + convs = [] + for instance_id, instance in enumerate(instances): + if instance.find("#") != -1 or instance.find(";") != -1: + continue + instance_args_list = instance[instance.find("<") + 1 : instance.find(">")] + args = instance_args_list.split(", ") + + block_size = int(args[0]) + m_per_block = int(args[1]) + n_per_block = int(args[2]) + k_per_block = int(args[3]) + spec = args[4] + m_per_xdl = int(args[5]) + n_per_xdl = int(args[6]) + m_xdl_per_wave = int(args[7]) + n_xdl_per_wave = int(args[8]) + a_scalar_per_vector = int(args[9]) + b_scalar_per_vector = int(args[10]) + c_scalar_per_vector = int(args[11]) + if len(args) == 15: + num_groups_to_merge = int(args[14]) + elif len(args) != 16 and len(args) != 14: + raise RuntimeError("wrong number of parameters") + else: + num_groups_to_merge = 1 + split_image = instance.find("Large") != -1 + double_smem_buffer = instance.find("BlkGemmPipelineVersion: v4") != -1 + num_wave_groups = 2 if instance.find("BlkGemmPipelineVersion: v5") != -1 else 1 + scheduler = ( + "Intrawave" if instance.find("BlkGemmPipelineScheduler") == -1 else args[14] + ) + pipeline_version = ( + "v1" if instance.find("BlkGemmPipelineVersion") == -1 else args[15] + ) + + m_warp = int(m_per_block / (m_per_xdl * m_xdl_per_wave)) + n_warp = int(n_per_block / (n_per_xdl * n_xdl_per_wave)) + warp_size = 64 + k_warp = int(block_size / (warp_size * m_warp * n_warp)) + dtype = get_dtype(problem_name) + # TODO: Make it more flexible + # k_per_xdl = f"ck_tile::get_k_warp_tile<{dtype}, {m_per_xdl}>()" + k_per_xdl = 8 if dtype == "float" else 16 + + conv = ConvInstanceTemplateParams( + spec, + [m_per_block, n_per_block, k_per_block], + [m_warp, n_warp, k_warp], + [m_per_xdl, n_per_xdl, k_per_xdl], + double_smem_buffer, + num_wave_groups, + pipeline_version, + scheduler, + [a_scalar_per_vector, b_scalar_per_vector, c_scalar_per_vector], + num_groups_to_merge, + split_image, + False, + instance_id, + ) + convs.append(conv) + return convs + + +def generate_instances_fwd(instances, problem_name, config, filter_pattern): + direction = "forward" + signature_name = f"SIGNATURE_{config.upper()}_FWD" + instances = parse_fwd_instances(instances, problem_name) + generate_calls_inc(instances, problem_name, direction, filter_pattern) + generate_defs_inc( + instances, + problem_name, + signature_name, + direction, + filter_pattern, + ) + generate_fwd_cpp( + instances, problem_name, config, direction, signature_name, filter_pattern + ) + + +if __name__ == "__main__": + fwd_configs = [ + "nhwgc_fp32", + "nhwgc_fp16", + "nhwgc_bf16", + "ndhwgc_fp32", + "ndhwgc_fp16", + "ndhwgc_bf16", + ] + + parser = argparse.ArgumentParser( + description="Generate grouped conv CK Tile instances." + ) + parser.add_argument( + "--filter_pattern", + type=str, + default="convolution", + help="Filter pattern for configs.", + ) + parser.add_argument( + "--mode", + choices=["compilation", "tests", "profiler"], + type=str, + default="profiler", + help="Generator modes. compilation - empty instance list, tests - limited instance list, profiler - generate all instances", + ) + args = parser.parse_args() + + # apply empty filter + if args.mode == "compilation": + args.filter_pattern = "empty" + configs_prefix = "profiler" + elif args.mode == "tests": + configs_prefix = "tests" + elif args.mode == "profiler": + configs_prefix = "profiler" + else: + raise RuntimeError("wrong mode") + + for config in fwd_configs: + instances = [] + generate_dir = Path(__file__).resolve().parent + config_path = f"{generate_dir}/configs/{configs_prefix}/{config}.conf" + with open(config_path, "r") as file: + instances = file.readlines() + problem_name = f"grouped_convolution_forward_tile_{config}" + generate_instances_fwd(instances, problem_name, config, args.filter_pattern) diff --git a/experimental/grouped_convolution_tile_instances/instances/grouped_convolution_forward_tile.cpp.in b/experimental/grouped_convolution_tile_instances/instances/grouped_convolution_forward_tile.cpp.in new file mode 100644 index 0000000000..7e86576f7b --- /dev/null +++ b/experimental/grouped_convolution_tile_instances/instances/grouped_convolution_forward_tile.cpp.in @@ -0,0 +1,19 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#include "../instance_includes.inc" +namespace ck_tile::builder::profiling { +constexpr auto SIGNATURE = gen_signature; +std::tuple run_gen_instance_name(const ckt::Args& args, + const ckt::Inputs& inputs, + const ckt::Outputs& outputs, + const ck_tile::stream_config& s_conf) +{ + constexpr auto ALGORITHM = cku::ConvAlgorithm_Tile_GroupedConvolutionKernel{} + .with_tile_specializations(gen_specialization) + .with_tile_thread_block(gen_thread_block) + .with_tile_block_gemm(gen_block_gemm_desc) + .with_tile_transfer(gen_block_transfer) + .with_tile_optimizations(gen_optimizations); +#include "../instance_run.inc" +} +} // namespace ck_tile::builder::profiling diff --git a/experimental/grouped_convolution_tile_instances/instances/instance_includes.inc b/experimental/grouped_convolution_tile_instances/instances/instance_includes.inc new file mode 100644 index 0000000000..4b4c144428 --- /dev/null +++ b/experimental/grouped_convolution_tile_instances/instances/instance_includes.inc @@ -0,0 +1,64 @@ +#include "../../builder/test/utils/ckb_conv_tile_test_configs.hpp" +#include "ck_tile/builder/testing/conv_fwd_ck_tile.hpp" + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; +namespace cku = ck_tile::builder::test_utils; + +namespace ck_tile::builder::profiling { + +constexpr auto SIGNATURE_NHWGC_FP32_FWD = + ckt::ConvSignature{.spatial_dim = 2, + .direction = ckb::ConvDirection::FORWARD, + .data_type = ckb::DataType::FP32, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = ckb::TensorLayout::NHWGC}}, + .weight = {.config = {.layout = ckb::TensorLayout::GKYXC}}, + .output = {.config = {.layout = ckb::TensorLayout::NHWGK}}}; + +constexpr auto SIGNATURE_NHWGC_BF16_FWD = + ckt::ConvSignature{.spatial_dim = 2, + .direction = ckb::ConvDirection::FORWARD, + .data_type = ckb::DataType::BF16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = ckb::TensorLayout::NHWGC}}, + .weight = {.config = {.layout = ckb::TensorLayout::GKYXC}}, + .output = {.config = {.layout = ckb::TensorLayout::NHWGK}}}; + +constexpr auto SIGNATURE_NHWGC_FP16_FWD = + ckt::ConvSignature{.spatial_dim = 2, + .direction = ckb::ConvDirection::FORWARD, + .data_type = ckb::DataType::FP16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = ckb::TensorLayout::NHWGC}}, + .weight = {.config = {.layout = ckb::TensorLayout::GKYXC}}, + .output = {.config = {.layout = ckb::TensorLayout::NHWGK}}}; + +constexpr auto SIGNATURE_NDHWGC_FP32_FWD = + ckt::ConvSignature{.spatial_dim = 3, + .direction = ckb::ConvDirection::FORWARD, + .data_type = ckb::DataType::FP32, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = ckb::TensorLayout::NDHWGC}}, + .weight = {.config = {.layout = ckb::TensorLayout::GKZYXC}}, + .output = {.config = {.layout = ckb::TensorLayout::NDHWGK}}}; + +constexpr auto SIGNATURE_NDHWGC_BF16_FWD = + ckt::ConvSignature{.spatial_dim = 3, + .direction = ckb::ConvDirection::FORWARD, + .data_type = ckb::DataType::BF16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = ckb::TensorLayout::NDHWGC}}, + .weight = {.config = {.layout = ckb::TensorLayout::GKZYXC}}, + .output = {.config = {.layout = ckb::TensorLayout::NDHWGK}}}; + +constexpr auto SIGNATURE_NDHWGC_FP16_FWD = + ckt::ConvSignature{.spatial_dim = 3, + .direction = ckb::ConvDirection::FORWARD, + .data_type = ckb::DataType::FP16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = ckb::TensorLayout::NDHWGC}}, + .weight = {.config = {.layout = ckb::TensorLayout::GKZYXC}}, + .output = {.config = {.layout = ckb::TensorLayout::NDHWGK}}}; + +} // namespace ck_tile::builder::profiling diff --git a/experimental/grouped_convolution_tile_instances/instances/instance_run.inc b/experimental/grouped_convolution_tile_instances/instances/instance_run.inc new file mode 100644 index 0000000000..6b8024fa93 --- /dev/null +++ b/experimental/grouped_convolution_tile_instances/instances/instance_run.inc @@ -0,0 +1,9 @@ + +using Builder = ckb::ConvBuilder; +using Instance = Builder::Instance; + +auto conv = Instance{}; +bool is_supported; +float avg_time; +std::tie(is_supported, avg_time) = ckt::run(conv, args, inputs, outputs, s_conf); +return std::make_tuple(is_supported, avg_time, conv.GetInstanceString()); diff --git a/include/ck/library/utility/host_tensor.hpp b/include/ck/library/utility/host_tensor.hpp index 05bc4ded12..1dda0a4863 100644 --- a/include/ck/library/utility/host_tensor.hpp +++ b/include/ck/library/utility/host_tensor.hpp @@ -298,9 +298,12 @@ struct HostTensorDescriptor if constexpr(!(std::is_same_v || std::is_same_v)) { - std::cerr << "Only RowMajor and ColumnMajor layouts are supported for empty " - "strides, got " - << layout << ". Will calculate strides as RowMajor." << std::endl; + if(dbg) + { + std::cerr << "Only RowMajor and ColumnMajor layouts are supported for empty " + "strides, got " + << layout << ". Will calculate strides as RowMajor." << std::endl; + } } mStrides.clear(); @@ -443,9 +446,14 @@ struct HostTensorDescriptor { // TBD: implement verification for Conv layouts // For now, just print warning and return - std::cerr << "Warning: Tensor layout verification for ck::tensor_layout::convolution " - "layouts is not supported yet. Skipping..." - << std::endl; + if(dbg) + { + + std::cerr + << "Warning: Tensor layout verification for ck::tensor_layout::convolution " + "layouts is not supported yet. Skipping..." + << std::endl; + } return; } else diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp index 936c38ddf3..9b7213837a 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -90,6 +90,8 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1>; +template +using WarpGemmMfmaF32F32F32M16N16K8 = WarpGemmImpl, + 2, + AttrNumAccess>>; + +template +using WarpGemmMfmaF32F32F32M32N32K8 = WarpGemmImpl, + 4, + AttrNumAccess>>; + template using WarpGemmMfmaF32F32F32M16N16K16TransposedCDistribution = WarpGemmImpl struct Dispatcher { using Type = WarpGemmMfmaF32F32F32M16N16K4; }; template<> struct Dispatcher { using Type = WarpGemmMfmaF32F32F32M16N16K16<>; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaF32F32F32M16N16K8<>; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaF32F32F32M32N32K8<>; }; template<> struct Dispatcher { using Type = WarpGemmMfmaF32F32F32M16N16K16TransposedCDistribution<>; }; // fp16 // ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp index 4af8d8a768..555264eee8 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp @@ -723,8 +723,11 @@ struct GroupedConvolutionForwardKernel if constexpr(GroupedConvTraitsType_::ExplicitGemm && ConvSpecialization != ConvolutionSpecialization::Filter1x1Stride1Pad0) { - CK_TILE_ERROR( - "Explicit Gemm is supported only for Filter1x1Stride1Pad0 specialization!"); + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Explicit Gemm is supported only for Filter1x1Stride1Pad0 specialization!"); + } return false; } @@ -736,13 +739,19 @@ struct GroupedConvolutionForwardKernel // Check access per C if(ConvC % GroupedConvTraitsType_::VectorSizeA != 0) { - CK_TILE_ERROR("Conv C is not a multiple of vector load size for input image!"); + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Conv C is not a multiple of vector load size for input image!"); + } return false; } } else { - CK_TILE_ERROR("Not supported input layout!"); + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Not supported input layout!"); + } return false; } @@ -754,13 +763,19 @@ struct GroupedConvolutionForwardKernel { if(ConvC % GroupedConvTraitsType_::VectorSizeB != 0) { - CK_TILE_ERROR("Conv C is not a multiple of vector load size for weight!"); + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Conv C is not a multiple of vector load size for weight!"); + } return false; } } else { - CK_TILE_ERROR("Not supported weight layout!"); + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Not supported weight layout!"); + } return false; } @@ -771,13 +786,20 @@ struct GroupedConvolutionForwardKernel { if(ConvK % GroupedConvTraitsType_::VectorSizeC != 0) { - CK_TILE_ERROR("Conv K is not a multiple of vector store size for output image!"); + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Conv K is not a multiple of vector store size for output image!"); + } return false; } } else { - CK_TILE_ERROR("Not supported output layout!"); + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Not supported output layout!"); + } return false; } @@ -786,7 +808,10 @@ struct GroupedConvolutionForwardKernel const index_t ConvG = kargs.wei_g_k_c_xs_lengths[number<0>{}]; if(ConvG % GroupedConvTraitsType_::NumGroupsToMerge != 0) { - CK_TILE_ERROR("ConvG must be a multiple of NumGroupsToMerge!"); + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("ConvG must be a multiple of NumGroupsToMerge!"); + } return false; } } @@ -955,7 +980,8 @@ struct GroupedConvolutionForwardKernel else { if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 && - is_any_of::value)) + is_any_of::value) && + IsSplitKSupported) { auto c_block_window = MakeCBlockWindow( c_ptr, c_desc, block_idx_m, block_idx_n); diff --git a/profiler/include/profiler/grouped_convolution_forward_tile_algs.hpp b/profiler/include/profiler/grouped_convolution_forward_tile_algs.hpp new file mode 100644 index 0000000000..e58c884729 --- /dev/null +++ b/profiler/include/profiler/grouped_convolution_forward_tile_algs.hpp @@ -0,0 +1,169 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include "../../experimental/builder/test/utils/conv_algorithm_type_utils.hpp" +#include "grouped_convolution_signatures.hpp" + +#include "ck_tile/builder/testing/filter_extent.hpp" +#include "ck_tile/builder/testing/conv_fwd_ck_tile.hpp" +#include "ck_tile/builder/testing/conv_fwd_reference.hpp" +#include "ck_tile/builder/conv_builder.hpp" + +namespace ck_tile::builder::profiling { + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; + +#include "../../experimental/grouped_convolution_tile_instances/grouped_convolution_forward_tile_nhwgc_fp32.inc" +#include "../../experimental/grouped_convolution_tile_instances/grouped_convolution_forward_tile_nhwgc_bf16.inc" +#include "../../experimental/grouped_convolution_tile_instances/grouped_convolution_forward_tile_nhwgc_fp16.inc" +#include "../../experimental/grouped_convolution_tile_instances/grouped_convolution_forward_tile_ndhwgc_fp32.inc" +#include "../../experimental/grouped_convolution_tile_instances/grouped_convolution_forward_tile_ndhwgc_bf16.inc" +#include "../../experimental/grouped_convolution_tile_instances/grouped_convolution_forward_tile_ndhwgc_fp16.inc" + +template +auto parse_conv_args(int arg_idx, char* const argv[]) +{ + const std::size_t G = static_cast(std::stol(argv[arg_idx++])); + const std::size_t N = static_cast(std::stol(argv[arg_idx++])); + const std::size_t K = static_cast(std::stol(argv[arg_idx++])); + const std::size_t C = static_cast(std::stol(argv[arg_idx++])); + + constexpr auto num_dim_spatial = SIGNATURE.spatial_dim; + + std::vector filter_spatial_lengths(num_dim_spatial); + std::vector input_spatial_lengths(num_dim_spatial); + std::vector conv_filter_strides(num_dim_spatial); + std::vector conv_filter_dilations(num_dim_spatial); + std::vector input_left_pads(num_dim_spatial); + std::vector input_right_pads(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + filter_spatial_lengths[i] = static_cast(std::stol(argv[arg_idx++])); + } + + for(int i = 0; i < num_dim_spatial; ++i) + { + input_spatial_lengths[i] = static_cast(std::stol(argv[arg_idx++])); + } + + for(int i = 0; i < num_dim_spatial; ++i) + { + conv_filter_strides[i] = static_cast(std::stol(argv[arg_idx++])); + } + + for(int i = 0; i < num_dim_spatial; ++i) + { + conv_filter_dilations[i] = static_cast(std::stol(argv[arg_idx++])); + } + + for(int i = 0; i < num_dim_spatial; ++i) + { + input_left_pads[i] = static_cast(std::stol(argv[arg_idx++])); + } + + for(int i = 0; i < num_dim_spatial; ++i) + { + input_right_pads[i] = static_cast(std::stol(argv[arg_idx++])); + } + + ckt::Args args = { + .lengths = + { + .batch_size = N, + .groups = G, + .input_channels = C, + .output_channels = K, + .image = ckt::filter_extent_from_vector(input_spatial_lengths), + .filter = ckt::filter_extent_from_vector(filter_spatial_lengths), + }, + .filter_strides = ckt::filter_extent_from_vector(conv_filter_strides), + .filter_dilation = ckt::filter_extent_from_vector(conv_filter_dilations), + .input_left_pad = ckt::filter_extent_from_vector(input_left_pads), + .input_right_pad = ckt::filter_extent_from_vector(input_right_pads), + .a_elementwise_op = {}, + .b_elementwise_op = {}, + .cde_elementwise_op = {}, + }; + return args; +} + +/// @brief `run_grouped_conv_forward_tile_algs()` run all grouped conv fwd instances. +/// +/// @tparam SIGNATURE Forward convolution signature. +/// +/// @see run_grouped_conv_forward_tile_algs() +template +std::tuple +run_grouped_conv_forward_tile_algs(const ckt::Args& args, + const ckt::Inputs& inputs, + const ckt::Outputs& outputs, + const ck_tile::stream_config& s_conf) +{ + float best_avg_time = std::numeric_limits::max(); + std::string best_op_name, op_name; + bool is_supported; + float avg_time; + bool valid = true; + + auto reference = ckt::alloc_outputs(args); + using ReferenceInstance = + typename ckb::ConvBuilder::Instance; + auto ref_conv = ReferenceInstance{}; + ckt::run(ref_conv, args, inputs, reference.get()); + + [[maybe_unused]] auto run_alg = [&](auto&& run_alg_func) { + std::tie(is_supported, avg_time, op_name) = run_alg_func(args, inputs, outputs, s_conf); + if(is_supported) + { + const auto errors = ckt::validate(args, outputs, reference.get()).get_errors(); + for(const auto& error : errors) + { + valid = false; + std::cout << "Number of incorrect values: " << error.wrong_elements + << " Is all zero:" << error.is_all_zero() << std::endl; + } + best_avg_time = std::min(best_avg_time, avg_time); + best_op_name = best_avg_time < avg_time ? best_op_name : op_name; + std::cout << "Perf: " << std::setw(10) << avg_time << " ms,"; + } + std::cout << " " << op_name << std::endl; + }; + + if constexpr(SIGNATURE == SIGNATURE_NHWGC_FP16_FWD) + { +#include "../../experimental/grouped_convolution_tile_instances/grouped_convolution_forward_tile_nhwgc_fp16_calls.inc" + } + else if constexpr(SIGNATURE == SIGNATURE_NHWGC_BF16_FWD) + { +#include "../../experimental/grouped_convolution_tile_instances/grouped_convolution_forward_tile_nhwgc_bf16_calls.inc" + } + else if constexpr(SIGNATURE == SIGNATURE_NHWGC_FP32_FWD) + { +#include "../../experimental/grouped_convolution_tile_instances/grouped_convolution_forward_tile_nhwgc_fp32_calls.inc" + } + else if constexpr(SIGNATURE == SIGNATURE_NDHWGC_FP16_FWD) + { +#include "../../experimental/grouped_convolution_tile_instances/grouped_convolution_forward_tile_ndhwgc_fp16_calls.inc" + } + else if constexpr(SIGNATURE == SIGNATURE_NDHWGC_BF16_FWD) + { +#include "../../experimental/grouped_convolution_tile_instances/grouped_convolution_forward_tile_ndhwgc_bf16_calls.inc" + } + else if constexpr(SIGNATURE == SIGNATURE_NDHWGC_FP32_FWD) + { +#include "../../experimental/grouped_convolution_tile_instances/grouped_convolution_forward_tile_ndhwgc_fp32_calls.inc" + } + else + { + std::cout << "Signature not supported" << std::endl; + return std::make_tuple(false, best_avg_time, best_op_name); + } + return std::make_tuple(valid, best_avg_time, best_op_name); +} + +} // namespace ck_tile::builder::profiling diff --git a/profiler/include/profiler/grouped_convolution_signatures.hpp b/profiler/include/profiler/grouped_convolution_signatures.hpp new file mode 100644 index 0000000000..5103b0f235 --- /dev/null +++ b/profiler/include/profiler/grouped_convolution_signatures.hpp @@ -0,0 +1,70 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include "../../experimental/builder/test/impl/conv_signature_types.hpp" +#include "ck_tile/builder/testing/conv_fwd_ck_tile.hpp" + +namespace ck_tile::builder::profiling { + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; + +constexpr auto SIGNATURE_NHWGC_FP32_FWD = + ckt::ConvSignature{.spatial_dim = 2, + .direction = ckb::ConvDirection::FORWARD, + .data_type = ckb::DataType::FP32, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = ckb::TensorLayout::NHWGC}}, + .weight = {.config = {.layout = ckb::TensorLayout::GKYXC}}, + .output = {.config = {.layout = ckb::TensorLayout::NHWGK}}}; + +constexpr auto SIGNATURE_NHWGC_BF16_FWD = + ckt::ConvSignature{.spatial_dim = 2, + .direction = ckb::ConvDirection::FORWARD, + .data_type = ckb::DataType::BF16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = ckb::TensorLayout::NHWGC}}, + .weight = {.config = {.layout = ckb::TensorLayout::GKYXC}}, + .output = {.config = {.layout = ckb::TensorLayout::NHWGK}}}; + +constexpr auto SIGNATURE_NHWGC_FP16_FWD = + ckt::ConvSignature{.spatial_dim = 2, + .direction = ckb::ConvDirection::FORWARD, + .data_type = ckb::DataType::FP16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = ckb::TensorLayout::NHWGC}}, + .weight = {.config = {.layout = ckb::TensorLayout::GKYXC}}, + .output = {.config = {.layout = ckb::TensorLayout::NHWGK}}}; + +constexpr auto SIGNATURE_NDHWGC_FP32_FWD = + ckt::ConvSignature{.spatial_dim = 3, + .direction = ckb::ConvDirection::FORWARD, + .data_type = ckb::DataType::FP32, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = ckb::TensorLayout::NDHWGC}}, + .weight = {.config = {.layout = ckb::TensorLayout::GKZYXC}}, + .output = {.config = {.layout = ckb::TensorLayout::NDHWGK}}}; + +constexpr auto SIGNATURE_NDHWGC_BF16_FWD = + ckt::ConvSignature{.spatial_dim = 3, + .direction = ckb::ConvDirection::FORWARD, + .data_type = ckb::DataType::BF16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = ckb::TensorLayout::NDHWGC}}, + .weight = {.config = {.layout = ckb::TensorLayout::GKZYXC}}, + .output = {.config = {.layout = ckb::TensorLayout::NDHWGK}}}; + +constexpr auto SIGNATURE_NDHWGC_FP16_FWD = + ckt::ConvSignature{.spatial_dim = 3, + .direction = ckb::ConvDirection::FORWARD, + .data_type = ckb::DataType::FP16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = ckb::TensorLayout::NDHWGC}}, + .weight = {.config = {.layout = ckb::TensorLayout::GKZYXC}}, + .output = {.config = {.layout = ckb::TensorLayout::NDHWGK}}}; + +} // namespace ck_tile::builder::profiling diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index e484ff9ef7..3379fd15d1 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -43,6 +43,9 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") list(APPEND PROFILER_OPS profile_contraction_bilinear.cpp) list(APPEND PROFILER_OPS profile_contraction_scale.cpp) endif() + if(CK_EXPERIMENTAL_BUILDER) + list(APPEND PROFILER_OPS profile_grouped_conv_fwd_tile.cpp) + endif() endif() if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]") @@ -256,6 +259,12 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]") list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance) endif() +if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") + if(CK_EXPERIMENTAL_BUILDER) + list(APPEND DEVICE_INSTANCES device_grouped_conv_fwd_tile_instances) + endif() +endif() + if(DL_KERNELS) list(APPEND DEVICE_INSTANCES device_batched_gemm_multi_d_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv1d_bwd_weight_instance) diff --git a/profiler/src/profile_grouped_conv_fwd_tile.cpp b/profiler/src/profile_grouped_conv_fwd_tile.cpp new file mode 100644 index 0000000000..8023dcf2f6 --- /dev/null +++ b/profiler/src/profile_grouped_conv_fwd_tile.cpp @@ -0,0 +1,201 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include + +#include "ck_tile/builder/testing/conv_fwd_ck_tile.hpp" +#include "ck_tile/host/device_prop.hpp" +#include "profiler/grouped_convolution_forward_tile_algs.hpp" + +#include "profiler_operation_registry.hpp" + +namespace { + +enum struct ConvLayout +{ + GNHWC_GKYXC_GNHWK, // 0 + NHWGC_GKYXC_NHWGK, // 1 + NGCHW_GKYXC_NGKHW, // 2 + NGCHW_GKCYX_NGKHW, // 3 +}; + +enum struct ConvDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 + F8_F8_F8, // 4 + BF8_BF8_F8, // 5 + F8_BF8_F8, // 6 + BF8_F8_F8, // 7 + F32_F32_F32_TF32, // 8 +}; + +enum struct IndexType +{ + INDEX_T, // 0 + LONG_INDEX_T, // 1 +}; + +static void print_helper_msg() +{ + std::cout + // clang-format off + << "arg1: tensor operation (grouped_conv_fwd : Grouped Convolution Forward)\n" + << "arg2: data type (0: Input fp32, Weight fp32, Output fp32\n" + << " 1: Input fp16, Weight fp16, Output fp16\n" + << " 2: Input bf16, Weight bf16, Output bf16\n" + << " 3: Input int8, Weight int8, Output int8\n" + << " 4: Input fp8, Weight fp8, Output fp8\n" + << " 5: Input bf8, Weight bf8, Output fp8\n" + << " 6: Input fp8, Weight bf8, Output fp8\n" + << " 7: Input bf8, Weight fp8, Output fp8\n" + << " 8: Input fp32, Weight fp32, Output fp32, Compute tf32)\n" + << "arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n" + << " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K]\n" + << " 2: Input[N, G, C, Hi, Wi], Weight[G, K, Y, X, C], Output[N, " + "G, K, Ho, Wo]\n" + << " 3: Input[N, G, C, Hi, Wi], Weight[G, K, C, Y, X], Output[N, " + "G, K, Ho, Wo])\n" + << "arg4: indexing data type (0: 32-bit, 1: 64-bit)\n" + << "arg5: verification (0: no, 1: yes)\n" + << "arg6: initialization (0: no init, 1: integer value, 2: decimal value)\n" + << "arg7: print tensor value (0: no; 1: yes)\n" + << "arg8: time kernel (0: no, 1: yes)\n" + << "Following arguments (depending on number of spatial dims):\n" + << " Number of spatial dimensions (1=Conv1d, 2=Conv2d, 3=Conv3d)\n" + << " G, N, K, C, \n" + << " , (ie Y, X for 2D)\n" + << " , (ie Hi, Wi for 2D)\n" + << " , (ie Sy, Sx for 2D)\n" + << " , (ie Dy, Dx for 2D)\n" + << " , (ie LeftPy, LeftPx for 2D)\n" + << " , (ie RightPy, RightPx for 2D)\n" << std::endl; + // clang-format on +} + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; +namespace ckp = ck_tile::builder::profiling; + +template +int call_profiler(const ckt::Args& args, bool time_kernel) +{ + auto inputs = alloc_inputs(args); + auto outputs = alloc_outputs(args); + ckt::init_inputs(args, inputs.get()); + + std::cout << args.make_input_descriptor() << std::endl; + std::cout << args.make_weight_descriptor() << std::endl; + std::cout << args.make_output_descriptor() << std::endl; + float avg_time; + std::string op_name; + bool valid; + std::tie(valid, avg_time, op_name) = ckp::run_grouped_conv_forward_tile_algs( + args, inputs.get(), outputs.get(), ck_tile::stream_config{nullptr, time_kernel}); + if(time_kernel) + { + std::cout << "Best configuration parameters:" << "\nname: " << op_name + << "\navg_time: " << avg_time << std::endl; + } + return !valid; +} + +#define OP_NAME "grouped_conv_fwd_tile" +#define OP_DESC "Grouped Convolution Forward (CK Tile)" + +} // namespace + +int profile_grouped_conv_fwd_tile(int argc, char* argv[]) +{ + // 8 for control, 1 for num_dim_spatial + if(argc < 10) + { + print_helper_msg(); + return 1; + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const auto index_type = static_cast(std::stoi(argv[4])); + [[maybe_unused]] const bool do_verification = std::stoi(argv[5]); + [[maybe_unused]] const int init_method = std::stoi(argv[6]); + [[maybe_unused]] const bool do_log = std::stoi(argv[7]); + const bool time_kernel = std::stoi(argv[8]); + const int num_dim_spatial = std::stoi(argv[9]); + + // 9 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial + if(argc != 9 + 1 + 4 + 6 * num_dim_spatial) + { + print_helper_msg(); + return 1; + } + + std::cout << "IMPORTANT: Generate instances using: python " + "experimental/builder/src/generate_instances.py --mode=profiler and rerun cmake" + << std::endl; + + const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 10, argv); + + if(index_type == IndexType::LONG_INDEX_T) + { + std::cout << "this indexing data type is not implemented" << std::endl; + return 1; + } + + if(layout == ConvLayout::NHWGC_GKYXC_NHWGK) + { + if(num_dim_spatial == 2) + { + if(data_type == ConvDataType::F32_F32_F32) + { + constexpr auto SIGNATURE = ckp::SIGNATURE_NHWGC_FP32_FWD; + return call_profiler(ckp::parse_conv_args(10, argv), + time_kernel); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + constexpr auto SIGNATURE = ckp::SIGNATURE_NHWGC_FP16_FWD; + return call_profiler(ckp::parse_conv_args(10, argv), + time_kernel); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + constexpr auto SIGNATURE = ckp::SIGNATURE_NHWGC_BF16_FWD; + return call_profiler(ckp::parse_conv_args(10, argv), + time_kernel); + } + } + else if(num_dim_spatial == 3) + { + if(data_type == ConvDataType::F32_F32_F32) + { + constexpr auto SIGNATURE = ckp::SIGNATURE_NDHWGC_FP32_FWD; + return call_profiler(ckp::parse_conv_args(10, argv), + time_kernel); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + constexpr auto SIGNATURE = ckp::SIGNATURE_NDHWGC_FP16_FWD; + return call_profiler(ckp::parse_conv_args(10, argv), + time_kernel); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + constexpr auto SIGNATURE = ckp::SIGNATURE_NDHWGC_BF16_FWD; + return call_profiler(ckp::parse_conv_args(10, argv), + time_kernel); + } + } + } + + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_grouped_conv_fwd_tile); diff --git a/test/grouped_convnd_fwd/CMakeLists.txt b/test/grouped_convnd_fwd/CMakeLists.txt index 5e2db1184c..6f8b71679c 100644 --- a/test/grouped_convnd_fwd/CMakeLists.txt +++ b/test/grouped_convnd_fwd/CMakeLists.txt @@ -19,6 +19,18 @@ if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") target_link_libraries(test_grouped_convnd_fwd_large_cases PRIVATE gtest_main getopt::getopt utility device_grouped_conv1d_fwd_instance device_grouped_conv2d_fwd_instance device_grouped_conv3d_fwd_instance) endif() +if(GPU_TARGETS MATCHES "gfx9") + if(CK_EXPERIMENTAL_BUILDER) + # TODO: Reenable after the instance fixes + # add_executable(test_grouped_convnd_fwd_tile test_grouped_convnd_fwd_tile.cpp) + # target_compile_options(test_grouped_convnd_fwd_tile PRIVATE -Wno-global-constructors -Wno-undef -Wno-c++20-compat) + # target_link_libraries(test_grouped_convnd_fwd_tile PRIVATE gtest_main getopt::getopt utility) + # if(TARGET device_grouped_conv_fwd_tile_instances) + # target_link_libraries(test_grouped_convnd_fwd_tile PRIVATE device_grouped_conv_fwd_tile_instances) + # endif() + endif() +endif() + add_gtest_executable(test_grouped_convnd_fwd_multi_ab_interface test_grouped_convnd_fwd_multi_ab_interface.cpp) if(result EQUAL 0) target_link_libraries(test_grouped_convnd_fwd_multi_ab_interface PRIVATE utility) diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_tile.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_tile.cpp new file mode 100644 index 0000000000..c04a15ec98 --- /dev/null +++ b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_tile.cpp @@ -0,0 +1,273 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include +#include + +#include "ck_tile/builder/testing/conv_fwd_ck_tile.hpp" +#include "ck_tile/host/device_prop.hpp" +#include "profiler/grouped_convolution_forward_tile_algs.hpp" + +// TODO: Remove limitation of conv fwd gpu reference which does not support right pad +#define CK_CONV_FWD_REF_SKIP_RIGHT_PAD_CASES 1 + +static ck::index_t args_mask = 0xffff; +static ck::index_t instance_index = -1; + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; +namespace ckp = ck_tile::builder::profiling; + +template +struct SignatureDetails +{ + static constexpr ck_tile::index_t num_spatial_dim = num_spatial_dim_; + static constexpr ckb::DataType data_type = data_type_; + static constexpr ckb::DataType acc_data_type = acc_data_type_; + static constexpr ckb::TensorLayout in_layout = in_layout_; + static constexpr ckb::TensorLayout wei_layout = wei_layout_; + static constexpr ckb::TensorLayout out_layout = out_layout_; +}; + +template +class TestGroupedConvndFwdTile : public ::testing::Test +{ + protected: + static constexpr auto SIGNATURE = + ckt::ConvSignature{.spatial_dim = SignatureDetailsType::num_spatial_dim, + .direction = ckb::ConvDirection::FORWARD, + .data_type = SignatureDetailsType::data_type, + .accumulation_data_type = SignatureDetailsType::acc_data_type, + .input = {.config = {.layout = SignatureDetailsType::in_layout}}, + .weight = {.config = {.layout = SignatureDetailsType::wei_layout}}, + .output = {.config = {.layout = SignatureDetailsType::out_layout}}}; + + std::vector> conv_args; + + template + void Run() + { + EXPECT_FALSE(conv_args.empty()); + bool pass = true; + for(size_t i = 0; i < conv_args.size(); i++) + { + if((args_mask & (1 << i)) == 0) + { + continue; + } + auto& args = conv_args[i]; + + auto inputs = alloc_inputs(args); + auto outputs = alloc_outputs(args); + ckt::init_inputs(args, inputs.get()); + + std::cout << args.make_input_descriptor() << std::endl; + std::cout << args.make_weight_descriptor() << std::endl; + std::cout << args.make_output_descriptor() << std::endl; + float avg_time; + std::string op_name; + bool case_passed; + std::tie(case_passed, avg_time, op_name) = ckp::run_grouped_conv_forward_tile_algs( + args, + inputs.get(), + outputs.get(), + ck_tile::stream_config{nullptr, false /*time_kernel*/}); + + pass = pass && case_passed; + } + EXPECT_TRUE(pass); + } + + void conv_args_append(std::size_t, + std::size_t G, + std::size_t N, + std::size_t K, + std::size_t C, + const std::vector& filter_spatial_lengths, + const std::vector& input_spatial_lengths, + const std::vector& conv_filter_strides, + const std::vector& conv_filter_dilations, + const std::vector& input_left_pads, + const std::vector& input_right_pads) + { +#if CK_CONV_FWD_REF_SKIP_RIGHT_PAD_CASES + bool without_right_pad = true; + for(const std::size_t& right_pad : input_right_pads) + { + without_right_pad &= right_pad == 0; + } + if(!without_right_pad) + { + return; + } +#endif + ckt::Args args = { + .lengths = + { + .batch_size = N, + .groups = G, + .input_channels = C, + .output_channels = K, + .image = ckt::filter_extent_from_vector( + input_spatial_lengths), + .filter = ckt::filter_extent_from_vector( + filter_spatial_lengths), + }, + .filter_strides = ckt::filter_extent_from_vector( + conv_filter_strides), + .filter_dilation = + ckt::filter_extent_from_vector( + conv_filter_dilations), + .input_left_pad = ckt::filter_extent_from_vector( + input_left_pads), + .input_right_pad = + ckt::filter_extent_from_vector( + input_right_pads), + .a_elementwise_op = {}, + .b_elementwise_op = {}, + .cde_elementwise_op = {}, + }; + conv_args.push_back(args); + } +}; + +using KernelTypes2d = ::testing::Types, + SignatureDetails<2, + ckb::DataType::FP16, + ckb::DataType::FP32, + ckb::TensorLayout::NHWGC, + ckb::TensorLayout::GKYXC, + ckb::TensorLayout::NHWGK>, + SignatureDetails<2, + ckb::DataType::BF16, + ckb::DataType::FP32, + ckb::TensorLayout::NHWGC, + ckb::TensorLayout::GKYXC, + ckb::TensorLayout::NHWGK>>; + +using KernelTypes3d = ::testing::Types, + SignatureDetails<3, + ckb::DataType::FP16, + ckb::DataType::FP32, + ckb::TensorLayout::NDHWGC, + ckb::TensorLayout::GKZYXC, + ckb::TensorLayout::NDHWGK>, + SignatureDetails<3, + ckb::DataType::BF16, + ckb::DataType::FP32, + ckb::TensorLayout::NDHWGC, + ckb::TensorLayout::GKZYXC, + ckb::TensorLayout::NDHWGK>>; + +template +class TestGroupedConvndFwdTile2d : public TestGroupedConvndFwdTile +{ +}; + +template +class TestGroupedConvndFwdTile3d : public TestGroupedConvndFwdTile +{ +}; + +TYPED_TEST_SUITE(TestGroupedConvndFwdTile2d, KernelTypes2d); +TYPED_TEST_SUITE(TestGroupedConvndFwdTile3d, KernelTypes3d); + +TYPED_TEST(TestGroupedConvndFwdTile2d, Test2D) +{ + this->conv_args.clear(); + this->conv_args_append(2, 3, 5, 96, 200, {1, 1}, {73, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}); + this->conv_args_append(2, 1, 1, 32, 32, {1, 1}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}); + this->conv_args_append(2, 1, 1, 32, 32, {2, 2}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}); + this->conv_args_append(2, 1, 1, 32, 32, {3, 3}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}); + this->conv_args_append(2, 1, 1, 32, 32, {5, 5}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}); + this->conv_args_append(2, 1, 1, 32, 32, {9, 9}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}); + + this->conv_args_append(2, 2, 32, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}); + + this->conv_args_append(2, 2, 32, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}); + + this->conv_args_append(2, 2, 32, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}); + this->conv_args_append(2, 1, 1, 1, 32, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}); + this->conv_args_append(2, 1, 1, 64, 3, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}); + this->conv_args_append(2, 1, 1, 1, 1, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}); + + this->conv_args_append(2, 96, 1, 1, 1, {1, 1}, {120, 160}, {1, 1}, {1, 1}, {1, 1}, {1, 1}); + this->conv_args_append(2, 96, 1, 1, 1, {3, 3}, {120, 160}, {1, 1}, {1, 1}, {1, 1}, {1, 1}); + this->template Run<2>(); +} + +TYPED_TEST(TestGroupedConvndFwdTile3d, Test3D) +{ + this->conv_args.clear(); + + this->conv_args_append( + 3, 3, 5, 96, 200, {1, 1, 1}, {37, 37, 16}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}); + this->conv_args_append( + 3, 1, 1, 32, 32, {1, 1, 1}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}); + this->conv_args_append( + 3, 1, 1, 32, 32, {2, 2, 2}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}); + this->conv_args_append( + 3, 1, 1, 32, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}); + this->conv_args_append( + 3, 1, 1, 32, 32, {5, 5, 5}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}); + this->conv_args_append( + 3, 1, 1, 32, 32, {9, 9, 9}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}); + + this->conv_args_append( + 3, 2, 32, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}); + this->conv_args_append( + 3, 2, 32, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}); + + this->conv_args_append( + 3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}); + this->conv_args_append( + 3, 1, 1, 32, 32, {1, 1, 1}, {16, 16, 16}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}); + + this->conv_args_append( + 3, 1, 1, 1, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}); + this->conv_args_append( + 3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}); + this->conv_args_append( + 3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}); + + this->conv_args_append( + 3, 96, 1, 1, 1, {1, 1, 1}, {120, 40, 20}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}); + this->conv_args_append( + 3, 96, 1, 1, 1, {3, 3, 3}, {120, 40, 20}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}); + this->template Run<3>(); +} + +int main(int argc, char** argv) +{ + testing::InitGoogleTest(&argc, argv); + if(argc == 1) {} + else if(argc == 3) + { + args_mask = strtol(argv[1], nullptr, 0); + instance_index = atoi(argv[2]); + } + else + { + std::cout << "Usage of " << argv[0] << std::endl; + std::cout << "Arg1,2: args_mask instance_index(-1 means all)" << std::endl; + } + return RUN_ALL_TESTS(); +} From b09121f86066381f3662fdbdee6a810849a8a1a7 Mon Sep 17 00:00:00 2001 From: Wojciech Laskowski <77888887+wj-laskowski@users.noreply.github.com> Date: Tue, 20 Jan 2026 10:50:46 +0100 Subject: [PATCH 28/99] WMMA support for batched_gemm_reduce (#3332) Summary: - added new device impl of Batched GEMM Reduce for WMMA - added instance library - added WMMA impl to the Batched GEMM Reduce tests --- ...e_batched_gemm_reduce_wmma_cshuffle_v3.hpp | 799 ++++++++++++++++++ .../gpu/batched_gemm_reduce/CMakeLists.txt | 6 +- ...6_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp | 88 ++ ...6_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp | 88 ++ ...6_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp | 87 ++ ...6_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp | 86 ++ ...6_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp | 2 +- ...6_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp | 2 +- ...6_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp | 2 +- .../profile_batched_gemm_reduce_impl.hpp | 67 +- test/batched_gemm_reduce/CMakeLists.txt | 10 +- .../batched_gemm_reduce_fp16.cpp | 119 +++ .../batched_gemm_reduce_fp16_xdl.cpp | 67 -- 13 files changed, 1345 insertions(+), 78 deletions(-) create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_wmma_cshuffle_v3.hpp create mode 100644 library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp create mode 100644 test/batched_gemm_reduce/batched_gemm_reduce_fp16.cpp delete mode 100644 test/batched_gemm_reduce/batched_gemm_reduce_fp16_xdl.cpp diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_wmma_cshuffle_v3.hpp new file mode 100644 index 0000000000..227a8aedd9 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_wmma_cshuffle_v3.hpp @@ -0,0 +1,799 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_batched_gemm_reduce_wmma_cshuffle_v3( + typename GridwiseGemm::Argument karg, + typename ReduceTrait::ReducePtrsGlobal_ p_reduces_grid, + const typename ReduceTrait::ReduceInElementwiseOperations_ reduce_in_element_ops, + const typename ReduceTrait::ReduceAccElementwiseOperations_ reduce_out_element_ops, + const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch) +{ +#if(defined(__gfx11__) || defined(__gfx12__)) +#if defined(__gfx11__) + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + using e_data_type = remove_cvref_t>; + if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && + (std::is_same_v || + std::is_same_v))) + { +#endif + using EpilogueType = typename GridwiseGemm::template EpilogueReduceCShuffle; + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); + __shared__ char p_shared[LDS_size]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + const index_t g_idx = amd_wave_read_first_lane(blockIdx.y); + + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t c_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)); + + auto reduces_batch = p_reduces_grid; + compute_ptr_offset_of_batch.OffsetReducePtrs(g_idx, reduces_batch); + + typename GridwiseGemm::AsGridPointer p_as_grid_shift; + static_for<0, GridwiseGemm::NumATensor, 1>{}([&](auto i) { + using ADataType_ = + remove_cvref_t>; + p_as_grid_shift(i) = static_cast(karg.p_as_grid[i]) + + splitk_batch_offset.a_k_split_offset[i] + a_batch_offset; + }); + + typename GridwiseGemm::BsGridPointer p_bs_grid_shift; + static_for<0, GridwiseGemm::NumBTensor, 1>{}([&](auto i) { + using BDataType_ = + remove_cvref_t>; + p_bs_grid_shift(i) = static_cast(karg.p_bs_grid[i]) + + splitk_batch_offset.b_k_split_offset[i] + b_batch_offset; + }); + + auto epilogue_args = EpilogueType(reduces_batch, + reduce_in_element_ops, + reduce_out_element_ops, + karg.M, + tensor_operation::element_wise::PassThrough{}); + + GridwiseGemm::template Run( + p_as_grid_shift, + p_bs_grid_shift, + karg.p_ds_grid, + karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.cde_element_op, + epilogue_args); +#if defined(__gfx11__) + } +#endif +#else + ignore = karg; + ignore = p_reduces_grid; + ignore = reduce_in_element_ops; + ignore = reduce_out_element_ops; + ignore = compute_ptr_offset_of_batch; +#endif +} + +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceBatchedGemmReduce_Wmma_CShuffleV3 + : public DeviceGemmReduce<0, ReduceOperations::Size()> +{ + using DeviceOp = DeviceBatchedGemmReduce_Wmma_CShuffleV3; + + static_assert(PermuteA == false, + "Permute A functionality not supported by DeviceBatchedGemm operations.\n"); + static_assert(PermuteB == false, + "Permute B functionality not supported by DeviceBatchedGemm operations.\n"); + + using CDEShuffleBlockTransferScalarPerVectors = + Sequence; + + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< + ALayout, + BLayout, + Tuple<>, + ELayout, + Tuple, + Tuple, + AccDataType, + CShuffleDataType, + Tuple<>, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB, + false, + false, + true>; + + using ReduceTrait = ReduceTrait_; + + static constexpr index_t NumReduce = ReduceOperations::Size(); + + struct ComputePtrOffsetOfStridedBatch + { + ComputePtrOffsetOfStridedBatch(long_index_t BatchStrideA, + long_index_t BatchStrideB, + long_index_t BatchStrideC, + std::array BatchStrideReduce) + : BatchStrideA_{BatchStrideA}, + BatchStrideB_{BatchStrideB}, + BatchStrideC_{BatchStrideC}, + BatchStrideReduce_{BatchStrideReduce} + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return g_idx * BatchStrideA_; + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return g_idx * BatchStrideB_; + } + + __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const + { + return g_idx * BatchStrideC_; + } + + template + __host__ __device__ void OffsetReducePtrs(index_t g_idx, ReducePtrs& ptrs) const + { + static_for<0, NumReduce, 1>{}( + [&](auto I) { ptrs(I) = ptrs(I) + g_idx * BatchStrideReduce_[I.value]; }); + } + + private: + long_index_t BatchStrideA_; + long_index_t BatchStrideB_; + long_index_t BatchStrideC_; + std::array BatchStrideReduce_{}; + }; + + private: + static long_index_t ComputeABatchStride(index_t MRaw, index_t KRaw, index_t StrideA) + { + if constexpr(is_same_v) + { + return static_cast(MRaw) * StrideA; + } + else + { + return static_cast(KRaw) * StrideA; + } + } + + static long_index_t ComputeBBatchStride(index_t KRaw, index_t NRaw, index_t StrideB) + { + if constexpr(is_same_v) + { + return static_cast(KRaw) * StrideB; + } + else + { + return static_cast(NRaw) * StrideB; + } + } + + static long_index_t ComputeCBatchStride(index_t MRaw, index_t NRaw, index_t StrideC) + { + if constexpr(is_same_v) + { + return static_cast(MRaw) * StrideC; + } + else + { + return static_cast(NRaw) * StrideC; + } + } + + public: + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + EDataType* p_e_grid, + ReducePtrsGlobal p_reduces_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t Batch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + ReduceInElementwiseOperations reduce_in_element_ops, + ReduceAccElementwiseOperations reduce_out_element_ops, + std::array batch_stride_reduce) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_e_grid_{p_e_grid}, + p_reduces_grid_{p_reduces_grid}, + MRaw_{MRaw}, + NRaw_{NRaw}, + KRaw_{KRaw}, + StrideA_{StrideA}, + StrideB_{StrideB}, + StrideC_{StrideC}, + Batch_{Batch}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op}, + reduce_in_element_ops_{reduce_in_element_ops}, + reduce_out_element_ops_{reduce_out_element_ops}, + batch_stride_reduce_{batch_stride_reduce}, + compute_ptr_offset_of_batch_( + ComputePtrOffsetOfStridedBatch{ComputeABatchStride(MRaw, KRaw, StrideA), + ComputeBBatchStride(KRaw, NRaw, StrideB), + ComputeCBatchStride(MRaw, NRaw, StrideC), + batch_stride_reduce}) + { + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + EDataType* p_e_grid_; + ReducePtrsGlobal p_reduces_grid_; + index_t MRaw_; + index_t NRaw_; + index_t KRaw_; + index_t StrideA_; + index_t StrideB_; + index_t StrideC_; + index_t Batch_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + ReduceInElementwiseOperations reduce_in_element_ops_; + ReduceAccElementwiseOperations reduce_out_element_ops_; + std::array batch_stride_reduce_{}; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + typename GridwiseGemm::Argument gemm_arg{ + std::array{arg.p_a_grid_}, + std::array{arg.p_b_grid_}, + std::array{}, + static_cast(arg.p_e_grid_), + arg.MRaw_, + arg.NRaw_, + arg.KRaw_, + std::array{arg.StrideA_}, // StrideAs + std::array{arg.StrideB_}, // StrideBs + std::array{}, // StrideDs + arg.StrideC_, // StrideC + 1, // kbatch + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_}; + + if(stream_config.log_level_ > 0) + { + gemm_arg.Print(); + GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print(); + } + + if(!GridwiseGemm::CheckValidity(gemm_arg, true)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.MRaw_, arg.NRaw_, 1); + + gdy *= arg.Batch_; + + float ave_time = 0; + + const index_t K_split = (arg.KRaw_ + KPerBlock - 1) / KPerBlock * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + const TailNumber tail_num = GridwiseGemm::CalculateKBlockLoopTailNum(K_split); + + const auto Run = [&](const auto& kernel) { + // Note: cache flushing not supported + + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg, + arg.p_reduces_grid_, + arg.reduce_in_element_ops_, + arg.reduce_out_element_ops_, + arg.compute_ptr_offset_of_batch_); + }; + + constexpr index_t minimum_occupancy = []() { + if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave) + { + return 2; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1; + } + else + { + return 1; + } + }(); + + auto CreateAndRunKernel = [&](auto has_main_k_block_loop_, auto tail_number_) { + constexpr bool has_loop = decltype(has_main_k_block_loop_)::value; + constexpr TailNumber tn = tail_number_; + + const auto kernel = + kernel_batched_gemm_reduce_wmma_cshuffle_v3; + + Run(kernel); + }; + + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(has_main_k_block_loop && tail_num == TailNumber::Full) + { + CreateAndRunKernel(std::integral_constant{}, + std::integral_constant{}); + } + else if(!has_main_k_block_loop && tail_num == TailNumber::Full) + { + CreateAndRunKernel(std::integral_constant{}, + std::integral_constant{}); + } + else + { + printf("Invalid has_main_k_block_loop and tail_num combination for V1!\n"); + return 0.0f; + } + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(has_main_k_block_loop && tail_num == TailNumber::Full) + { + CreateAndRunKernel(std::integral_constant{}, + std::integral_constant{}); + } + else if(!has_main_k_block_loop && tail_num == TailNumber::Even) + { + CreateAndRunKernel(std::integral_constant{}, + std::integral_constant{}); + } + else if(!has_main_k_block_loop && tail_num == TailNumber::Odd) + { + CreateAndRunKernel(std::integral_constant{}, + std::integral_constant{}); + } + else + { + printf("Invalid has_main_k_block_loop and tail_num combination for V3!\n"); + return 0.0f; + } + } + else + { + printf("Invalid pipeline version!\n"); + return 0.0f; + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() { return true; } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Device implementation supports only gfx11 and gfx12! " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) + { + if(ck::is_gfx11_supported()) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "FP8 and BF8 not supported on gfx11! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if((arg.KRaw_ % AK1 != 0 || arg.KRaw_ % BK1 != 0) && + !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Without padding, K must be divisible by AK1 and BK1! " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + + typename GridwiseGemm::Argument gemm_arg{std::array{arg.p_a_grid_}, + std::array{arg.p_b_grid_}, + std::array{}, + static_cast(arg.p_e_grid_), + arg.MRaw_, + arg.NRaw_, + arg.KRaw_, + std::array{arg.StrideA_}, // StrideAs + std::array{arg.StrideB_}, // StrideBs + std::array{}, // StrideDs + arg.StrideC_, // StrideC + 1, // kbatch + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_}; + + return GridwiseGemm::CheckValidity(gemm_arg, true); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + const void* p_bias, + std::array p_ds, + void* p_e, + std::array p_reduces, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + std::array StrideDs, + std::array gemm_element_ops, + std::array d_element_ops, + std::array reduce_in_element_op, + std::array reduce_out_element_op, + ck::index_t Batch) + { + (void)p_bias; + (void)p_ds; + (void)StrideDs; + (void)d_element_ops; + + ReducePtrsGlobal reduce_tuple = generate_tuple( + [&](auto I) { + auto tmp = ReducePtrsGlobal{}[I]; + using T = remove_pointer_t; + return static_cast(p_reduces[I.value]); + }, + Number{}); + + ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceInElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_in_element_op[I.value])); + }, + Number{}); + + ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceAccElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_out_element_op[I.value])); + }, + Number{}); + + AElementwiseOperation a_element_op = + *(static_cast(gemm_element_ops[0])); + BElementwiseOperation b_element_op = + *(static_cast(gemm_element_ops[1])); + CElementwiseOperation c_element_op = + *(static_cast(gemm_element_ops[2])); + + std::array batch_stride_reduce{}; + static_for<0, NumReduce, 1>{}( + [&](auto I) { batch_stride_reduce[I.value] = static_cast(M); }); + + return Argument{static_cast(p_a), + static_cast(p_b), + static_cast(p_e), + reduce_tuple, + M, + N, + K, + StrideA, + StrideB, + StrideC, + Batch, + a_element_op, + b_element_op, + c_element_op, + reduce_in_element_ops, + reduce_out_element_ops, + batch_stride_reduce}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_bias, + std::array p_ds, + void* p_e, + std::array p_reduces, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + std::array StrideDs, + std::array gemm_element_ops, + std::array d_element_ops, + std::array reduce_in_element_op, + std::array reduce_out_element_op, + ck::index_t Batch = 1) override + { + (void)p_bias; + (void)p_ds; + (void)StrideDs; + (void)d_element_ops; + + ReducePtrsGlobal reduce_tuple = generate_tuple( + [&](auto I) { + auto tmp = ReducePtrsGlobal{}[I]; + using T = remove_pointer_t; + return static_cast(p_reduces[I.value]); + }, + Number{}); + + ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceInElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_in_element_op[I.value])); + }, + Number{}); + + ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceAccElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_out_element_op[I.value])); + }, + Number{}); + + AElementwiseOperation a_element_op = + *(static_cast(gemm_element_ops[0])); + BElementwiseOperation b_element_op = + *(static_cast(gemm_element_ops[1])); + CElementwiseOperation c_element_op = + *(static_cast(gemm_element_ops[2])); + + std::array batch_stride_reduce{}; + static_for<0, NumReduce, 1>{}( + [&](auto I) { batch_stride_reduce[I.value] = static_cast(M); }); + + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_e), + reduce_tuple, + M, + N, + K, + StrideA, + StrideB, + StrideC, + Batch, + a_element_op, + b_element_op, + c_element_op, + reduce_in_element_ops, + reduce_out_element_ops, + batch_stride_reduce); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + str << "DeviceBatchedGemmReduce_Wmma_CShuffleV3" << "<" << BlockSize << ", " << MPerBlock + << ", " << NPerBlock << ", " << KPerBlock << ", " << AK1 << ", " << BK1 << ", " + << MPerWmma << ", " << NPerWmma << ", " << MRepeat << ", " << NRepeat << ", " + << ABlockTransferSrcScalarPerVector << ", " << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMRepeatPerShuffle << ", " << CShuffleNRepeatPerShuffle << ">"; + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt index a098a0a7e5..89626f1afa 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt @@ -1,10 +1,14 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS add_instance_library(device_batched_gemm_reduce_instance device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp + device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp + device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp + device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp + device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp new file mode 100644 index 0000000000..e604c358cf --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp @@ -0,0 +1,88 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +// c[g, m, n] = a[g, k, m] * b[g, k, n] +using device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances = + std::tuple< + // clang-format off + //#####################################| ALayout| BLayout| ELayout|AData| BData| EData| Acc| CShuffle| ReduceAcc| ReducePtrsGlobal| A| B| C| Reduce| ReduceIn| ReduceAcc| ReduceGlobal| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CDEShuffleBlockTransferClusterLengths| CDEShuffleBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| BlkGemm| BlkGemm| + //#####################################| | | | Type| Type| Type| DataType| DataType| DataType| | Elementwise| Elementwise| Elementwise| Operation| Elementwise| Elementwise| MemoryData| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| PipeSched| PipelineVer| + //#####################################| | | | | | | | | | | Operation| Operation| Operation| | Operations| Operations| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| | | + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // v1 Intrawave + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 2, 2, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 2, 2, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 2, 2, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 2, 2, 16, 16, 2, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + + // v1 Interwave + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 2, 2, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 2, 2, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, S<64, 4>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 2, 2, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 2, 2, 16, 16, 2, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + + // v3 Intrawave + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 2, 2, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 2, 2, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 2, 2, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 2, 2, 16, 16, 2, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; + +void add_device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp new file mode 100644 index 0000000000..6bd538750b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp @@ -0,0 +1,88 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +// c[g, m, n] = a[g, k, m] * b[g, n, k] +using device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances = + std::tuple< + // clang-format off + //#####################################| ALayout| BLayout| ELayout|AData| BData| EData| Acc| CShuffle| ReduceAcc| ReducePtrsGlobal| A| B| C| Reduce| ReduceIn| ReduceAcc| ReduceGlobal| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CDEShuffleBlockTransferClusterLengths| CDEShuffleBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| BlkGemm| BlkGemm| + //#####################################| | | | Type| Type| Type| DataType| DataType| DataType| | Elementwise| Elementwise| Elementwise| Operation| Elementwise| Elementwise| MemoryData| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| PipeSched| PipelineVer| + //#####################################| | | | | | | | | | | Operation| Operation| Operation| | Operations| Operations| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| | | + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // v1 Intrawave + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 2, 2, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 2, 2, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 4, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 2, 2, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 2, 2, 16, 16, 2, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + + // v1 Interwave + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 2, 2, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 2, 2, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 4, S<64, 4>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 2, 2, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 2, 2, 16, 16, 2, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + + // v3 Intrawave + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 2, 2, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 2, 2, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 4, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 2, 2, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 2, 2, 16, 16, 2, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; + +void add_device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp new file mode 100644 index 0000000000..8d75ef7b65 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp @@ -0,0 +1,87 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +// c[g, m, n] = a[g, m, k] * b[g, k, n] +using device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances = + std::tuple< + // clang-format off + //#####################################| ALayout| BLayout| ELayout|AData| BData| EData| Acc| CShuffle| ReduceAcc| ReducePtrsGlobal| A| B| C| Reduce| ReduceIn| ReduceAcc| ReduceGlobal| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CDEShuffleBlockTransferClusterLengths| CDEShuffleBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| BlkGemm| BlkGemm| + //#####################################| | | | Type| Type| Type| DataType| DataType| DataType| | Elementwise| Elementwise| Elementwise| Operation| Elementwise| Elementwise| MemoryData| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| PipeSched| PipelineVer| + //#####################################| | | | | | | | | | | Operation| Operation| Operation| | Operations| Operations| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| | | + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // v1 Intrawave + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 2, 2, 16, 16, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 2, 2, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 2, 2, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 2, 2, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + + // v1 Interwave + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 2, 2, 16, 16, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 2, 2, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, S<64, 4>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 2, 2, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 2, 2, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + + // v3 Intrawave + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 2, 2, 16, 16, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 2, 2, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 2, 2, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 2, 2, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; + +void add_device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp new file mode 100644 index 0000000000..7386ab3bf7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp @@ -0,0 +1,86 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +// c[g, m, n] = a[g, m, k] * b[g, n, k] +using device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances = + std::tuple< + // clang-format off + //#####################################| ALayout| BLayout| ELayout|AData| BData| EData| Acc| CShuffle| ReduceAcc| ReducePtrsGlobal| A| B| C| Reduce| ReduceIn| ReduceAcc| ReduceGlobal| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CDEShuffleBlockTransferClusterLengths| CDEShuffleBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| BlkGemm| BlkGemm| + //#####################################| | | | Type| Type| Type| DataType| DataType| DataType| | Elementwise| Elementwise| Elementwise| Operation| Elementwise| Elementwise| MemoryData| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| PipeSched| PipelineVer| + //#####################################| | | | | | | | | | | Operation| Operation| Operation| | Operations| Operations| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| | | + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // v1 Intrawave + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 8, 8, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, S<16, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + // v1 Interwave + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 8, 8, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, S<64, 4>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, S<16, 4>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + // v3 Intrawave + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 8, 8, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, S<16, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; + +void add_device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp index c8564c120c..689657a505 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp @@ -39,7 +39,7 @@ using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; using DeviceGemmReduceNoOpPtr = ck::tensor_operation::device::DeviceGemmReducePtr<0, ReducePtrsGlobal::Size()>; +#ifdef CK_ENABLE_FP16 +#ifdef CK_USE_XDL void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances( std::vector&); @@ -44,6 +47,22 @@ void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances( std::vector&); +#endif // CK_USE_XDL + +#ifdef CK_USE_WMMA +void add_device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances( + std::vector&); + +void add_device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances( + std::vector&); + +void add_device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances( + std::vector&); + +void add_device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances( + std::vector&); +#endif // CK_USE_WMMA +#endif // CK_ENABLE_FP16 } // namespace instance } // namespace device @@ -210,6 +229,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification, // add device GEMM instances std::vector gemm_ptrs; +#ifdef CK_ENABLE_FP16 if constexpr(is_same::value && is_same::value && is_same::value) { @@ -217,35 +237,64 @@ bool profile_batched_gemm_reduce_impl(int do_verification, is_same::value && is_same::value) { +#ifdef CK_USE_XDL ck::tensor_operation::device::instance:: add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances( gemm_ptrs); +#endif +#ifdef CK_USE_WMMA + ck::tensor_operation::device::instance:: + add_device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances( + gemm_ptrs); +#endif } else if constexpr(is_same::value && is_same::value && is_same::value) { +#ifdef CK_USE_XDL ck::tensor_operation::device::instance:: add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances( gemm_ptrs); +#endif +#ifdef CK_USE_WMMA + ck::tensor_operation::device::instance:: + add_device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances( + gemm_ptrs); +#endif } else if constexpr(is_same::value && is_same::value && is_same::value) { +#ifdef CK_USE_XDL ck::tensor_operation::device::instance:: add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances( gemm_ptrs); +#endif +#ifdef CK_USE_WMMA + ck::tensor_operation::device::instance:: + add_device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances( + gemm_ptrs); +#endif } else if constexpr(is_same::value && is_same::value && is_same::value) { +#ifdef CK_USE_XDL ck::tensor_operation::device::instance:: add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances( gemm_ptrs); +#endif +#ifdef CK_USE_WMMA + ck::tensor_operation::device::instance:: + add_device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances( + gemm_ptrs); +#endif } } +#endif // CK_ENABLE_FP16 if(gemm_ptrs.size() <= 0) { @@ -318,9 +367,21 @@ bool profile_batched_gemm_reduce_impl(int do_verification, reduce0_device_buf.FromDevice(d0_g_m_device_result.mData.data()); reduce1_device_buf.FromDevice(d1_g_m_device_result.mData.data()); - bool c_error = ck::utils::check_err(c_g_m_n_device_result, c_g_m_n_host_result); - bool d0_error = ck::utils::check_err(d0_g_m_device_result, d0_g_m_host_result); - bool d1_error = ck::utils::check_err(d1_g_m_device_result, d1_g_m_host_result); + bool c_error = ck::utils::check_err(c_g_m_n_device_result, + c_g_m_n_host_result, + "Error: Device and Host results do not match!", + get_rtol(), + get_atol()); + bool d0_error = ck::utils::check_err(d0_g_m_device_result, + d0_g_m_host_result, + "Error: Device and Host results do not match!", + get_rtol(), + get_atol()); + bool d1_error = ck::utils::check_err(d1_g_m_device_result, + d1_g_m_host_result, + "Error: Device and Host results do not match!", + get_rtol(), + get_atol()); pass = pass && (c_error == true); pass = pass && (d0_error == true); diff --git a/test/batched_gemm_reduce/CMakeLists.txt b/test/batched_gemm_reduce/CMakeLists.txt index 4348c4b536..b2765148dd 100644 --- a/test/batched_gemm_reduce/CMakeLists.txt +++ b/test/batched_gemm_reduce/CMakeLists.txt @@ -1,7 +1,9 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -add_test_executable(test_batched_gemm_reduce_fp16 batched_gemm_reduce_fp16_xdl.cpp) -if(result EQUAL 0) - target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE utility device_batched_gemm_reduce_instance) - endif() +if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") + add_gtest_executable(test_batched_gemm_reduce_fp16 batched_gemm_reduce_fp16.cpp) + if(result EQUAL 0) + target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE utility device_batched_gemm_reduce_instance) + endif() +endif() diff --git a/test/batched_gemm_reduce/batched_gemm_reduce_fp16.cpp b/test/batched_gemm_reduce/batched_gemm_reduce_fp16.cpp new file mode 100644 index 0000000000..71cd12e534 --- /dev/null +++ b/test/batched_gemm_reduce/batched_gemm_reduce_fp16.cpp @@ -0,0 +1,119 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include +#include + +#include + +#include "profiler/profile_batched_gemm_reduce_impl.hpp" + +static ck::index_t param_mask = 0xffff; +static ck::index_t instance_index = -1; +struct GemmParams +{ + ck::index_t M; + ck::index_t N; + ck::index_t K; + ck::index_t BatchCount; +}; + +class TestBatchedGemmReduce : public ::testing::Test +{ + protected: + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + std::vector params; + + bool Run() + { + bool pass = true; + for(size_t i = 0; i < params.size(); i++) + { + if((param_mask & (1 << i)) == 0) + { + continue; + } + const auto& param = params[i]; + const auto M = param.M; + const auto N = param.N; + const auto K = param.K; + const auto BatchCount = param.BatchCount; + + pass = pass && ck::profiler::profile_batched_gemm_reduce_impl( + true, 1, false, false, M, N, K, K, N, N, BatchCount); + + pass = pass && ck::profiler::profile_batched_gemm_reduce_impl( + true, 1, false, false, M, N, K, K, K, N, BatchCount); + + pass = pass && ck::profiler::profile_batched_gemm_reduce_impl( + true, 1, false, false, M, N, K, M, N, N, BatchCount); + + pass = pass && ck::profiler::profile_batched_gemm_reduce_impl( + true, 1, false, false, M, N, K, M, K, N, BatchCount); + } + return pass; + } +}; + +#ifdef CK_ENABLE_FP16 +TEST_F(TestBatchedGemmReduce, fp16) +{ + this->params.push_back({64, 64, 64, 2}); + this->params.push_back({64, 64, 64, 1}); + this->params.push_back({40, 40, 40, 2}); + this->params.push_back({256, 256, 128, 3}); + + // Tests with larger MNK + this->params.push_back({512, 256, 128, 1}); + this->params.push_back({256, 240, 192, 2}); + this->params.push_back({256, 256, 128, 3}); + this->params.push_back({240, 128, 128, 5}); + EXPECT_TRUE(this->Run()); +} +#endif + +int main(int argc, char** argv) +{ + testing::InitGoogleTest(&argc, argv); + if(argc == 1) {} + else if(argc == 3) + { + param_mask = strtol(argv[1], nullptr, 0); + instance_index = atoi(argv[2]); + } + else + { + std::cout << "Usage of " << argv[0] << std::endl; + std::cout << "Arg1,2: param_mask instance_index(-1 means all)" << std::endl; + } + return RUN_ALL_TESTS(); +} diff --git a/test/batched_gemm_reduce/batched_gemm_reduce_fp16_xdl.cpp b/test/batched_gemm_reduce/batched_gemm_reduce_fp16_xdl.cpp deleted file mode 100644 index 8e4c60d545..0000000000 --- a/test/batched_gemm_reduce/batched_gemm_reduce_fp16_xdl.cpp +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include - -#include "profiler/profile_batched_gemm_reduce_impl.hpp" - -int main() -{ - using Row = ck::tensor_layout::gemm::RowMajor; - using Col = ck::tensor_layout::gemm::ColumnMajor; - - int M = 512; - int N = 256; - int K = 128; - - int BatchCount = 3; - - bool pass = true; - - pass = pass && ck::profiler::profile_batched_gemm_reduce_impl( - true, 1, false, false, M, N, K, K, N, N, BatchCount); - - pass = pass && ck::profiler::profile_batched_gemm_reduce_impl( - true, 1, false, false, M, N, K, K, K, N, BatchCount); - - pass = pass && ck::profiler::profile_batched_gemm_reduce_impl( - true, 1, false, false, M, N, K, M, N, N, BatchCount); - - pass = pass && ck::profiler::profile_batched_gemm_reduce_impl( - true, 1, false, false, M, N, K, M, K, N, BatchCount); - - if(pass) - { - std::cout << "test BatchedGEMM+Reduce fp16: Pass" << std::endl; - return 0; - } - else - { - std::cout << "test BatchedGEMM+Reduce fp16: Fail" << std::endl; - return -1; - } -} From 6300ad3c62298dc6fdddfcf19ecd074f7f08fa96 Mon Sep 17 00:00:00 2001 From: music-dino <111048524+music-dino@users.noreply.github.com> Date: Tue, 20 Jan 2026 16:25:30 +0100 Subject: [PATCH 29/99] Batched gemm softmax gemm descriptor fix (#3564) * Add rocm to prefix path for codegen * Fix issue with c0_matrix_mask construction --- .../impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp index 35b2f54f58..e3a990bcb1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp @@ -1059,7 +1059,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle c_grid_desc_m_n)}, has_main_k_block_loop{GridwiseGemm64::CalculateHasMainKBlockLoop( a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2))}, - c0_matrix_mask{c.GetLength(I1)}, + c0_matrix_mask{b.GetLength(I0)}, a_element_op{a_element_op_}, b_element_op{b_element_op_}, b1_element_op{b1_element_op_}, From 4d58c70e6cf76ce6cb40aa6035ebccbb28493f71 Mon Sep 17 00:00:00 2001 From: Cong Ma <142121551+CongMa13@users.noreply.github.com> Date: Tue, 20 Jan 2026 10:01:33 -0700 Subject: [PATCH 30/99] [CK TILE GEMM] Add bf8 support to tile engine streamk generator (#3543) --- tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py b/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py index bea46de067..d7aaa6121a 100644 --- a/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py @@ -307,6 +307,7 @@ class GemmKernelBuilder: "fp16": "ck_tile::fp16_t", "fp8": "ck_tile::fp8_t", "bf16": "ck_tile::bf16_t", + "bf8": "ck_tile::bf8_t", "fp32": "float", "fp64": "double", } @@ -776,7 +777,7 @@ def main(): parser.add_argument( "--datatype", required=True, - choices=["fp16", "fp8", "bf16", "fp32", "fp64"], + choices=["fp16", "fp8", "bf16", "bf8", "fp32", "fp64"], help="Data type", ) parser.add_argument( From 7d8bca7ddcff71281e4c75630e97a5e63cee057e Mon Sep 17 00:00:00 2001 From: Estevan Vedovelli Date: Tue, 20 Jan 2026 12:39:57 -0500 Subject: [PATCH 31/99] Add support to fp16 + compute fp16 and bf16 + compute bf16 contractions (#3598) * Add support to fp16 + compute fp16 and bf16 + compute bf16 contractions Enables hipTensor to access the WMMA HW functionalities for these combinations of datatype on gfx11 and gfx12. * Fix change to contraction scale tests * Fix clang-format --- example/26_contraction/CMakeLists.txt | 12 + .../contraction_bilinear_xdl_bf16.cpp | 86 ++++++ .../contraction_bilinear_xdl_fp16.cpp | 86 ++++++ .../contraction_scale_xdl_bf16.cpp | 85 ++++++ .../contraction_scale_xdl_fp16.cpp | 85 ++++++ .../run_contraction_bilinear_example.inc | 21 +- .../run_contraction_scale_example.inc | 21 +- .../gpu/contraction_bilinear.hpp | 260 +++++++++++++++++- .../gpu/contraction_scale.hpp | 260 +++++++++++++++++- ...ffle_bf16_bf16_bf16_bf16_kknn_instance.cpp | 58 ++++ ...ffle_bf16_bf16_bf16_bf16_knnn_instance.cpp | 58 ++++ ...ffle_bf16_bf16_bf16_bf16_mknn_instance.cpp | 58 ++++ ...ffle_bf16_bf16_bf16_bf16_mnnn_instance.cpp | 58 ++++ ..._shuffle_f16_f16_f16_f16_kknn_instance.cpp | 58 ++++ ..._shuffle_f16_f16_f16_f16_knnn_instance.cpp | 58 ++++ ..._shuffle_f16_f16_f16_f16_mknn_instance.cpp | 58 ++++ ..._shuffle_f16_f16_f16_f16_mnnn_instance.cpp | 58 ++++ ...ffle_bf16_bf16_bf16_bf16_kknn_instance.cpp | 58 ++++ ...ffle_bf16_bf16_bf16_bf16_knnn_instance.cpp | 58 ++++ ...ffle_bf16_bf16_bf16_bf16_mknn_instance.cpp | 58 ++++ ...ffle_bf16_bf16_bf16_bf16_mnnn_instance.cpp | 58 ++++ ..._shuffle_f16_f16_f16_f16_kknn_instance.cpp | 60 ++++ ..._shuffle_f16_f16_f16_f16_knnn_instance.cpp | 58 ++++ ..._shuffle_f16_f16_f16_f16_mknn_instance.cpp | 58 ++++ ..._shuffle_f16_f16_f16_f16_mnnn_instance.cpp | 58 ++++ .../gpu/contraction_bilinear/CMakeLists.txt | 10 + ..._c_shuffle_bf16_bf16_bf16_kkn_instance.cpp | 57 ++++ ..._c_shuffle_bf16_bf16_bf16_knn_instance.cpp | 57 ++++ ..._c_shuffle_bf16_bf16_bf16_mkn_instance.cpp | 57 ++++ ..._c_shuffle_bf16_bf16_bf16_mnn_instance.cpp | 57 ++++ ...xdl_c_shuffle_f16_f16_f16_kkn_instance.cpp | 57 ++++ ...xdl_c_shuffle_f16_f16_f16_knn_instance.cpp | 57 ++++ ...xdl_c_shuffle_f16_f16_f16_mkn_instance.cpp | 57 ++++ ...xdl_c_shuffle_f16_f16_f16_mnn_instance.cpp | 57 ++++ ..._c_shuffle_bf16_bf16_bf16_kkn_instance.cpp | 57 ++++ ..._c_shuffle_bf16_bf16_bf16_knn_instance.cpp | 57 ++++ ..._c_shuffle_bf16_bf16_bf16_mkn_instance.cpp | 57 ++++ ..._c_shuffle_bf16_bf16_bf16_mnn_instance.cpp | 57 ++++ ...xdl_c_shuffle_f16_f16_f16_kkn_instance.cpp | 57 ++++ ...xdl_c_shuffle_f16_f16_f16_knn_instance.cpp | 57 ++++ ...xdl_c_shuffle_f16_f16_f16_mkn_instance.cpp | 57 ++++ ...xdl_c_shuffle_f16_f16_f16_mnn_instance.cpp | 57 ++++ .../gpu/contraction_scale/CMakeLists.txt | 10 + test/contraction/test_contraction_xdl.cpp | 8 +- 44 files changed, 2762 insertions(+), 24 deletions(-) create mode 100644 example/26_contraction/contraction_bilinear_xdl_bf16.cpp create mode 100644 example/26_contraction/contraction_bilinear_xdl_fp16.cpp create mode 100644 example/26_contraction/contraction_scale_xdl_bf16.cpp create mode 100644 example/26_contraction/contraction_scale_xdl_fp16.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/2D/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_kknn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/2D/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_knnn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/2D/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_mknn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/2D/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_mnnn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/2D/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_kknn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/2D/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_knnn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/2D/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_mknn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/2D/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_mnnn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_kknn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_knnn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_mknn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_mnnn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_kknn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_knnn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_mknn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_mnnn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/2D/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_kkn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/2D/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_knn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/2D/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_mkn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/2D/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_mnn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/2D/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_kkn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/2D/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_knn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/2D/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_mkn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/2D/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_mnn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_kkn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_knn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_mkn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_mnn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_kkn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_knn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_mkn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_mnn_instance.cpp diff --git a/example/26_contraction/CMakeLists.txt b/example/26_contraction/CMakeLists.txt index 4a41bc5e65..c8a516bae6 100644 --- a/example/26_contraction/CMakeLists.txt +++ b/example/26_contraction/CMakeLists.txt @@ -38,16 +38,28 @@ add_example_executable(example_contraction_scale_xdl_fp64_compute_fp32 contracti add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp64_compute_fp32) # FP16 +add_example_executable(example_contraction_bilinear_xdl_fp16 contraction_bilinear_xdl_fp16.cpp) +add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp16) + add_example_executable(example_contraction_bilinear_xdl_fp16_compute_fp32 contraction_bilinear_xdl_fp16_compute_fp32.cpp) add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp16_compute_fp32) +add_example_executable(example_contraction_scale_xdl_fp16 contraction_scale_xdl_fp16.cpp) +add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp16) + add_example_executable(example_contraction_scale_xdl_fp16_compute_fp32 contraction_scale_xdl_fp16_compute_fp32.cpp) add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp16_compute_fp32) # BF16 +add_example_executable(example_contraction_bilinear_xdl_bf16 contraction_bilinear_xdl_bf16.cpp) +add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_bf16) + add_example_executable(example_contraction_bilinear_xdl_bf16_compute_fp32 contraction_bilinear_xdl_bf16_compute_fp32.cpp) add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_bf16_compute_fp32) +add_example_executable(example_contraction_scale_xdl_bf16 contraction_scale_xdl_bf16.cpp) +add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_bf16) + add_example_executable(example_contraction_scale_xdl_bf16_compute_fp32 contraction_scale_xdl_bf16_compute_fp32.cpp) add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_bf16_compute_fp32) diff --git a/example/26_contraction/contraction_bilinear_xdl_bf16.cpp b/example/26_contraction/contraction_bilinear_xdl_bf16.cpp new file mode 100644 index 0000000000..8899b54fbf --- /dev/null +++ b/example/26_contraction/contraction_bilinear_xdl_bf16.cpp @@ -0,0 +1,86 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "common_instances.hpp" + +using ADataType = BF16; +using BDataType = BF16; +using AccDataType = F32; +using CShuffleDataType = BF16; +using DDataType = BF16; +using DsDataType = ck::Tuple; +using EDataType = BF16; +using ComputeDataType = BF16; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; + +using DeviceOpInstanceKKNN = DeviceOpInstanceKK_Generic; + +using DeviceOpInstanceKNNN = DeviceOpInstanceKN_Generic; + +using DeviceOpInstanceMKNN = DeviceOpInstanceMK_Generic; + +using DeviceOpInstanceMNNN = DeviceOpInstanceMN_Generic; + +using DeviceOpInstance = DeviceOpInstanceKKNN; + +#include "run_contraction_bilinear_example.inc" + +int main(int argc, char* argv[]) { return run_contraction_bilinear_example(argc, argv); } diff --git a/example/26_contraction/contraction_bilinear_xdl_fp16.cpp b/example/26_contraction/contraction_bilinear_xdl_fp16.cpp new file mode 100644 index 0000000000..16e33e0886 --- /dev/null +++ b/example/26_contraction/contraction_bilinear_xdl_fp16.cpp @@ -0,0 +1,86 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "common_instances.hpp" + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F16; +using DDataType = F16; +using DsDataType = ck::Tuple; +using EDataType = F16; +using ComputeDataType = F16; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; + +using DeviceOpInstanceKKNN = DeviceOpInstanceKK_Generic; + +using DeviceOpInstanceKNNN = DeviceOpInstanceKN_Generic; + +using DeviceOpInstanceMKNN = DeviceOpInstanceMK_Generic; + +using DeviceOpInstanceMNNN = DeviceOpInstanceMN_Generic; + +using DeviceOpInstance = DeviceOpInstanceKKNN; + +#include "run_contraction_bilinear_example.inc" + +int main(int argc, char* argv[]) { return run_contraction_bilinear_example(argc, argv); } diff --git a/example/26_contraction/contraction_scale_xdl_bf16.cpp b/example/26_contraction/contraction_scale_xdl_bf16.cpp new file mode 100644 index 0000000000..586b022397 --- /dev/null +++ b/example/26_contraction/contraction_scale_xdl_bf16.cpp @@ -0,0 +1,85 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "common_instances.hpp" + +using ADataType = BF16; +using BDataType = BF16; +using AccDataType = F32; +using CShuffleDataType = BF16; +using DsDataType = ck::Tuple<>; +using EDataType = BF16; +using ComputeDataType = BF16; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Scale; + +using DeviceOpInstanceKKN = DeviceOpInstanceKK_Generic; + +using DeviceOpInstanceKNN = DeviceOpInstanceKN_Generic; + +using DeviceOpInstanceMKN = DeviceOpInstanceMK_Generic; + +using DeviceOpInstanceMNN = DeviceOpInstanceMN_Generic; + +using DeviceOpInstance = DeviceOpInstanceKKN; + +#include "run_contraction_scale_example.inc" + +int main(int argc, char* argv[]) { return run_contraction_scale_example(argc, argv); } diff --git a/example/26_contraction/contraction_scale_xdl_fp16.cpp b/example/26_contraction/contraction_scale_xdl_fp16.cpp new file mode 100644 index 0000000000..1f29e16223 --- /dev/null +++ b/example/26_contraction/contraction_scale_xdl_fp16.cpp @@ -0,0 +1,85 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "common_instances.hpp" + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F16; +using DsDataType = ck::Tuple<>; +using EDataType = F16; +using ComputeDataType = F16; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Scale; + +using DeviceOpInstanceKKN = DeviceOpInstanceKK_Generic; + +using DeviceOpInstanceKNN = DeviceOpInstanceKN_Generic; + +using DeviceOpInstanceMKN = DeviceOpInstanceMK_Generic; + +using DeviceOpInstanceMNN = DeviceOpInstanceMN_Generic; + +using DeviceOpInstance = DeviceOpInstanceKKN; + +#include "run_contraction_scale_example.inc" + +int main(int argc, char* argv[]) { return run_contraction_scale_example(argc, argv); } diff --git a/example/26_contraction/run_contraction_bilinear_example.inc b/example/26_contraction/run_contraction_bilinear_example.inc index 69eb42defd..08ed098b66 100644 --- a/example/26_contraction/run_contraction_bilinear_example.inc +++ b/example/26_contraction/run_contraction_bilinear_example.inc @@ -235,13 +235,20 @@ int run_contraction_bilinear_example(int argc, char* argv[]) if(ck::is_gfx11_supported()) { - return ck::utils::check_err(e_ms_ns_device_result, - e_ms_ns_host_result, - "Error: Incorrect results!", - 1e-4, - 1e-4) - ? 0 - : 1; + if constexpr(std::is_same_v) + { + return ck::utils::check_err(e_ms_ns_device_result, + e_ms_ns_host_result, + "Error: Incorrect results!", + 1e-4, + 1e-4) + ? 0 + : 1; + } + else + { + return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; + } } else { diff --git a/example/26_contraction/run_contraction_scale_example.inc b/example/26_contraction/run_contraction_scale_example.inc index a7451fab71..a5bcd8d447 100644 --- a/example/26_contraction/run_contraction_scale_example.inc +++ b/example/26_contraction/run_contraction_scale_example.inc @@ -218,13 +218,20 @@ int run_contraction_scale_example(int argc, char* argv[]) if(ck::is_gfx11_supported()) { - return ck::utils::check_err(e_ms_ns_device_result, - e_ms_ns_host_result, - "Error: Incorrect results!", - 1e-4, - 1e-4) - ? 0 - : 1; + if constexpr(std::is_same_v) + { + return ck::utils::check_err(e_ms_ns_device_result, + e_ms_ns_host_result, + "Error: Incorrect results!", + 1e-4, + 1e-4) + ? 0 + : 1; + } + else + { + return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; + } } else { diff --git a/library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp b/library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp index 02cf3df942..0d799bf15d 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp @@ -282,6 +282,58 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_comp #endif // CK_ENABLE_FP64 #ifdef CK_ENABLE_FP16 +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_kknn_instance( + std::vector>>& instances); + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_knnn_instance( + std::vector>>& instances); + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_mknn_instance( + std::vector>>& instances); + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_mnnn_instance( + std::vector>>& instances); + void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_kknn_instance( std::vector>>& instances); + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_knnn_instance( + std::vector>>& instances); + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_mknn_instance( + std::vector>>& instances); + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_mnnn_instance( + std::vector>>& instances); + void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance( std::vector>>& instances); + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_knnn_instance( + std::vector>>& instances); + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_mknn_instance( + std::vector>>& instances); + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_mnnn_instance( + std::vector>>& instances); + void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_kknn_instance( std::vector>>& instances); + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_knnn_instance( + std::vector>>& instances); + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_mknn_instance( + std::vector>>& instances); + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_mnnn_instance( + std::vector>>& instances); + void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance( std::vector) + if constexpr(is_same_v) + { + add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_kknn_instance( + op_ptrs); + add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_knnn_instance( + op_ptrs); + add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_mknn_instance( + op_ptrs); + add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_mnnn_instance( + op_ptrs); + } + else if constexpr(is_same_v) { add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_kknn_instance( op_ptrs); @@ -952,7 +1171,18 @@ struct DeviceOperationInstanceFactory) + if constexpr(is_same_v) + { + add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_kknn_instance( + op_ptrs); + add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_knnn_instance( + op_ptrs); + add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_mknn_instance( + op_ptrs); + add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_mnnn_instance( + op_ptrs); + } + else if constexpr(is_same_v) { add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_kknn_instance( op_ptrs); @@ -972,7 +1202,18 @@ struct DeviceOperationInstanceFactory) + if constexpr(is_same_v) + { + add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_kknn_instance( + op_ptrs); + add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_knnn_instance( + op_ptrs); + add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_mknn_instance( + op_ptrs); + add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_mnnn_instance( + op_ptrs); + } + else if constexpr(is_same_v) { add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance( op_ptrs); @@ -986,7 +1227,18 @@ struct DeviceOperationInstanceFactory) + if constexpr(is_same_v) + { + add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_kknn_instance( + op_ptrs); + add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_knnn_instance( + op_ptrs); + add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_mknn_instance( + op_ptrs); + add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_mnnn_instance( + op_ptrs); + } + else if constexpr(is_same_v) { add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance( op_ptrs); diff --git a/library/include/ck/library/tensor_operation_instance/gpu/contraction_scale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/contraction_scale.hpp index 50b9f33f9a..7945d409b3 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/contraction_scale.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/contraction_scale.hpp @@ -282,6 +282,58 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32 #endif // CK_ENABLE_FP64 #ifdef CK_ENABLE_FP16 +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_kkn_instance( + std::vector>>& instances); + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_knn_instance( + std::vector>>& instances); + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_mkn_instance( + std::vector>>& instances); + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_mnn_instance( + std::vector>>& instances); + void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_kkn_instance( std::vector>>& instances); + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_knn_instance( + std::vector>>& instances); + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_mkn_instance( + std::vector>>& instances); + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_mnn_instance( + std::vector>>& instances); + void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_kkn_instance( std::vector>>& instances); + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_knn_instance( + std::vector>>& instances); + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_mkn_instance( + std::vector>>& instances); + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_mnn_instance( + std::vector>>& instances); + void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_compute_f32_kkn_instance( std::vector>>& instances); + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_knn_instance( + std::vector>>& instances); + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_mkn_instance( + std::vector>>& instances); + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_mnn_instance( + std::vector>>& instances); + void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_kkn_instance( std::vector>>& instances); -#endif // CK_ENABLE_FP16 +#endif // CK_ENABLE_BF16 // Contraction + Scale template ) + if constexpr(is_same_v) + { + add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_kkn_instance( + op_ptrs); + add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_knn_instance( + op_ptrs); + add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_mkn_instance( + op_ptrs); + add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_mnn_instance( + op_ptrs); + } + else if constexpr(is_same_v) { add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_kkn_instance( op_ptrs); @@ -951,7 +1170,18 @@ struct DeviceOperationInstanceFactory) + if constexpr(is_same_v) + { + add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_kkn_instance( + op_ptrs); + add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_knn_instance( + op_ptrs); + add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_mkn_instance( + op_ptrs); + add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_mnn_instance( + op_ptrs); + } + else if constexpr(is_same_v) { add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_compute_f32_kkn_instance( op_ptrs); @@ -971,7 +1201,18 @@ struct DeviceOperationInstanceFactory) + if constexpr(is_same_v) + { + add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_kkn_instance( + op_ptrs); + add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_knn_instance( + op_ptrs); + add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_mkn_instance( + op_ptrs); + add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_mnn_instance( + op_ptrs); + } + else if constexpr(is_same_v) { add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_kkn_instance( op_ptrs); @@ -985,6 +1226,17 @@ struct DeviceOperationInstanceFactory) + { + add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_kkn_instance( + op_ptrs); + add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_knn_instance( + op_ptrs); + add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_mkn_instance( + op_ptrs); + add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_mnn_instance( + op_ptrs); + } if constexpr(is_same_v) { add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_kkn_instance( diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/2D/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_kknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/2D/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_kknn_instance.cpp new file mode 100644 index 0000000000..ce57ee2d07 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/2D/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_kknn_instance.cpp @@ -0,0 +1,58 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_kknn_instance = + device_contraction_kk_instance; + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_kknn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_kknn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/2D/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_knnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/2D/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_knnn_instance.cpp new file mode 100644 index 0000000000..e1e5dbb434 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/2D/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_knnn_instance.cpp @@ -0,0 +1,58 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_knnn_instance = + device_contraction_kn_instance; + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_knnn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_knnn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/2D/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_mknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/2D/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_mknn_instance.cpp new file mode 100644 index 0000000000..db98406390 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/2D/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_mknn_instance.cpp @@ -0,0 +1,58 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_mknn_instance = + device_contraction_mk_instance; + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_mknn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_mknn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/2D/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_mnnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/2D/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_mnnn_instance.cpp new file mode 100644 index 0000000000..5c7032e854 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/2D/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_mnnn_instance.cpp @@ -0,0 +1,58 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_mnnn_instance = + device_contraction_mn_instance; + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_mnnn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_mnnn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/2D/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_kknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/2D/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_kknn_instance.cpp new file mode 100644 index 0000000000..a0c8376980 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/2D/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_kknn_instance.cpp @@ -0,0 +1,58 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_kknn_instance = + device_contraction_kk_instance; + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_kknn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_kknn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/2D/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_knnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/2D/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_knnn_instance.cpp new file mode 100644 index 0000000000..0798f7a9b6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/2D/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_knnn_instance.cpp @@ -0,0 +1,58 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_knnn_instance = + device_contraction_kn_instance; + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_knnn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_knnn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/2D/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_mknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/2D/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_mknn_instance.cpp new file mode 100644 index 0000000000..7da8371482 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/2D/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_mknn_instance.cpp @@ -0,0 +1,58 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_mknn_instance = + device_contraction_mk_instance; + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_mknn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_mknn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/2D/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_mnnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/2D/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_mnnn_instance.cpp new file mode 100644 index 0000000000..49267e0867 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/2D/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_mnnn_instance.cpp @@ -0,0 +1,58 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_mnnn_instance = + device_contraction_mn_instance; + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_mnnn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_mnnn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_kknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_kknn_instance.cpp new file mode 100644 index 0000000000..77fae91ffe --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_kknn_instance.cpp @@ -0,0 +1,58 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_kknn_instance = + device_contraction_kk_instance; + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_kknn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_kknn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_knnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_knnn_instance.cpp new file mode 100644 index 0000000000..9b8cacc5e1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_knnn_instance.cpp @@ -0,0 +1,58 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_knnn_instance = + device_contraction_kn_instance; + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_knnn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_knnn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_mknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_mknn_instance.cpp new file mode 100644 index 0000000000..50a7645256 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_mknn_instance.cpp @@ -0,0 +1,58 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_mknn_instance = + device_contraction_mk_instance; + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_mknn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_mknn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_mnnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_mnnn_instance.cpp new file mode 100644 index 0000000000..78aa99fa6e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_mnnn_instance.cpp @@ -0,0 +1,58 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_mnnn_instance = + device_contraction_mn_instance; + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_mnnn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_mnnn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_kknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_kknn_instance.cpp new file mode 100644 index 0000000000..e738e54f06 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_kknn_instance.cpp @@ -0,0 +1,60 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_kknn_instance = + device_contraction_kk_instance; + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_kknn_instance( + std::vector>>& instances) +{ + printf("[CK_DEBUG] f16+f16+f16+f16_kknn_instance: before add, size=%zu\n", instances.size()); + add_device_operation_instances( + instances, + device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_kknn_instance{}); + printf("[CK_DEBUG] f16+f16+f16+f16_kknn_instance: after add, size=%zu\n", instances.size()); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_knnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_knnn_instance.cpp new file mode 100644 index 0000000000..4bc5b1684a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_knnn_instance.cpp @@ -0,0 +1,58 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_knnn_instance = + device_contraction_kn_instance; + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_knnn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_knnn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_mknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_mknn_instance.cpp new file mode 100644 index 0000000000..e320fbe11a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_mknn_instance.cpp @@ -0,0 +1,58 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_mknn_instance = + device_contraction_mk_instance; + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_mknn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_mknn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_mnnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_mnnn_instance.cpp new file mode 100644 index 0000000000..bbb90a6af4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_mnnn_instance.cpp @@ -0,0 +1,58 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_mnnn_instance = + device_contraction_mn_instance; + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_mnnn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_mnnn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt index 9850882c55..b9cde18e24 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt @@ -37,12 +37,22 @@ foreach(idx IN LISTS DIMS) ${PREFIX}_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mnnn_instance.cpp) # FP16 + list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES ${PREFIX}_xdl_c_shuffle_f16_f16_f16_f16_kknn_instance.cpp + ${PREFIX}_xdl_c_shuffle_f16_f16_f16_f16_knnn_instance.cpp + ${PREFIX}_xdl_c_shuffle_f16_f16_f16_f16_mknn_instance.cpp + ${PREFIX}_xdl_c_shuffle_f16_f16_f16_f16_mnnn_instance.cpp) + list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES ${PREFIX}_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_kknn_instance.cpp ${PREFIX}_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_knnn_instance.cpp ${PREFIX}_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mknn_instance.cpp ${PREFIX}_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mnnn_instance.cpp) # BF16 + list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES ${PREFIX}_xdl_c_shuffle_bf16_bf16_bf16_bf16_kknn_instance.cpp + ${PREFIX}_xdl_c_shuffle_bf16_bf16_bf16_bf16_knnn_instance.cpp + ${PREFIX}_xdl_c_shuffle_bf16_bf16_bf16_bf16_mknn_instance.cpp + ${PREFIX}_xdl_c_shuffle_bf16_bf16_bf16_bf16_mnnn_instance.cpp) + list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES ${PREFIX}_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance.cpp ${PREFIX}_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_knnn_instance.cpp ${PREFIX}_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mknn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/2D/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_kkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/2D/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_kkn_instance.cpp new file mode 100644 index 0000000000..c85f8cc998 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/2D/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_kkn_instance.cpp @@ -0,0 +1,57 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_kkn_instance = + device_contraction_kk_instance; + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_kkn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_kkn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/2D/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_knn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/2D/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_knn_instance.cpp new file mode 100644 index 0000000000..d4a25d40cb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/2D/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_knn_instance.cpp @@ -0,0 +1,57 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_knn_instance = + device_contraction_kn_instance; + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_knn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_knn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/2D/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_mkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/2D/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_mkn_instance.cpp new file mode 100644 index 0000000000..7be8a0a694 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/2D/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_mkn_instance.cpp @@ -0,0 +1,57 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_mkn_instance = + device_contraction_mk_instance; + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_mkn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_mkn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/2D/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_mnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/2D/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_mnn_instance.cpp new file mode 100644 index 0000000000..b2a4c020e6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/2D/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_mnn_instance.cpp @@ -0,0 +1,57 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_mnn_instance = + device_contraction_mn_instance; + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_mnn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_mnn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/2D/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_kkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/2D/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_kkn_instance.cpp new file mode 100644 index 0000000000..52042dd045 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/2D/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_kkn_instance.cpp @@ -0,0 +1,57 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_kkn_instance = + device_contraction_kk_instance; + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_kkn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_kkn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/2D/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_knn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/2D/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_knn_instance.cpp new file mode 100644 index 0000000000..2b6aed8ed4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/2D/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_knn_instance.cpp @@ -0,0 +1,57 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_knn_instance = + device_contraction_kn_instance; + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_knn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_knn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/2D/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_mkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/2D/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_mkn_instance.cpp new file mode 100644 index 0000000000..07cbbf87c6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/2D/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_mkn_instance.cpp @@ -0,0 +1,57 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_mkn_instance = + device_contraction_mk_instance; + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_mkn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_mkn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/2D/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_mnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/2D/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_mnn_instance.cpp new file mode 100644 index 0000000000..2cc4bfb718 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/2D/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_mnn_instance.cpp @@ -0,0 +1,57 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_mnn_instance = + device_contraction_mn_instance; + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_mnn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_mnn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_kkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_kkn_instance.cpp new file mode 100644 index 0000000000..9244f6a132 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_kkn_instance.cpp @@ -0,0 +1,57 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_kkn_instance = + device_contraction_kk_instance; + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_kkn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_kkn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_knn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_knn_instance.cpp new file mode 100644 index 0000000000..99e80e0e28 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_knn_instance.cpp @@ -0,0 +1,57 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_knn_instance = + device_contraction_kn_instance; + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_knn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_knn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_mkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_mkn_instance.cpp new file mode 100644 index 0000000000..77ca8c0d16 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_mkn_instance.cpp @@ -0,0 +1,57 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_mkn_instance = + device_contraction_mk_instance; + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_mkn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_mkn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_mnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_mnn_instance.cpp new file mode 100644 index 0000000000..564fe537bb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_mnn_instance.cpp @@ -0,0 +1,57 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_mnn_instance = + device_contraction_mn_instance; + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_mnn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_mnn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_kkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_kkn_instance.cpp new file mode 100644 index 0000000000..dfc187562a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_kkn_instance.cpp @@ -0,0 +1,57 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_kkn_instance = + device_contraction_kk_instance; + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_kkn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_kkn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_knn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_knn_instance.cpp new file mode 100644 index 0000000000..50d951a99c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_knn_instance.cpp @@ -0,0 +1,57 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_knn_instance = + device_contraction_kn_instance; + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_knn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_knn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_mkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_mkn_instance.cpp new file mode 100644 index 0000000000..460c5c4b49 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_mkn_instance.cpp @@ -0,0 +1,57 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_mkn_instance = + device_contraction_mk_instance; + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_mkn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_mkn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_mnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_mnn_instance.cpp new file mode 100644 index 0000000000..bee17f3386 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_mnn_instance.cpp @@ -0,0 +1,57 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_mnn_instance = + device_contraction_mn_instance; + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_mnn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_mnn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt index a45bea6460..542c7b8200 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt @@ -37,12 +37,22 @@ foreach(idx IN LISTS DIMS) ${PREFIX}_xdl_c_shuffle_f64_f64_f64_compute_f32_mnn_instance.cpp) # FP16 + list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES ${PREFIX}_xdl_c_shuffle_f16_f16_f16_kkn_instance.cpp + ${PREFIX}_xdl_c_shuffle_f16_f16_f16_knn_instance.cpp + ${PREFIX}_xdl_c_shuffle_f16_f16_f16_mkn_instance.cpp + ${PREFIX}_xdl_c_shuffle_f16_f16_f16_mnn_instance.cpp) + list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES ${PREFIX}_xdl_c_shuffle_f16_f16_f16_compute_f32_kkn_instance.cpp ${PREFIX}_xdl_c_shuffle_f16_f16_f16_compute_f32_knn_instance.cpp ${PREFIX}_xdl_c_shuffle_f16_f16_f16_compute_f32_mkn_instance.cpp ${PREFIX}_xdl_c_shuffle_f16_f16_f16_compute_f32_mnn_instance.cpp) # BF16 + list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES ${PREFIX}_xdl_c_shuffle_bf16_bf16_bf16_kkn_instance.cpp + ${PREFIX}_xdl_c_shuffle_bf16_bf16_bf16_knn_instance.cpp + ${PREFIX}_xdl_c_shuffle_bf16_bf16_bf16_mkn_instance.cpp + ${PREFIX}_xdl_c_shuffle_bf16_bf16_bf16_mnn_instance.cpp) + list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES ${PREFIX}_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_kkn_instance.cpp ${PREFIX}_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_knn_instance.cpp ${PREFIX}_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mkn_instance.cpp diff --git a/test/contraction/test_contraction_xdl.cpp b/test/contraction/test_contraction_xdl.cpp index 373aaa2597..70b11c3bdb 100644 --- a/test/contraction/test_contraction_xdl.cpp +++ b/test/contraction/test_contraction_xdl.cpp @@ -121,10 +121,14 @@ class TestContractionBilinear : public TestContraction using BilinearKernelTypes = ::testing::Types, F32, Bilinear), - ALL_LAYOUT_COMBINATIONS(F64, ck::Tuple, F64, Bilinear)>; + ALL_LAYOUT_COMBINATIONS(F64, ck::Tuple, F64, Bilinear), + ALL_LAYOUT_COMBINATIONS(F16, ck::Tuple, F16, Bilinear), + ALL_LAYOUT_COMBINATIONS(BF16, ck::Tuple, BF16, Bilinear)>; using ScaleKernelTypes = ::testing::Types, F32, Scale), - ALL_LAYOUT_COMBINATIONS(F64, ck::Tuple<>, F64, Scale)>; + ALL_LAYOUT_COMBINATIONS(F64, ck::Tuple<>, F64, Scale), + ALL_LAYOUT_COMBINATIONS(F16, ck::Tuple<>, F16, Scale), + ALL_LAYOUT_COMBINATIONS(BF16, ck::Tuple<>, BF16, Scale)>; TYPED_TEST_SUITE(TestContractionBilinear, BilinearKernelTypes); TYPED_TEST_SUITE(TestContractionScale, ScaleKernelTypes); From 8f75869408210cb85e9eb7ff639c4c9dad1331cb Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Wed, 21 Jan 2026 01:40:54 +0800 Subject: [PATCH 32/99] Revert "[CK_TILE][FMHA] Add new tile size for async (#3586)" (#3613) This reverts commit f3aafb95552cc2570f952667848310fbe3e982e7. --- example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 8 +------- .../fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp | 2 -- 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 81c7b067d3..dd65c0298b 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -315,7 +315,7 @@ class FmhaFwdApiTrait: assert False def seqtune(self, max_bm0: int) -> str: - if self.bm0 == max_bm0 or self.bm0 == 64: + if self.bm0 == max_bm0: return "true/*fall back to largest tile*/" else: return f"a.seqlen_q <= {self.bm0}" @@ -847,11 +847,6 @@ class CompatibilityRuleFactoryGfx9(CompatibilityRuleFactory): (problem_ctx.hdim, problem_ctx.hdim_v) != (128, 128) and kernel_ctx.tile.F_bm0 != 128 ) - or ( - (problem_ctx.hdim, problem_ctx.hdim_v) == (128, 128) - and kernel_ctx.pipeline.tag != "qr_async" - and kernel_ctx.tile.F_bk0 == 64 - ) ): # non qr_async_trload only support km0=128 tile size when hdim is not 128 # non qr_async only support kn0=128 tile size when hdim is 128 @@ -947,7 +942,6 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9): ( 96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (128, 128) : [FmhaFwdTileSize( 16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), FmhaFwdTileSize( 32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), - FmhaFwdTileSize( 64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 16, -1, CppConstraint('get_num_blocks(64) <= num_cus')), FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], # (160, 160) : [FmhaFwdTileSize(128, 128 , 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index e30d4215d6..7224ed3a70 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -321,8 +321,6 @@ struct BlockFmhaPipelineQRKSVSAsync { if(num_total_loop <= 0) { - buffer_load_fence(0); // rocm-7.1.1, if whole tile is masked out, need to fence(0) - // otherwise will have compute error(maybe compiler bug?) if constexpr(kStoreLSE) { auto lse = From 91b4102a59c6013d3faeb54f250cf577b2f129ce Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Tue, 20 Jan 2026 10:37:09 -0800 Subject: [PATCH 33/99] Add persistent async input scheduler for GEMM kernels (#3520) Add signal-based synchronization for persistent GEMM kernels where input data becomes available incrementally. Uses modulo wraparound (like PyTorch's AsyncMM) for chunk index calculation: chunk_idx = ((tile_idx + tile_idx_pivot) / tiles_per_chunk) % num_chunks Key components: - PersistentAsyncInputScheduler struct with tiles_per_chunk_m, chunk_signals, tile_idx_pivot_m, and num_chunks fields - wait_eq_wave method using __builtin_amdgcn_s_sleep for power efficiency - IsSupportedArgument validation for scheduler parameters - Example demonstrating async input scheduling with simulated producer - GTest unit tests covering all layout combinations --- CHANGELOG.md | 1 + example/ck_tile/03_gemm/gemm_utils.hpp | 3 +- example/ck_tile/03_gemm/universal_gemm.cpp | 229 +++++++++++-- .../03_gemm/universal_gemm_invoker.hpp | 170 ++++++++++ include/ck_tile/core.hpp | 1 + .../ck_tile/core/arch/workgroup_barrier.hpp | 30 ++ .../persistent_async_input_scheduler.hpp | 49 +++ .../ops/gemm/kernel/universal_gemm_kernel.hpp | 98 ++++-- test/ck_tile/CMakeLists.txt | 1 + .../CMakeLists.txt | 19 ++ .../test_gemm_persistent_async_input.cpp | 304 ++++++++++++++++++ 11 files changed, 844 insertions(+), 61 deletions(-) create mode 100644 include/ck_tile/core/utility/persistent_async_input_scheduler.hpp create mode 100644 test/ck_tile/gemm_persistent_async_input/CMakeLists.txt create mode 100644 test/ck_tile/gemm_persistent_async_input/test_gemm_persistent_async_input.cpp diff --git a/CHANGELOG.md b/CHANGELOG.md index 066dc9aa3b..c3a257e464 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added support for gfx1153 target. * Added FMHA batch prefill kernel support for several KV cache layouts, flexible page sizes, and different lookup table configurations. * Added gpt-oss sink support for FMHA FWD, include qr_ks_vs, qr_async, qr_async_trload and splitkv pipelines. +* Added persistent async input scheduler for CK Tile universal GEMM kernels to support asynchronous input streaming. ### Changed diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index 8eff0e7469..c1df27ecc8 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -456,7 +456,8 @@ inline auto create_args() .insert("json", "0", "0: No Json, 1: Dump Results in Json format") .insert("jsonfile", "gemm.json", "json file name to dump results") .insert("flush_cache", "true", "flush cache before running the kernel, defaults to true") - .insert("rotating_count", "1000", "rotating count, defaults to 1000"); + .insert("rotating_count", "1000", "rotating count, defaults to 1000") + .insert("test_async", "0", "0: normal gemm, 1: test async input scheduler"); return arg_parser; } diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index c1c8a2fc89..ace9152747 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -12,6 +12,169 @@ #include "run_gemm_example_common.hpp" #include "universal_gemm_invoker.hpp" +// Universal GEMM-specific wrapper that handles test_async flag +template +int run_gemm_example_with_layouts_universal(ck_tile::ArgParser& arg_parser, + const ALayout a_layout = ALayout{}, + const BLayout b_layout = BLayout{}, + const CLayout c_layout = CLayout{}) +{ + using Invoker = UniversalInvoker; + using AccDataType = typename GemmTypeConfig::AccDataType; + + // Check for async input scheduler test mode + bool test_async = arg_parser.get_int("test_async"); + if(test_async) + { + // Extract parameters for async test (same as shared implementation) + const ck_tile::index_t M = arg_parser.get_int("m"); + const ck_tile::index_t N = arg_parser.get_int("n"); + const ck_tile::index_t K = arg_parser.get_int("k"); + const ck_tile::index_t kbatch = arg_parser.get_int("split_k"); + + using Row = ck_tile::tensor_layout::gemm::RowMajor; + constexpr bool is_a_row_major = std::is_same_v; + constexpr bool is_b_row_major = std::is_same_v; + constexpr bool is_c_row_major = std::is_same_v; + + const ck_tile::index_t stride_A = is_a_row_major ? K : M; + const ck_tile::index_t stride_B = is_b_row_major ? N : K; + const ck_tile::index_t stride_C = is_c_row_major ? N : M; + + // Allocate and initialize tensors + ck_tile::HostTensor a_m_k(ck_tile::host_tensor_descriptor( + M, K, stride_A, ck_tile::bool_constant{})); + ck_tile::HostTensor b_k_n(ck_tile::host_tensor_descriptor( + K, N, stride_B, ck_tile::bool_constant{})); + ck_tile::HostTensor c_m_n_dev_result(ck_tile::host_tensor_descriptor( + M, N, stride_C, ck_tile::bool_constant{})); + + ck_tile::FillUniformDistributionIntegerValue{-5, 5}(a_m_k); + ck_tile::FillUniformDistributionIntegerValue{-5, 5}(b_k_n); + + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); + + a_m_k_dev_buf.ToDevice(a_m_k.data()); + b_k_n_dev_buf.ToDevice(b_k_n.data()); + c_m_n_dev_buf.SetZero(); + c_m_n_dev_result.SetZero(); + + ck_tile::GemmHostArgs args = {a_m_k_dev_buf.GetDeviceBuffer(), + b_k_n_dev_buf.GetDeviceBuffer(), + c_m_n_dev_buf.GetDeviceBuffer(), + kbatch, + M, + N, + K, + stride_A, + stride_B, + stride_C}; + + Invoker::template test_async_input_scheduler, + AccDataType, + CDataType, + ALayout, + BLayout, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough>( + args, ck_tile::stream_config{nullptr, false, 1}); + + // Copy result from device for verification + c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); + + // Compute CPU reference + ck_tile::HostTensor c_m_n_ref(ck_tile::host_tensor_descriptor( + M, N, stride_C, ck_tile::bool_constant{})); + c_m_n_ref.SetZero(); + ck_tile::reference_gemm( + a_m_k, b_k_n, c_m_n_ref); + + // Verify results + const float max_accumulated_value = + *std::max_element(c_m_n_ref.mData.begin(), c_m_n_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol( + K, kbatch, max_accumulated_value); + bool pass = do_verify(c_m_n_dev_result, c_m_n_ref, rtol_atol, "CPU"); + + std::cout << "Async input scheduler test: " << (pass ? "PASS" : "FAIL") << std::endl; + return pass; + } + + // Normal path - delegate to shared implementation + return run_gemm_example_with_layouts( + arg_parser, a_layout, b_layout, c_layout); +} + +// Universal GEMM-specific prec_type dispatcher that uses the wrapper +template +int run_gemm_example_prec_type_universal(std::string a_layout, + std::string b_layout, + ck_tile::ArgParser& arg_parser) +{ + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + bool preshuffle = GemmConfig::Preshuffle; + + if(preshuffle && std::is_same_v) + { + throw std::runtime_error("Preshuffle is not supported for this int4 datatype!"); + } + + if(preshuffle && a_layout != "R" && b_layout != "C") + { + throw std::runtime_error( + "Preshuffle is supported only for A(Row major), B(column major) input matrices!"); + } + + using LayoutVariant = std::variant; + + auto string_to_layout = [](const std::string& layout) -> LayoutVariant { + if(layout == "R") + return Row{}; + if(layout == "C") + return Col{}; + throw std::runtime_error("Unsupported layout: " + layout); + }; + + auto a_layout_variant = string_to_layout(a_layout); + auto b_layout_variant = string_to_layout(b_layout); + + return std::visit( + [&](auto a_layout_type, auto b_layout_type) -> int { + if constexpr(std::is_same_v && + std::is_same_v) + { + throw std::runtime_error("Unsupported memory layout for the input matrices when " + "BPrecType is ck_tile::pk_int4_t!"); + } + else + { + return run_gemm_example_with_layouts_universal( + arg_parser, a_layout_type, b_layout_type, Row{}); + } + }, + a_layout_variant, + b_layout_variant); +} + template