From b15df372553e0f80a660124f1b558d9cb276bd08 Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Wed, 10 Dec 2025 23:08:41 +0800 Subject: [PATCH 01/10] fix: python 3.8 compatibility in fmha codegen (#3388) --- example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index c00bdcea3b..edc0e049c5 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -770,7 +770,7 @@ def create_kernel( class CompatibilityRuleFactory: @staticmethod - def get_rules() -> list[CompatibilityRule]: + def get_rules() -> List[CompatibilityRule]: # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not def check_mode(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: if problem_ctx.mode == "group": @@ -812,7 +812,7 @@ class CompatibilityRuleFactoryGfx9(CompatibilityRuleFactory): _AVAILABLE_PIPELINES = frozenset({"qr", "qr_async", "qs"}) @classmethod - def get_rules(cls) -> list[CompatibilityRule]: + def get_rules(cls) -> List[CompatibilityRule]: rules = CompatibilityRuleFactory.get_rules() def check_hdim_tile( @@ -846,7 +846,7 @@ class CompatibilityRuleFactoryGfx950(CompatibilityRuleFactoryGfx9): ) @classmethod - def get_rules(cls) -> list[CompatibilityRule]: + def get_rules(cls) -> List[CompatibilityRule]: rules = CompatibilityRuleFactoryGfx9.get_rules() def check_tile_pipeline( From 15ed65db35e6702593cd8ed1d603222fb11684e4 Mon Sep 17 00:00:00 2001 From: John Shumway Date: Wed, 10 Dec 2025 12:25:23 -0800 Subject: [PATCH 02/10] Improve sequence sorting and add unit tests (#3376) Old sequence sort code was showing up on build profiles. Convert it to constexpr functions for much more efficient build-time execution. The sorting is still O(N^2), but our sequences are small enough it executes quickly. This reduced compilation time of a small convolution by more than 10% and time overall time spent in the compiler on a narrow build by %6. --- include/ck/utility/sequence.hpp | 325 ++++++--------- test/CMakeLists.txt | 1 + test/util/CMakeLists.txt | 7 + test/util/unit_sequence.cpp | 684 ++++++++++++++++++++++++++++++++ 4 files changed, 808 insertions(+), 209 deletions(-) create mode 100644 test/util/CMakeLists.txt create mode 100644 test/util/unit_sequence.cpp diff --git a/include/ck/utility/sequence.hpp b/include/ck/utility/sequence.hpp index 9f97d44a4a..6e68690048 100644 --- a/include/ck/utility/sequence.hpp +++ b/include/ck/utility/sequence.hpp @@ -380,236 +380,143 @@ struct sequence_reduce }; #endif -template -struct sequence_sort_impl +// Implement sequence_sort and sequence_unique_sort using constexpr functions (C++17) +namespace sort_impl { + +// Temporary arrays to hold values during operations with capacity N and mutable size. +template +struct IndexedValueArray { - template - struct sorted_sequence_merge_impl - { - static constexpr bool choose_left = LeftValues::Front() < RightValues::Front(); - - static constexpr index_t chosen_value = - choose_left ? LeftValues::Front() : RightValues::Front(); - static constexpr index_t chosen_id = choose_left ? LeftIds::Front() : RightIds::Front(); - - using new_merged_values = decltype(MergedValues::PushBack(Number{})); - using new_merged_ids = decltype(MergedIds::PushBack(Number{})); - - using new_left_values = - typename conditional::type; - using new_left_ids = - typename conditional::type; - - using new_right_values = - typename conditional::type; - using new_right_ids = - typename conditional::type; - - using merge = sorted_sequence_merge_impl; - // this is output - using merged_values = typename merge::merged_values; - using merged_ids = typename merge::merged_ids; - }; - - template - struct sorted_sequence_merge_impl, - Sequence<>, - MergedValues, - MergedIds, - Comp> - { - using merged_values = typename sequence_merge::type; - using merged_ids = typename sequence_merge::type; - }; - - template - struct sorted_sequence_merge_impl, - Sequence<>, - RightValues, - RightIds, - MergedValues, - MergedIds, - Comp> - { - using merged_values = typename sequence_merge::type; - using merged_ids = typename sequence_merge::type; - }; - - template - struct sorted_sequence_merge - { - using merge = sorted_sequence_merge_impl, - Sequence<>, - Comp>; - - using merged_values = typename merge::merged_values; - using merged_ids = typename merge::merged_ids; - }; - - static constexpr index_t nsize = Values::Size(); - - using split_unsorted_values = sequence_split; - using split_unsorted_ids = sequence_split; - - using left_unsorted_values = typename split_unsorted_values::left_type; - using left_unsorted_ids = typename split_unsorted_ids::left_type; - using left_sort = sequence_sort_impl; - using left_sorted_values = typename left_sort::sorted_values; - using left_sorted_ids = typename left_sort::sorted_ids; - - using right_unsorted_values = typename split_unsorted_values::right_type; - using right_unsorted_ids = typename split_unsorted_ids::right_type; - using right_sort = sequence_sort_impl; - using right_sorted_values = typename right_sort::sorted_values; - using right_sorted_ids = typename right_sort::sorted_ids; - - using merged_sorted = sorted_sequence_merge; - - using sorted_values = typename merged_sorted::merged_values; - using sorted_ids = typename merged_sorted::merged_ids; + index_t values[N > 0 ? N : 1]; + index_t ids[N > 0 ? N : 1]; + index_t size = 0; }; -template -struct sequence_sort_impl, Sequence, Compare> +template +constexpr auto make_indexed_value_array(Sequence) { - static constexpr bool choose_x = Compare{}(ValueX, ValueY); + constexpr index_t N = sizeof...(Is); + IndexedValueArray result = {{Is...}, {}, N}; + for(index_t i = 0; i < N; ++i) + { + result.ids[i] = i; + } + return result; +} - using sorted_values = - typename conditional, Sequence>::type; - using sorted_ids = typename conditional, Sequence>::type; +enum class SortField +{ + Values, + Ids }; -template -struct sequence_sort_impl, Sequence, Compare> +// Perform an insertion sort on an IndexedValueArray. +template +constexpr auto insertion_sort(IndexedValueArray arr, Compare comp) { - using sorted_values = Sequence; - using sorted_ids = Sequence; + for(index_t i = 1; i < arr.size; ++i) + { + index_t key_val = arr.values[i]; + index_t key_id = arr.ids[i]; + index_t j = i - 1; + while(j >= 0 && comp(key_val, arr.values[j])) + { + arr.values[j + 1] = arr.values[j]; + arr.ids[j + 1] = arr.ids[j]; + --j; + } + arr.values[j + 1] = key_val; + arr.ids[j + 1] = key_id; + } + return arr; +} + +// Remove duplicates from a sorted IndexedValueArray. +template +constexpr auto unique(const IndexedValueArray& sorted, Equal eq) +{ + IndexedValueArray result{}; + if constexpr(N == 0) + { + return result; + } + result.size = 1; + result.values[0] = sorted.values[0]; + result.ids[0] = sorted.ids[0]; + for(index_t i = 1; i < sorted.size; ++i) + { + if(!eq(sorted.values[i], sorted.values[i - 1])) + { + result.values[result.size] = sorted.values[i]; + result.ids[result.size] = sorted.ids[i]; + ++result.size; + } + } + return result; +} + +// Compute sorted (and optionally unique) IndexedValueArray from input Sequence. +template +constexpr auto compute_sorted(Sequence seq, Compare comp, Equal eq) +{ + auto sorted = insertion_sort(make_indexed_value_array(seq), comp); + return Unique ? unique(sorted, eq) : sorted; +} + +// Cache the sorted results to avoid recomputation. +template +struct SortedCache +{ + static constexpr auto data = compute_sorted(Seq{}, Compare{}, Equal{}); }; -template -struct sequence_sort_impl, Sequence<>, Compare> +// Build sorted value and ID sequences from cached sorted data +template +constexpr index_t get_sorted_field() { - using sorted_values = Sequence<>; - using sorted_ids = Sequence<>; + constexpr auto& data = SortedCache::data; + return (Field == SortField::Values) ? data.values[I] : data.ids[I]; +} + +template +struct SortedSequences; + +template +struct SortedSequences> +{ + using values_type = + Sequence()...>; + using ids_type = + Sequence()...>; }; +template +using sorted_sequences_t = SortedSequences< + Unique, + Seq, + Compare, + Equal, + typename arithmetic_sequence_gen<0, SortedCache::data.size, 1>:: + type>; + +using Equal = ck::math::equal; + +} // namespace sort_impl + template struct sequence_sort { - using unsorted_ids = typename arithmetic_sequence_gen<0, Values::Size(), 1>::type; - using sort = sequence_sort_impl; - - // this is output - using type = typename sort::sorted_values; - using sorted2unsorted_map = typename sort::sorted_ids; + using sorted_seqs = sort_impl::sorted_sequences_t; + using type = typename sorted_seqs::values_type; + using sorted2unsorted_map = typename sorted_seqs::ids_type; }; template struct sequence_unique_sort { - template - struct sorted_sequence_uniquify_impl - { - static constexpr index_t current_value = RemainValues::Front(); - static constexpr index_t current_id = RemainIds::Front(); - - static constexpr bool is_unique_value = (current_value != UniquifiedValues::Back()); - - using new_remain_values = decltype(RemainValues::PopFront()); - using new_remain_ids = decltype(RemainIds::PopFront()); - - using new_uniquified_values = - typename conditional{})), - UniquifiedValues>::type; - - using new_uniquified_ids = - typename conditional{})), - UniquifiedIds>::type; - - using uniquify = sorted_sequence_uniquify_impl; - - // this is output - using uniquified_values = typename uniquify::uniquified_values; - using uniquified_ids = typename uniquify::uniquified_ids; - }; - - template - struct sorted_sequence_uniquify_impl, - Sequence<>, - UniquifiedValues, - UniquifiedIds, - Eq> - { - using uniquified_values = UniquifiedValues; - using uniquified_ids = UniquifiedIds; - }; - - template - struct sorted_sequence_uniquify - { - using uniquify = sorted_sequence_uniquify_impl, - Sequence, - Eq>; - - using uniquified_values = typename uniquify::uniquified_values; - using uniquified_ids = typename uniquify::uniquified_ids; - }; - - using sort = sequence_sort; - using sorted_values = typename sort::type; - using sorted_ids = typename sort::sorted2unsorted_map; - - using uniquify = sorted_sequence_uniquify; - - // this is output - using type = typename uniquify::uniquified_values; - using sorted2unsorted_map = typename uniquify::uniquified_ids; + using sorted_seqs = sort_impl::sorted_sequences_t; + using type = typename sorted_seqs::values_type; + using sorted2unsorted_map = typename sorted_seqs::ids_type; }; template diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index c221f11f46..b7db14945d 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -310,3 +310,4 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx12") endif() add_subdirectory(position_embedding) add_subdirectory(scatter_gather) +add_subdirectory(util) diff --git a/test/util/CMakeLists.txt b/test/util/CMakeLists.txt new file mode 100644 index 0000000000..bf0a444f18 --- /dev/null +++ b/test/util/CMakeLists.txt @@ -0,0 +1,7 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +add_gtest_executable(unit_sequence unit_sequence.cpp) +if(result EQUAL 0) + target_link_libraries(unit_sequence PRIVATE utility) +endif() diff --git a/test/util/unit_sequence.cpp b/test/util/unit_sequence.cpp new file mode 100644 index 0000000000..f09fd86e06 --- /dev/null +++ b/test/util/unit_sequence.cpp @@ -0,0 +1,684 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include "ck/utility/sequence.hpp" +#include "ck/utility/functional.hpp" + +using namespace ck; + +// Test basic Sequence construction and properties +TEST(Sequence, BasicConstruction) +{ + using Seq = Sequence<1, 2, 3, 4, 5>; + EXPECT_EQ(Seq::Size(), 5); + EXPECT_EQ(Seq::mSize, 5); +} + +TEST(Sequence, EmptySequence) +{ + using Seq = Sequence<>; + EXPECT_EQ(Seq::Size(), 0); + EXPECT_EQ(Seq::mSize, 0); +} + +// Test At() method +TEST(Sequence, AtRuntime) +{ + using Seq = Sequence<10, 20, 30, 40>; + EXPECT_EQ(Seq::At(0), 10); + EXPECT_EQ(Seq::At(1), 20); + EXPECT_EQ(Seq::At(2), 30); + EXPECT_EQ(Seq::At(3), 40); +} + +TEST(Sequence, AtCompileTime) +{ + using Seq = Sequence<10, 20, 30, 40>; + EXPECT_EQ(Seq::At(Number<0>{}), 10); + EXPECT_EQ(Seq::At(Number<1>{}), 20); + EXPECT_EQ(Seq::At(Number<2>{}), 30); + EXPECT_EQ(Seq::At(Number<3>{}), 40); +} + +TEST(Sequence, OperatorBracket) +{ + constexpr auto seq = Sequence<5, 10, 15>{}; + EXPECT_EQ(seq[Number<0>{}], 5); + EXPECT_EQ(seq[Number<1>{}], 10); + EXPECT_EQ(seq[Number<2>{}], 15); +} + +// Test Front() and Back() +TEST(Sequence, FrontBack) +{ + using Seq = Sequence<100, 200, 300>; + EXPECT_EQ(Seq::Front(), 100); + EXPECT_EQ(Seq::Back(), 300); +} + +TEST(Sequence, FrontBackSingleElement) +{ + using Seq = Sequence<42>; + EXPECT_EQ(Seq::Front(), 42); + EXPECT_EQ(Seq::Back(), 42); +} + +// Test PushFront and PushBack +TEST(Sequence, PushFront) +{ + using Seq = Sequence<2, 3, 4>; + using Result = decltype(Seq::PushFront(Sequence<1>{})); + using Expected = Sequence<1, 2, 3, 4>; + EXPECT_TRUE((is_same::value)); +} + +TEST(Sequence, PushFrontNumbers) +{ + using Seq = Sequence<3, 4>; + using Result = decltype(Seq::PushFront(Number<1>{}, Number<2>{})); + using Expected = Sequence<1, 2, 3, 4>; + EXPECT_TRUE((is_same::value)); +} + +TEST(Sequence, PushBack) +{ + using Seq = Sequence<1, 2, 3>; + using Result = decltype(Seq::PushBack(Sequence<4, 5>{})); + using Expected = Sequence<1, 2, 3, 4, 5>; + EXPECT_TRUE((is_same::value)); +} + +TEST(Sequence, PushBackNumbers) +{ + using Seq = Sequence<1, 2>; + using Result = decltype(Seq::PushBack(Number<3>{}, Number<4>{})); + using Expected = Sequence<1, 2, 3, 4>; + EXPECT_TRUE((is_same::value)); +} + +// Test PopFront and PopBack +TEST(Sequence, PopFront) +{ + using Seq = Sequence<1, 2, 3, 4>; + using Result = decltype(Seq::PopFront()); + using Expected = Sequence<2, 3, 4>; + EXPECT_TRUE((is_same::value)); +} + +TEST(Sequence, PopBack) +{ + using Seq = Sequence<1, 2, 3, 4>; + using Result = decltype(Seq::PopBack()); + using Expected = Sequence<1, 2, 3>; + EXPECT_TRUE((is_same::value)); +} + +// Test Extract +TEST(Sequence, ExtractByNumbers) +{ + using Seq = Sequence<10, 20, 30, 40, 50>; + using Result = decltype(Seq::Extract(Number<0>{}, Number<2>{}, Number<4>{})); + using Expected = Sequence<10, 30, 50>; + EXPECT_TRUE((is_same::value)); +} + +TEST(Sequence, ExtractBySequence) +{ + using Seq = Sequence<10, 20, 30, 40, 50>; + using Result = decltype(Seq::Extract(Sequence<1, 3>{})); + using Expected = Sequence<20, 40>; + EXPECT_TRUE((is_same::value)); +} + +// Test Modify +TEST(Sequence, Modify) +{ + using Seq = Sequence<1, 2, 3, 4>; + using Result = decltype(Seq::Modify(Number<2>{}, Number<99>{})); + using Expected = Sequence<1, 2, 99, 4>; + EXPECT_TRUE((is_same::value)); +} + +// Test Transform +TEST(Sequence, Transform) +{ + using Seq = Sequence<1, 2, 3, 4>; + auto double_it = [](auto x) { return 2 * x; }; + using Result = decltype(Seq::Transform(double_it)); + using Expected = Sequence<2, 4, 6, 8>; + EXPECT_TRUE((is_same::value)); +} + +// Test Reverse +TEST(Sequence, Reverse) +{ + using Seq = Sequence<1, 2, 3, 4, 5>; + using Result = decltype(Seq::Reverse()); + using Expected = Sequence<5, 4, 3, 2, 1>; + EXPECT_TRUE((is_same::value)); +} + +TEST(Sequence, ReverseSingleElement) +{ + using Seq = Sequence<42>; + using Result = decltype(Seq::Reverse()); + using Expected = Sequence<42>; + EXPECT_TRUE((is_same::value)); +} + +// Test ReorderGivenNew2Old +TEST(Sequence, ReorderGivenNew2Old) +{ + using Seq = Sequence<10, 20, 30, 40>; + using Result = decltype(Seq::ReorderGivenNew2Old(Sequence<3, 1, 2, 0>{})); + using Expected = Sequence<40, 20, 30, 10>; + EXPECT_TRUE((is_same::value)); +} + +// Test ReorderGivenOld2New +TEST(Sequence, ReorderGivenOld2New) +{ + using Seq = Sequence<10, 20, 30, 40>; + using Result = decltype(Seq::ReorderGivenOld2New(Sequence<3, 1, 2, 0>{})); + using Expected = Sequence<40, 20, 30, 10>; + EXPECT_TRUE((is_same::value)); +} + +// Test arithmetic_sequence_gen +TEST(SequenceGen, ArithmeticSequence) +{ + using Result = typename arithmetic_sequence_gen<0, 5, 1>::type; + using Expected = Sequence<0, 1, 2, 3, 4>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceGen, ArithmeticSequenceWithIncrement) +{ + using Result = typename arithmetic_sequence_gen<0, 10, 2>::type; + using Expected = Sequence<0, 2, 4, 6, 8>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceGen, ArithmeticSequenceNegativeIncrement) +{ + using Result = typename arithmetic_sequence_gen<10, 5, -1>::type; + using Expected = Sequence<10, 9, 8, 7, 6>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceGen, ArithmeticSequenceEmpty) +{ + using Result = typename arithmetic_sequence_gen<5, 5, 1>::type; + using Expected = Sequence<>; + EXPECT_TRUE((is_same::value)); +} + +// Test uniform_sequence_gen +TEST(SequenceGen, UniformSequence) +{ + using Result = typename uniform_sequence_gen<5, 42>::type; + using Expected = Sequence<42, 42, 42, 42, 42>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceGen, UniformSequenceZeroSize) +{ + using Result = typename uniform_sequence_gen<0, 42>::type; + using Expected = Sequence<>; + EXPECT_TRUE((is_same::value)); +} + +// Test make_index_sequence +TEST(SequenceGen, MakeIndexSequence) +{ + using Result = make_index_sequence<5>; + using Expected = Sequence<0, 1, 2, 3, 4>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceGen, MakeIndexSequenceZero) +{ + using Result = make_index_sequence<0>; + using Expected = Sequence<>; + EXPECT_TRUE((is_same::value)); +} + +// Test sequence_merge +TEST(SequenceMerge, MergeTwoSequences) +{ + using Seq1 = Sequence<1, 2, 3>; + using Seq2 = Sequence<4, 5, 6>; + using Result = typename sequence_merge::type; + using Expected = Sequence<1, 2, 3, 4, 5, 6>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceMerge, MergeMultipleSequences) +{ + using Seq1 = Sequence<1, 2>; + using Seq2 = Sequence<3, 4>; + using Seq3 = Sequence<5, 6>; + using Result = typename sequence_merge::type; + using Expected = Sequence<1, 2, 3, 4, 5, 6>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceMerge, MergeSingleSequence) +{ + using Seq = Sequence<1, 2, 3>; + using Result = typename sequence_merge::type; + using Expected = Sequence<1, 2, 3>; + EXPECT_TRUE((is_same::value)); +} + +// Test sequence_split +TEST(SequenceSplit, SplitInMiddle) +{ + using Seq = Sequence<1, 2, 3, 4, 5, 6>; + using Split = sequence_split; + using ExpectedLeft = Sequence<1, 2, 3>; + using ExpectedRight = Sequence<4, 5, 6>; + EXPECT_TRUE((is_same::value)); + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceSplit, SplitAtBeginning) +{ + using Seq = Sequence<1, 2, 3, 4>; + using Split = sequence_split; + using ExpectedLeft = Sequence<>; + using ExpectedRight = Sequence<1, 2, 3, 4>; + EXPECT_TRUE((is_same::value)); + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceSplit, SplitAtEnd) +{ + using Seq = Sequence<1, 2, 3, 4>; + using Split = sequence_split; + using ExpectedLeft = Sequence<1, 2, 3, 4>; + using ExpectedRight = Sequence<>; + EXPECT_TRUE((is_same::value)); + EXPECT_TRUE((is_same::value)); +} + +// Test sequence_sort +TEST(SequenceSort, SortAscending) +{ + using Seq = Sequence<5, 2, 8, 1, 9>; + using Result = typename sequence_sort>::type; + using Expected = Sequence<1, 2, 5, 8, 9>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceSort, SortDescending) +{ + // Create a greater-than comparator + struct greater + { + __host__ __device__ constexpr bool operator()(index_t x, index_t y) const { return x > y; } + }; + using Seq = Sequence<5, 2, 8, 1, 9>; + using Result = typename sequence_sort::type; + using Expected = Sequence<9, 8, 5, 2, 1>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceSort, SortAlreadySorted) +{ + using Seq = Sequence<1, 2, 3, 4, 5>; + using Result = typename sequence_sort>::type; + using Expected = Sequence<1, 2, 3, 4, 5>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceSort, SortWithDuplicates) +{ + using Seq = Sequence<3, 1, 4, 1, 5, 9, 2, 6, 5>; + using Result = typename sequence_sort>::type; + using Expected = Sequence<1, 1, 2, 3, 4, 5, 5, 6, 9>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceSort, SortEmptySequence) +{ + using Seq = Sequence<>; + using Result = typename sequence_sort>::type; + using Expected = Sequence<>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceSort, SortSingleElement) +{ + using Seq = Sequence<42>; + using Result = typename sequence_sort>::type; + using Expected = Sequence<42>; + EXPECT_TRUE((is_same::value)); +} + +// Test sequence_unique_sort +TEST(SequenceUniqueSort, UniqueSort) +{ + using Seq = Sequence<3, 1, 4, 1, 5, 9, 2, 6, 5>; + using Result = + typename sequence_unique_sort, math::equal>::type; + using Expected = Sequence<1, 2, 3, 4, 5, 6, 9>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceUniqueSort, UniqueSortNoDuplicates) +{ + using Seq = Sequence<5, 2, 8, 1, 9>; + using Result = + typename sequence_unique_sort, math::equal>::type; + using Expected = Sequence<1, 2, 5, 8, 9>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceUniqueSort, UniqueSortAllSame) +{ + using Seq = Sequence<5, 5, 5, 5>; + using Result = + typename sequence_unique_sort, math::equal>::type; + using Expected = Sequence<5>; + EXPECT_TRUE((is_same::value)); +} + +// Test is_valid_sequence_map +TEST(SequenceMap, ValidMap) +{ + using Map = Sequence<0, 1, 2, 3>; + EXPECT_TRUE((is_valid_sequence_map::value)); +} + +TEST(SequenceMap, ValidMapPermuted) +{ + using Map = Sequence<2, 0, 3, 1>; + EXPECT_TRUE((is_valid_sequence_map::value)); +} + +TEST(SequenceMap, InvalidMapDuplicate) +{ + using Map = Sequence<0, 1, 1, 3>; + EXPECT_FALSE((is_valid_sequence_map::value)); +} + +TEST(SequenceMap, InvalidMapMissing) +{ + using Map = Sequence<0, 1, 3, 4>; + EXPECT_FALSE((is_valid_sequence_map::value)); +} + +// Test sequence_map_inverse +// Note: sequence_map_inverse inverts a mapping where Map[i] = j means old position i maps to new +// position j The inverse gives us new position i came from old position inverse[i] +TEST(SequenceMapInverse, InverseMap) +{ + // Map = <2, 0, 3, 1> means: old[0]->new[2], old[1]->new[0], old[2]->new[3], old[3]->new[1] + // Inverse should be: new[0]<-old[1], new[1]<-old[3], new[2]<-old[0], new[3]<-old[2] + using Map = Sequence<2, 0, 3, 1>; + using Result = typename sequence_map_inverse::type; + // Verify by checking that Map[Result[i]] == i for all i + EXPECT_EQ((Map::At(Number{})>{}) == 0), true); + EXPECT_EQ((Map::At(Number{})>{}) == 1), true); + EXPECT_EQ((Map::At(Number{})>{}) == 2), true); + EXPECT_EQ((Map::At(Number{})>{}) == 3), true); +} + +TEST(SequenceMapInverse, InverseIdentityMap) +{ + using Map = Sequence<0, 1, 2, 3>; + using Result = typename sequence_map_inverse::type; + // Verify by checking that Map[Result[i]] == i for all i (same as the other test) + EXPECT_EQ((Map::At(Number{})>{}) == 0), true); + EXPECT_EQ((Map::At(Number{})>{}) == 1), true); + EXPECT_EQ((Map::At(Number{})>{}) == 2), true); + EXPECT_EQ((Map::At(Number{})>{}) == 3), true); +} + +// Test sequence operators +TEST(SequenceOperators, Equality) +{ + constexpr auto seq1 = Sequence<1, 2, 3>{}; + constexpr auto seq2 = Sequence<1, 2, 3>{}; + constexpr auto seq3 = Sequence<1, 2, 4>{}; + EXPECT_TRUE(seq1 == seq2); + EXPECT_FALSE(seq1 == seq3); +} + +TEST(SequenceOperators, Addition) +{ + using Seq1 = Sequence<1, 2, 3>; + using Seq2 = Sequence<4, 5, 6>; + using Result = decltype(Seq1{} + Seq2{}); + using Expected = Sequence<5, 7, 9>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceOperators, Subtraction) +{ + using Seq1 = Sequence<10, 20, 30>; + using Seq2 = Sequence<1, 2, 3>; + using Result = decltype(Seq1{} - Seq2{}); + using Expected = Sequence<9, 18, 27>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceOperators, Multiplication) +{ + using Seq1 = Sequence<2, 3, 4>; + using Seq2 = Sequence<5, 6, 7>; + using Result = decltype(Seq1{} * Seq2{}); + using Expected = Sequence<10, 18, 28>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceOperators, Division) +{ + using Seq1 = Sequence<10, 20, 30>; + using Seq2 = Sequence<2, 4, 5>; + using Result = decltype(Seq1{} / Seq2{}); + using Expected = Sequence<5, 5, 6>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceOperators, Modulo) +{ + using Seq1 = Sequence<10, 20, 30>; + using Seq2 = Sequence<3, 7, 8>; + using Result = decltype(Seq1{} % Seq2{}); + using Expected = Sequence<1, 6, 6>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceOperators, AdditionWithNumber) +{ + using Seq = Sequence<1, 2, 3>; + using Result = decltype(Seq{} + Number<10>{}); + using Expected = Sequence<11, 12, 13>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceOperators, SubtractionWithNumber) +{ + using Seq = Sequence<10, 20, 30>; + using Result = decltype(Seq{} - Number<5>{}); + using Expected = Sequence<5, 15, 25>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceOperators, MultiplicationWithNumber) +{ + using Seq = Sequence<2, 3, 4>; + using Result = decltype(Seq{} * Number<3>{}); + using Expected = Sequence<6, 9, 12>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceOperators, DivisionWithNumber) +{ + using Seq = Sequence<10, 20, 30>; + using Result = decltype(Seq{} / Number<5>{}); + using Expected = Sequence<2, 4, 6>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceOperators, NumberAddition) +{ + using Seq = Sequence<1, 2, 3>; + using Result = decltype(Number<10>{} + Seq{}); + using Expected = Sequence<11, 12, 13>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceOperators, NumberMultiplication) +{ + using Seq = Sequence<2, 3, 4>; + using Result = decltype(Number<3>{} * Seq{}); + using Expected = Sequence<6, 9, 12>; + EXPECT_TRUE((is_same::value)); +} + +// Test helper functions +TEST(SequenceHelpers, MergeSequences) +{ + using Seq1 = Sequence<1, 2>; + using Seq2 = Sequence<3, 4>; + using Seq3 = Sequence<5, 6>; + using Result = decltype(merge_sequences(Seq1{}, Seq2{}, Seq3{})); + using Expected = Sequence<1, 2, 3, 4, 5, 6>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceHelpers, TransformSequencesSingle) +{ + auto double_it = [](auto x) { return 2 * x; }; + using Seq = Sequence<1, 2, 3>; + using Result = decltype(transform_sequences(double_it, Seq{})); + using Expected = Sequence<2, 4, 6>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceHelpers, TransformSequencesTwo) +{ + auto add = [](auto x, auto y) { return x + y; }; + using Seq1 = Sequence<1, 2, 3>; + using Seq2 = Sequence<4, 5, 6>; + using Result = decltype(transform_sequences(add, Seq1{}, Seq2{})); + using Expected = Sequence<5, 7, 9>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceHelpers, TransformSequencesThree) +{ + auto add3 = [](auto x, auto y, auto z) { return x + y + z; }; + using Seq1 = Sequence<1, 2, 3>; + using Seq2 = Sequence<4, 5, 6>; + using Seq3 = Sequence<7, 8, 9>; + using Result = decltype(transform_sequences(add3, Seq1{}, Seq2{}, Seq3{})); + using Expected = Sequence<12, 15, 18>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceHelpers, ReduceOnSequence) +{ + auto add = [](auto x, auto y) { return x + y; }; + constexpr auto seq = Sequence<1, 2, 3, 4, 5>{}; + constexpr auto result = reduce_on_sequence(seq, add, Number<0>{}); + EXPECT_EQ(result, 15); +} + +TEST(SequenceHelpers, SequenceAnyOf) +{ + auto is_even = [](auto x) { return x % 2 == 0; }; + constexpr auto seq1 = Sequence<1, 3, 5, 7>{}; + constexpr auto seq2 = Sequence<1, 3, 4, 7>{}; + EXPECT_FALSE(sequence_any_of(seq1, is_even)); + EXPECT_TRUE(sequence_any_of(seq2, is_even)); +} + +TEST(SequenceHelpers, SequenceAllOf) +{ + auto is_positive = [](auto x) { return x > 0; }; + constexpr auto seq1 = Sequence<1, 2, 3, 4>{}; + constexpr auto seq2 = Sequence<1, -2, 3, 4>{}; + EXPECT_TRUE(sequence_all_of(seq1, is_positive)); + EXPECT_FALSE(sequence_all_of(seq2, is_positive)); +} + +// Test scan operations +TEST(SequenceScan, ReverseInclusiveScan) +{ + using Seq = Sequence<1, 2, 3, 4>; + using Result = + decltype(reverse_inclusive_scan_sequence(Seq{}, math::plus{}, Number<0>{})); + using Expected = Sequence<10, 9, 7, 4>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceScan, ReverseExclusiveScan) +{ + using Seq = Sequence<1, 2, 3, 4>; + using Result = + decltype(reverse_exclusive_scan_sequence(Seq{}, math::plus{}, Number<0>{})); + using Expected = Sequence<9, 7, 4, 0>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceScan, InclusiveScan) +{ + using Seq = Sequence<1, 2, 3, 4>; + using Result = decltype(inclusive_scan_sequence(Seq{}, math::plus{}, Number<0>{})); + using Expected = Sequence<1, 3, 6, 10>; + EXPECT_TRUE((is_same::value)); +} + +// Test pick and modify operations +TEST(SequencePickModify, PickElementsByIds) +{ + using Seq = Sequence<10, 20, 30, 40, 50>; + using Ids = Sequence<0, 2, 4>; + using Result = decltype(pick_sequence_elements_by_ids(Seq{}, Ids{})); + using Expected = Sequence<10, 30, 50>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequencePickModify, PickElementsByMask) +{ + using Seq = Sequence<10, 20, 30, 40, 50>; + using Mask = Sequence<1, 0, 1, 0, 1>; + using Result = decltype(pick_sequence_elements_by_mask(Seq{}, Mask{})); + using Expected = Sequence<10, 30, 50>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequencePickModify, ModifyElementsByIds) +{ + using Seq = Sequence<10, 20, 30, 40, 50>; + using Values = Sequence<99, 88>; + using Ids = Sequence<1, 3>; + using Result = decltype(modify_sequence_elements_by_ids(Seq{}, Values{}, Ids{})); + using Expected = Sequence<10, 99, 30, 88, 50>; + EXPECT_TRUE((is_same::value)); +} + +// Test sequence_reduce +TEST(SequenceReduce, ReduceTwoSequences) +{ + using Seq1 = Sequence<1, 2, 3>; + using Seq2 = Sequence<4, 5, 6>; + using Result = typename sequence_reduce, Seq1, Seq2>::type; + using Expected = Sequence<5, 7, 9>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceReduce, ReduceMultipleSequences) +{ + using Seq1 = Sequence<1, 2>; + using Seq2 = Sequence<3, 4>; + using Seq3 = Sequence<5, 6>; + using Result = typename sequence_reduce, Seq1, Seq2, Seq3>::type; + using Expected = Sequence<9, 12>; + EXPECT_TRUE((is_same::value)); +} From 8270900d606398868e747b7f9097484ee73a4cb4 Mon Sep 17 00:00:00 2001 From: Geo Min Date: Wed, 10 Dec 2025 17:34:41 -0800 Subject: [PATCH 03/10] [ci] Bumping TheRock commit hash (#3385) * Bumping TheRock commit hash * new docker hash * Using new runner name --- .github/workflows/therock-ci-linux.yml | 4 ++-- .github/workflows/therock-ci.yml | 2 +- .github/workflows/therock-test-component.yml | 2 +- .github/workflows/therock-test-packages.yml | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/therock-ci-linux.yml b/.github/workflows/therock-ci-linux.yml index 86d134e456..0baa503334 100644 --- a/.github/workflows/therock-ci-linux.yml +++ b/.github/workflows/therock-ci-linux.yml @@ -20,7 +20,7 @@ jobs: permissions: id-token: write container: - image: ghcr.io/rocm/therock_build_manylinux_x86_64@sha256:2f3ebd0beb04c449fdb36933e54bdc69483b914fb9005594d3fc9444c206b54b + image: ghcr.io/rocm/therock_build_manylinux_x86_64@sha256:583d473f263a289222c48d4b493e2956b2354a45796f09dee6f2c8ecd4504ab6 options: -v /runner/config:/home/awsconfig/ env: AMDGPU_FAMILIES: ${{ inputs.amdgpu_families }} @@ -54,7 +54,7 @@ jobs: with: repository: "ROCm/TheRock" path: "TheRock" - ref: f3f77a3161922df3eee006b888b439d75b2b4668 # 2025-10-29 commit + ref: d76278526218def9fb1b016bc9e421738cb4f8f6 # 2025-12-09 commit - name: Setup ccache run: | diff --git a/.github/workflows/therock-ci.yml b/.github/workflows/therock-ci.yml index 40a3b0bec8..0951244f31 100644 --- a/.github/workflows/therock-ci.yml +++ b/.github/workflows/therock-ci.yml @@ -65,7 +65,7 @@ jobs: -DTHEROCK_USE_EXTERNAL_ROCM_LIBRARIES=ON -DTHEROCK_ROCM_LIBRARIES_SOURCE_DIR=../ amdgpu_families: "gfx94X-dcgpu" - test_runs_on: "linux-mi325-1gpu-ossci-rocm" + test_runs_on: "linux-mi325-1gpu-ossci-rocm-frac" therock_ci_summary: name: TheRock CI Summary diff --git a/.github/workflows/therock-test-component.yml b/.github/workflows/therock-test-component.yml index 27eff4fdb0..565d1d3e54 100644 --- a/.github/workflows/therock-test-component.yml +++ b/.github/workflows/therock-test-component.yml @@ -51,7 +51,7 @@ jobs: uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: repository: "ROCm/TheRock" - ref: f3f77a3161922df3eee006b888b439d75b2b4668 # 2025-10-29 commit + ref: d76278526218def9fb1b016bc9e421738cb4f8f6 # 2025-12-09 commit - name: Run setup test environment workflow uses: './.github/actions/setup_test_environment' diff --git a/.github/workflows/therock-test-packages.yml b/.github/workflows/therock-test-packages.yml index 81632fce48..cd255a40b6 100644 --- a/.github/workflows/therock-test-packages.yml +++ b/.github/workflows/therock-test-packages.yml @@ -27,7 +27,7 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: repository: "ROCm/TheRock" - ref: f3f77a3161922df3eee006b888b439d75b2b4668 # 2025-10-29 commit + ref: d76278526218def9fb1b016bc9e421738cb4f8f6 # 2025-12-09 commit - name: "Configuring CI options" env: From fbbdd36ea880aaee1eb4691f1c670492fa388647 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Thu, 11 Dec 2025 10:47:19 +0400 Subject: [PATCH 04/10] docs: add notes on tile distribution and inline comments (#3297) * docs: add notes on tile distribution and inline comments * Apply suggestions from code review Co-authored-by: spolifroni-amd --------- Co-authored-by: spolifroni-amd --- .../01_naive_gemm/TILE_DISTRIBUTION.md | 312 ++++++++++++++++++ ...ice_gemm_block_policy_agmem_bgmem_creg.hpp | 12 +- ...ce_gemm_host_pipeline_agmem_bgmem_creg.hpp | 2 +- .../ck_tile/01_naive_gemm/practice_gemm.cpp | 34 +- .../ck_tile/01_naive_gemm/practice_gemm.hpp | 7 +- 5 files changed, 347 insertions(+), 20 deletions(-) create mode 100644 tutorial/ck_tile/01_naive_gemm/TILE_DISTRIBUTION.md diff --git a/tutorial/ck_tile/01_naive_gemm/TILE_DISTRIBUTION.md b/tutorial/ck_tile/01_naive_gemm/TILE_DISTRIBUTION.md new file mode 100644 index 0000000000..275d1a1c12 --- /dev/null +++ b/tutorial/ck_tile/01_naive_gemm/TILE_DISTRIBUTION.md @@ -0,0 +1,312 @@ +# Tile Distribution: Mapping Threads to Data + +## Overview + +**Tile Distribution** describes how each thread in a thread block maps to elements of a block tile. It defines the hierarchical pattern of data distribution across threads, warps, and thread blocks. + +## The Problem + +Given a block tile of size `MPerBlock × KPerBlock` (e.g., 256×32), we need to determine: +- Which threads load which elements. +- How the threads are organized into warps. +- The number of times each warp repeats its pattern. +- The number of elements each thread can load in a single vector instruction. + +--- + +## Bottom-Up Construction Approach + +### Step 1: Determine K Dimension Layout + +**Start with the innermost dimension (K) for memory coalescing:** + +```cpp +constexpr index_t K1 = 16 / sizeof(ADataType); // Elements per thread (vector load) +constexpr index_t K0 = kKPerBlock / K1; // Threads needed in K dimension +``` + +**Example (with fp16):** +- `K1 = 16 / 2 = 8` → Each thread loads 8 fp16 elements in a single vector instruction +- `kKPerBlock = 32` +- `K0 = 32 / 8 = 4` → We need 4 threads along K to cover the entire K dimension + +**Visual:** +``` +K dimension (32 elements): +Thread 0: [0-7] Thread 1: [8-15] Thread 2: [16-23] Thread 3: [24-31] + K1=8 K1=8 K1=8 K1=8 +├──────────────────────────────────────────────────────────────┤ + K0=4 threads +``` + +--- + +### Step 2: Determine M Dimension Layout + +**Now partition the M dimension hierarchically:** + +#### Level 1: Threads per Warp in M (M2) + +```cpp +constexpr index_t M2 = get_warp_size() / K0; +``` + +- Warp size = 64 threads +- K dimension already uses `K0 = 4` threads per row +- `M2 = 64 / 4 = 16` → Each warp can have 16 threads in M dimension + +**Visual (Single Warp):** +``` + K dimension (4 threads) + ┌─────┬─────┬─────┬─────┐ + 0 │ T0 │ T1 │ T2 │ T3 │ + 1 │ T4 │ T5 │ T6 │ T7 │ + 2 │ T8 │ T9 │ T10 │ T11 │ +M 3 │ T12 │ T13 │ T14 │ T15 │ ← 16 rows + ...│ ... │ ... │ ... │ ... │ (M2=16) + 15 │ T60 │ T61 │ T62 │ T63 │ + └─────┴─────┴─────┴─────┘ + One Warp = 64 threads +``` + +#### Level 2: Warps per Block (M1) + +```cpp +constexpr index_t M1 = kBlockSize / get_warp_size(); +``` + +- `kBlockSize = 256` threads per block +- `M1 = 256 / 64 = 4` → We have 4 warps per block + +**Visual (4 Warps):** +``` + Warp 0 (rows 0-15) + Warp 1 (rows 16-31) + Warp 2 (rows 32-47) + Warp 3 (rows 48-63) + ↑ + M1 = 4 warps cover 64 rows total +``` + +#### Level 3: Repetitions (M0) + +```cpp +constexpr index_t M0 = kMPerBlock / (M2 * M1); +``` + +- `kMPerBlock = 256` rows to cover +- `M2 * M1 = 16 * 4 = 64` rows covered by all warps +- `M0 = 256 / 64 = 4` → Each warp must repeat its pattern 4 times + +**Visual (Complete Block):** +``` +┌──────────────┐ +│ Iteration 0 │ ← Warp 0: rows 0-15, Warp 1: rows 16-31, ... +│ (rows 0-63) │ +├──────────────┤ +│ Iteration 1 │ ← Warp 0: rows 64-79, Warp 1: rows 80-95, ... +│ (rows 64-127)│ +├──────────────┤ +│ Iteration 2 │ ← Warp 0: rows 128-143, Warp 1: rows 144-159, ... +│(rows 128-191)│ +├──────────────┤ +│ Iteration 3 │ ← Warp 0: rows 192-207, Warp 1: rows 208-223, ... +│(rows 192-255)│ +└──────────────┘ + M0 = 4 iterations +``` + +--- + +## The Tile Distribution Encoding + +Now we can construct the distribution: + +```cpp +tile_distribution_encoding< + sequence<1>, // [1] Replication + tuple, sequence>, // [2] Hierarchy + tuple, sequence<1, 2>>, // [3] Parallelism: + tuple, sequence<2, 0>>, // [3] Parallelism + sequence<1, 2>, // [4] Yield + sequence<0, 1> // [4] Yield +> +``` + +### [1] Replication: `sequence<1>` + +Defines how many times warp patterns are replicated: +- `1` = Each warp has a unique pattern (no replication) +- `2` = Warp 0 and Warp 1 do the same thing, Warp 2 and Warp 3 do the same thing +- `4` = All warps do the same thing + +In our case: `1` means no replication (each warp is independent). + +--- + +### [2] Hierarchy: The Multi-Level Structure + +```cpp +tuple, sequence> + └───────┬──────────┘ └──────┬────────┘ + M dimension K dimension +``` + +**Concrete values:** +- M hierarchy: `sequence<4, 4, 16>` = (4 repetitions, 4 warps, 16 threads/warp) +- K hierarchy: `sequence<4, 8>` = (4 threads, 8 elements/thread) + +--- + +### [3] Parallelism: Addressing the Hierarchy + +**The key insight:** Read the tuples **vertically** to understand indexing! + +```cpp +tuple, sequence<1, 2>> +tuple, sequence<2, 0>> +``` + +#### Reading Pattern + +**Column 1 (Dimension 0 = M):** +``` +sequence<1> → Address hierarchy index 1,1 → M1 (warps/block in M dimension) +sequence<1> +``` + +**Column 2 (Dimension 1 = K):** +``` +sequence<1, 2> +sequence<2, 0> +``` +[1,2] M2=threads/warp in M dimension +[2,0] K0=threads/warp in K dimension + +--- + +### [4] Yield Sequences: Output Ordering + +```cpp +sequence<1, 2> +sequence<0, 1> + +[1,0] means M0=repetitions/warp in M dimension +[2,1] means K1=elements/thread in K dimension +``` +--- + +## Complete Example: Thread 25 in Warp 0 + +Let's trace where **Thread 25** in **Warp 0** reads data: + +### Thread Coordinates +- Thread ID in warp: 25 +- Warp ID in block: 0 + +### Decompose Thread 25 +``` +Thread 25 in a 2D layout (M2=16, K0=4): +Row index: 25 / 4 = 6 +Col index: 25 % 4 = 1 +``` + +### M Position (Row) +``` +M0 iteration: 0 (first iteration) +M1 warp: 0 (warp 0) +M2 thread: 6 (6th row in warp) +→ M position = 0*64 + 0*16 + 6 = 6 +``` + +### K Position (Column) +``` +K0 thread: 1 (column group 1) +K1 elements: 8 (will load 8 consecutive elements) +→ K position = 1*8 + [0-7] = elements 8-15 +``` + +**Result:** Thread 25 in Warp 0 loads **row 6, columns 8-15** (8 elements). + +--- + +## Why This Matters + +### 1. **Memory Coalescing** +- Consecutive threads access consecutive memory → efficient global memory access +- K dimension uses K1=8 for vectorized loads + +### 2. **Warp Efficiency** +- All 64 threads in a warp are utilized +- Natural 2D layout: 16 threads (M) × 4 threads (K) = 64 threads + +### 3. **Scalability** +- M0 repetitions allow handling larger tiles +- Same pattern scales to different sizes + +### 4. **Register Allocation** +- Each thread knows exactly how many elements it will hold +- Compiler can allocate registers optimally + +--- + +## Summary Table + +| Parameter | Value | Meaning | +|-----------|-------|---------| +| **K1** | 8 | Elements per thread (vector width) | +| **K0** | 4 | Threads along K per row | +| **M2** | 16 | Threads along M per warp | +| **M1** | 4 | Warps per block | +| **M0** | 4 | Repetitions of warp pattern | +| **Total Threads** | 256 | M0×M1×M2 = 4×4×16 (actually M1×64) | +| **Total Elements** | 8192 | 256×32 (MPerBlock × KPerBlock) | +| **Elements/Thread** | 32 | M0×K1 = 4×8 | + +--- + +## Visualization: Complete Thread Block + +``` +Block Tile: 256×32 + + K dimension (32 elements) + ├─────────────────────┤ + 0 ┌──────────────────────┐ ┐ + 16 │ Warp 0 │ │ + 32 │ Warp 1 │ │ Iteration 0 + 48 │ Warp 2 │ │ (M0=0) + 64 │ Warp 3 │ ┘ + 80 ├──────────────────────┤ ┐ + 96 │ Warp 0 │ │ + 112 │ Warp 1 │ │ Iteration 1 + 128 │ Warp 2 │ │ (M0=1) + 144 │ Warp 3 │ ┘ + 160 ├──────────────────────┤ ┐ + 176 │ Warp 0 │ │ + 192 │ Warp 1 │ │ Iteration 2 + 208 │ Warp 2 │ │ (M0=2) + 224 │ Warp 3 │ ┘ + 240 ├──────────────────────┤ ┐ + 256 │ Warp 0 │ │ + │ Warp 1 │ │ Iteration 3 + │ Warp 2 │ │ (M0=3) + │ Warp 3 │ ┘ + └──────────────────────┘ + +Each warp processes 16 rows × 32 cols = 512 elements +Each iteration processes 64 rows × 32 cols = 2048 elements +Total: 4 iterations × 2048 = 8192 elements ✓ +``` + +--- + +## Key Takeaways + +1. **Bottom-up construction**: Start from vector width (K1), build up through thread/warp/block hierarchy +2. **Vertical reading**: The repeat and elements tuples are read column-wise to address hierarchy levels +3. **Replication controls redundancy**: How many warps share the same pattern +4. **Hierarchy encodes structure**: The multi-level sequence defines the complete mapping + +This design enables CK to achieve maximum GPU performance through optimal thread-to-data mapping! + diff --git a/tutorial/ck_tile/01_naive_gemm/block_level/practice_gemm_block_policy_agmem_bgmem_creg.hpp b/tutorial/ck_tile/01_naive_gemm/block_level/practice_gemm_block_policy_agmem_bgmem_creg.hpp index 2921bce8bf..a3ed982488 100644 --- a/tutorial/ck_tile/01_naive_gemm/block_level/practice_gemm_block_policy_agmem_bgmem_creg.hpp +++ b/tutorial/ck_tile/01_naive_gemm/block_level/practice_gemm_block_policy_agmem_bgmem_creg.hpp @@ -98,12 +98,12 @@ struct PracticeGemmBlockPolicy constexpr index_t M0 = kMPerBlock / (M2 * M1); return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); + tile_distribution_encoding, // replication + tuple, sequence>, // hierarchy + tuple, sequence<1, 2>>, // parallelism + tuple, sequence<2, 0>>, // paralleism + sequence<1, 2>, // yield + sequence<0, 1>>{}); // yield } template diff --git a/tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_pipeline_agmem_bgmem_creg.hpp b/tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_pipeline_agmem_bgmem_creg.hpp index dd72f08d99..15c1743a86 100644 --- a/tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_pipeline_agmem_bgmem_creg.hpp +++ b/tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_pipeline_agmem_bgmem_creg.hpp @@ -24,7 +24,7 @@ struct PracticeGemmHostPipeline template CK_TILE_DEVICE void operator()(const ADRAMTensorView& a_dram, const BDRAMTensorView& b_dram, - CDRAMTensorView& c_dram_ref) const + CDRAMTensorView& c_dram) const { // Size of the entire problem diff --git a/tutorial/ck_tile/01_naive_gemm/practice_gemm.cpp b/tutorial/ck_tile/01_naive_gemm/practice_gemm.cpp index 4f0bc13dd5..7635c9376b 100644 --- a/tutorial/ck_tile/01_naive_gemm/practice_gemm.cpp +++ b/tutorial/ck_tile/01_naive_gemm/practice_gemm.cpp @@ -6,7 +6,7 @@ #include "practice_gemm.hpp" #include "reference_gemm.hpp" -int main() +int main(int argc, char* argv[]) { // TODO: GemmTypeConfig using ADataType = ck_tile::half_t; @@ -14,11 +14,22 @@ int main() using CDataType = float; using AccDataType = float; - // ArgParser - ck_tile::index_t M = 512; - ck_tile::index_t N = 256; - ck_tile::index_t K = 64; - ck_tile::index_t verification = 1; + // Setup simple argument parser for M, N, K + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "512", "m dimension") + .insert("n", "256", "n dimension") + .insert("k", "64", "k dimension") + .insert("v", "1", "verification: 0=off, 1=on"); + + auto result = arg_parser.parse(argc, argv); + if(!result) + return -1; + + // Get problem dimensions from command line + ck_tile::index_t M = arg_parser.get_int("m"); + ck_tile::index_t N = arg_parser.get_int("n"); + ck_tile::index_t K = arg_parser.get_int("k"); + ck_tile::index_t verification = arg_parser.get_int("v"); ck_tile::index_t stride_a = K; ck_tile::index_t stride_b = K; @@ -61,9 +72,6 @@ int main() ck_tile::DeviceMem c_device(c_host); // TODO: BlockTileConfig - // constexpr ck_tile::index_t warpSize = 64; - constexpr ck_tile::index_t kBlockSize = 256; - using BlockTile = ck_tile::sequence<256, 128, 32>; using WaveTile = ck_tile::sequence<16, 16, 16>; @@ -77,11 +85,13 @@ int main() ck_tile::index_t kGridSize = ck_tile::integer_divide_ceil(M, PracticeGemmShape::BlockTile_M) * ck_tile::integer_divide_ceil(N, PracticeGemmShape::BlockTile_N); - std::cout << "kGridSize: " << kGridSize << std::endl; + std::cout << "Total number of thread blocks: " << kGridSize << std::endl; constexpr ck_tile::index_t kBlockPerCU = 1; // 1 block per CU - std::cout << "kBlockSize: " << kBlockSize << std::endl; - std::cout << "kBlockPerCU: " << kBlockPerCU << std::endl; + // Block size is now derived from the shape configuration + constexpr ck_tile::index_t kBlockSize = PracticeGemmShape::kBlockSize; + std::cout << "Number of threads per block: " << kBlockSize << std::endl; + std::cout << "Number of blocks per compute unit: " << kBlockPerCU << std::endl; using gemm_kernel = ck_tile::PracticeGemmKernel; diff --git a/tutorial/ck_tile/01_naive_gemm/practice_gemm.hpp b/tutorial/ck_tile/01_naive_gemm/practice_gemm.hpp index 850e6ae3b3..91d7fae90c 100644 --- a/tutorial/ck_tile/01_naive_gemm/practice_gemm.hpp +++ b/tutorial/ck_tile/01_naive_gemm/practice_gemm.hpp @@ -24,6 +24,10 @@ struct PracticeGemmShape static constexpr index_t WaveTile_N = WaveTile::at(number<1>{}); static constexpr index_t WaveTile_K = WaveTile::at(number<2>{}); + // Thread block configuration + static constexpr index_t kWarpSize = 64; // AMD GPU warp size (also called wavefront) + static constexpr index_t kBlockSize = 256; // Total threads per block (4 warps × 64 threads) + CK_TILE_HOST static std::string GetName() { // clang-format off @@ -40,7 +44,8 @@ struct PracticeGemmKernel using Problem = remove_cvref_t; using Policy = remove_cvref_t; - static constexpr index_t kBlockSize = 256; + // Derive block size from the shape configuration + static constexpr index_t kBlockSize = Problem::Shape::kBlockSize; CK_TILE_DEVICE void operator()(const typename Problem::ADataType* p_a, const typename Problem::BDataType* p_b, From 6d25525adc2344d5b62b12b9ffddee50f89cd0ff Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Thu, 11 Dec 2025 10:50:43 +0400 Subject: [PATCH 05/10] feat(precommit-hooks): add check for correct copyright header (#3302) * chore(copyright): update copyright header for left files * feat(copyright): add copyright check to precommit hooks * chore(copyright): update copyright header for include/ck_tile directory * chore(copyright): update copyright header for example directory * chore(copyright): update copyright header for .github directory * refactor: copyright_check script with better if else handling * chore(copyright): update compyright header for remaining files * feat: add script to automate copyright addition --- .github/scripts/therock_configure_ci.py | 3 + .pre-commit-config.yaml | 12 +- include/ck_tile/core.hpp | 3 +- include/ck_tile/host.hpp | 3 +- include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp | 3 +- include/ck_tile/ops/batched_contraction.hpp | 3 +- include/ck_tile/ops/batched_transpose.hpp | 3 +- include/ck_tile/ops/common.hpp | 3 +- include/ck_tile/ops/elementwise.hpp | 3 +- include/ck_tile/ops/epilogue.hpp | 3 +- include/ck_tile/ops/flatmm.hpp | 3 +- include/ck_tile/ops/fmha.hpp | 3 +- include/ck_tile/ops/fused_moe.hpp | 3 +- include/ck_tile/ops/gemm.hpp | 3 +- include/ck_tile/ops/gemm_quant.hpp | 3 +- include/ck_tile/ops/grouped_convolution.hpp | 3 +- include/ck_tile/ops/image_to_column.hpp | 3 +- include/ck_tile/ops/layernorm2d.hpp | 3 +- include/ck_tile/ops/norm_reduce.hpp | 3 +- include/ck_tile/ops/permute.hpp | 3 +- include/ck_tile/ops/pooling.hpp | 3 +- include/ck_tile/ops/reduce.hpp | 3 +- include/ck_tile/ops/rmsnorm2d.hpp | 3 +- include/ck_tile/ops/smoothquant.hpp | 3 +- include/ck_tile/ops/softmax.hpp | 3 +- include/ck_tile/ops/topk.hpp | 3 +- include/ck_tile/ops/topk_softmax.hpp | 3 +- include/ck_tile/remod.py | 5 +- script/check_copyright_year.sh | 70 ++++- script/update_amd_copyright_headers.py | 295 ++++++++++++++++++ .../ck_tile/core/arch/mma/test_amdgcn_mma.cpp | 4 +- test/ck_tile/core/arch/test_arch.cpp | 4 +- tile_engine/include/utility/validation.hpp | 2 +- tile_engine/ops/gemm_streamk/CMakeLists.txt | 3 + .../gemm_streamk/gemm_streamk_benchmark.hpp | 2 +- .../gemm_streamk_benchmark_single.cpp | 2 +- .../ops/gemm_streamk/gemm_streamk_common.hpp | 2 +- .../gemm_streamk_instance_builder.py | 3 + .../gemm_streamk/gemm_streamk_profiler.hpp | 2 +- .../gemm_streamk_validation_utils.py | 2 +- 40 files changed, 408 insertions(+), 78 deletions(-) create mode 100644 script/update_amd_copyright_headers.py diff --git a/.github/scripts/therock_configure_ci.py b/.github/scripts/therock_configure_ci.py index 860b6bf875..c892941fc6 100644 --- a/.github/scripts/therock_configure_ci.py +++ b/.github/scripts/therock_configure_ci.py @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + import fnmatch import json import os diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 04ebc6b45a..71e7ccdb81 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,12 +20,12 @@ repos: )$ - repo: local hooks: - # - id: copyright-year-checker - # name: copyright-year-checker - # entry: script/check_copyright_year.sh - # verbose: false - # language: script - # types: [c++] + - id: copyright-header-checker + name: Check copyright headers + entry: script/check_copyright_year.sh + verbose: false + language: script + types_or: [c++, python, shell, cmake] - id: remove-exec-bit name: Remove executable bit from non-executable files entry: script/remove_exec_bit.sh diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index 5c05e9b6ee..d28d29a0ef 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/core/algorithm/cluster_descriptor.hpp" diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index c769e3e247..b543fd84e9 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/host/arg_parser.hpp" diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp index 6c0972e10a..00234b20cf 100644 --- a/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp" diff --git a/include/ck_tile/ops/batched_contraction.hpp b/include/ck_tile/ops/batched_contraction.hpp index 2232ec1261..45fa52e505 100644 --- a/include/ck_tile/ops/batched_contraction.hpp +++ b/include/ck_tile/ops/batched_contraction.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp" diff --git a/include/ck_tile/ops/batched_transpose.hpp b/include/ck_tile/ops/batched_transpose.hpp index 5822d7b91b..b23e45c233 100644 --- a/include/ck_tile/ops/batched_transpose.hpp +++ b/include/ck_tile/ops/batched_transpose.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp" diff --git a/include/ck_tile/ops/common.hpp b/include/ck_tile/ops/common.hpp index eff2d625b3..94243e674f 100644 --- a/include/ck_tile/ops/common.hpp +++ b/include/ck_tile/ops/common.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/common/generic_2d_block_shape.hpp" diff --git a/include/ck_tile/ops/elementwise.hpp b/include/ck_tile/ops/elementwise.hpp index 7f2303932e..5752703ab6 100644 --- a/include/ck_tile/ops/elementwise.hpp +++ b/include/ck_tile/ops/elementwise.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/elementwise/binary_elementwise_operation.hpp" diff --git a/include/ck_tile/ops/epilogue.hpp b/include/ck_tile/ops/epilogue.hpp index ec5a8ef445..555402b53a 100644 --- a/include/ck_tile/ops/epilogue.hpp +++ b/include/ck_tile/ops/epilogue.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp" diff --git a/include/ck_tile/ops/flatmm.hpp b/include/ck_tile/ops/flatmm.hpp index 7ef2fd5433..2d3a819e80 100644 --- a/include/ck_tile/ops/flatmm.hpp +++ b/include/ck_tile/ops/flatmm.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp" diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index 5b87a821c9..20714397c9 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" diff --git a/include/ck_tile/ops/fused_moe.hpp b/include/ck_tile/ops/fused_moe.hpp index 71721f3408..e6802e82dc 100644 --- a/include/ck_tile/ops/fused_moe.hpp +++ b/include/ck_tile/ops/fused_moe.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp" diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index ec2d2488c8..d518a15b7e 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp" diff --git a/include/ck_tile/ops/gemm_quant.hpp b/include/ck_tile/ops/gemm_quant.hpp index 3e16d937cb..7dc5b40286 100644 --- a/include/ck_tile/ops/gemm_quant.hpp +++ b/include/ck_tile/ops/gemm_quant.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp" diff --git a/include/ck_tile/ops/grouped_convolution.hpp b/include/ck_tile/ops/grouped_convolution.hpp index 23a72d79e9..6743e46613 100644 --- a/include/ck_tile/ops/grouped_convolution.hpp +++ b/include/ck_tile/ops/grouped_convolution.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp" diff --git a/include/ck_tile/ops/image_to_column.hpp b/include/ck_tile/ops/image_to_column.hpp index 2307b05190..1d33ebf39d 100644 --- a/include/ck_tile/ops/image_to_column.hpp +++ b/include/ck_tile/ops/image_to_column.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp" diff --git a/include/ck_tile/ops/layernorm2d.hpp b/include/ck_tile/ops/layernorm2d.hpp index 9ce22137bf..ebb20aebf4 100644 --- a/include/ck_tile/ops/layernorm2d.hpp +++ b/include/ck_tile/ops/layernorm2d.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp" diff --git a/include/ck_tile/ops/norm_reduce.hpp b/include/ck_tile/ops/norm_reduce.hpp index aa074b7f9f..469a98c256 100644 --- a/include/ck_tile/ops/norm_reduce.hpp +++ b/include/ck_tile/ops/norm_reduce.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp" diff --git a/include/ck_tile/ops/permute.hpp b/include/ck_tile/ops/permute.hpp index 46512c57fe..88a3d8a137 100644 --- a/include/ck_tile/ops/permute.hpp +++ b/include/ck_tile/ops/permute.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/permute/kernel/generic_permute_kernel.hpp" diff --git a/include/ck_tile/ops/pooling.hpp b/include/ck_tile/ops/pooling.hpp index 084b498203..3e44122afa 100644 --- a/include/ck_tile/ops/pooling.hpp +++ b/include/ck_tile/ops/pooling.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/pooling/kernel/pool_kernel.hpp" diff --git a/include/ck_tile/ops/reduce.hpp b/include/ck_tile/ops/reduce.hpp index d628e9c945..57f3f3c80a 100644 --- a/include/ck_tile/ops/reduce.hpp +++ b/include/ck_tile/ops/reduce.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/reduce/block/block_reduce.hpp" diff --git a/include/ck_tile/ops/rmsnorm2d.hpp b/include/ck_tile/ops/rmsnorm2d.hpp index 00afcf4aed..ad23a708b7 100644 --- a/include/ck_tile/ops/rmsnorm2d.hpp +++ b/include/ck_tile/ops/rmsnorm2d.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp" diff --git a/include/ck_tile/ops/smoothquant.hpp b/include/ck_tile/ops/smoothquant.hpp index 1aa14c69e1..13372f3289 100644 --- a/include/ck_tile/ops/smoothquant.hpp +++ b/include/ck_tile/ops/smoothquant.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp" diff --git a/include/ck_tile/ops/softmax.hpp b/include/ck_tile/ops/softmax.hpp index d559dc15e2..9cf3e08319 100644 --- a/include/ck_tile/ops/softmax.hpp +++ b/include/ck_tile/ops/softmax.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/softmax/block/block_softmax_2d.hpp" diff --git a/include/ck_tile/ops/topk.hpp b/include/ck_tile/ops/topk.hpp index 040c6b8ddc..090ad0919f 100644 --- a/include/ck_tile/ops/topk.hpp +++ b/include/ck_tile/ops/topk.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/topk/block/block_topk_stream_2d.hpp" diff --git a/include/ck_tile/ops/topk_softmax.hpp b/include/ck_tile/ops/topk_softmax.hpp index d9657a9764..7afce1708b 100644 --- a/include/ck_tile/ops/topk_softmax.hpp +++ b/include/ck_tile/ops/topk_softmax.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp" diff --git a/include/ck_tile/remod.py b/include/ck_tile/remod.py index aeec7bd471..51f3941233 100644 --- a/include/ck_tile/remod.py +++ b/include/ck_tile/remod.py @@ -1,7 +1,6 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -from datetime import datetime import pathlib from pathlib import Path import subprocess @@ -13,8 +12,8 @@ OPS = "ops" OPS_COMMON = "common" # common header will be duplicated into ops/* other module IGNORED_DIRS = ["utility", "ref"] -HEADER_COMMON = f"""// SPDX-License-Identifier: MIT -// Copyright (c) 2018-{datetime.now().year}, Advanced Micro Devices, Inc. All rights reserved.\n +HEADER_COMMON = """// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT """ diff --git a/script/check_copyright_year.sh b/script/check_copyright_year.sh index 1b63c6b711..48c050c76b 100755 --- a/script/check_copyright_year.sh +++ b/script/check_copyright_year.sh @@ -2,18 +2,70 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT +# This script checks if files have the correct copyright header template. +# It supports .hpp, .cpp, .inc, .py, .sh, and .cmake files. +# +# Usage: ./check_copyright_year.sh ... -current_year=$(date +%Y) exit_code=0 -for file in $@; do - if grep -q "Copyright (c)" $file - then - if ! grep -q "Copyright (c).*$current_year" $file - then - echo "ERROR: File $file has a copyright notice without the current year ($current_year)." - exit_code=1 - fi +# Expected copyright header lines (without comment characters) +COPYRIGHT_LINE="Copyright (c) Advanced Micro Devices, Inc., or its affiliates." +SPDX_LINE="SPDX-License-Identifier: MIT" + +check_file() { + local file=$1 + local basename="${file##*/}" + local ext="${file##*.}" + local comment_char + + # Determine comment character based on filename or extension + if [[ "$basename" == "CMakeLists.txt" ]]; then + comment_char="#" + else + case "$ext" in + cpp|hpp|inc) + comment_char="//" + ;; + py|sh|cmake) + comment_char="#" + ;; + *) + # Skip files with unsupported extensions + return 0 + ;; + esac + fi + + # Build expected header patterns + expected_copyright="$comment_char $COPYRIGHT_LINE" + expected_spdx="$comment_char $SPDX_LINE" + + # Check if file contains both required lines + if ! grep -qF "$expected_copyright" "$file"; then + echo "ERROR: File $file is missing the correct copyright header line." + echo " Expected: $expected_copyright" + return 1 + fi + + if ! grep -qF "$expected_spdx" "$file"; then + echo "ERROR: File $file is missing the correct SPDX license identifier line." + echo " Expected: $expected_spdx" + return 1 + fi + + return 0 +} + +# Process each file provided as argument +for file in "$@"; do + # Skip if file doesn't exist or is a directory + if [[ ! -f "$file" ]]; then + continue + fi + + if ! check_file "$file"; then + exit_code=1 fi done diff --git a/script/update_amd_copyright_headers.py b/script/update_amd_copyright_headers.py new file mode 100644 index 0000000000..489b774e97 --- /dev/null +++ b/script/update_amd_copyright_headers.py @@ -0,0 +1,295 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Purpose: + Normalize and enforce AMD two-line copyright + SPDX headers across files. + +Target files: + - C/C++-style: .cpp, .hpp, .inc -> uses "//" comment style + - Hash-style: .py, .cmake, .sh, and CMakeLists.txt -> uses "#" style + +Header formats inserted (top of file, followed by exactly one blank line): + C/C++ : + // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. + // SPDX-License-Identifier: MIT + + Hash : + + +Shebang special case (hash-style only): + - If line 1 starts with "#!", keep shebang, then a blank line, then the + two hash-style header lines, then a blank line. + +Removal rules: + - Remove any comment lines (anywhere in file) containing the keywords + "copyright" or "spdx" (case-insensitive). Blank lines are preserved. + - Remove long-form MIT license block comment when: + a) The file starts with the block (absolute top), OR + b) The block appears immediately after the AMD header position + (i.e., when remainder at insertion point begins with "/*" and + the first content line is "* The MIT License (MIT)"). + +Blank-line normalization: + - Enforce exactly ONE blank line immediately after the AMD header. + (Drop only the leading blank lines at the insertion point before + re-inserting the header.) + - Do not change blank lines between other non-copyright comments. + +Preservation: + - Preserve original newline style: CRLF (\r\n) vs LF (\n). + - Preserve UTF-8 BOM if present. + - Do not modify non-comment code lines. + +Idempotency: + - Running this script multiple times does not further modify files. +""" + +from __future__ import annotations +import re +import sys +from pathlib import Path +from typing import List, Tuple + +AMD_CPP_HEADER_TEXT = [ + "// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.", + "// SPDX-License-Identifier: MIT", +] +AMD_HASH_HEADER_TEXT = [ + "# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.", + "# SPDX-License-Identifier: MIT", +] + +CPP_EXTS = {".cpp", ".hpp", ".inc"} +HASH_EXTS = {".py", ".cmake", ".sh"} + +# --- Encoding helpers ------------------------------------------------------- + + +def has_bom(raw: bytes) -> bool: + return raw.startswith(b"\xef\xbb\xbf") + + +def decode_text(raw: bytes) -> str: + return raw.decode("utf-8-sig", errors="replace") + + +def encode_text(text: str, bom: bool) -> bytes: + data = text.encode("utf-8") + return (b"\xef\xbb\xbf" + data) if bom else data + + +# --- Newline detection ------------------------------------------------------ + + +def detect_newline_sequence(raw: bytes) -> str: + if b"\r\n" in raw: + return "\r\n" + elif b"\n" in raw: + return "\n" + else: + return "\n" + + +# --- Utilities -------------------------------------------------------------- + + +def is_comment_line(line: str, style: str) -> bool: + stripped = line.lstrip() + if style == "cpp": + return ( + stripped.startswith("//") + or stripped.startswith("/*") + or stripped.startswith("*") + or stripped.startswith("*/") + ) + elif style == "hash": + return stripped.startswith("#") + return False + + +def has_keywords(line: str) -> bool: + lower_line = line.lower() + return ("copyright" in lower_line) or ("spdx" in lower_line) + + +# --- MIT License banner detection ------------------------------ +MIT_C_FIRST_LINE_RE = re.compile(r"^\s*\*\s*The MIT License \(MIT\)") +MIT_HASH_FIRST_LINE_RE = re.compile(r"^\s*#\s*The MIT License \(MIT\)") + + +def remove_top_mit_block(lines: List[str]) -> Tuple[List[str], bool]: + """ + Unified MIT banner removal at the top of 'lines'. + Supports: + - C-style block starting with '/*' and ending with '*/'; removes only if + a line within the block matches MIT_C_FIRST_LINE_RE. + - Hash-style banner: contiguous top run of lines starting with '#'; + removes only if any line in that run matches MIT_HASH_FIRST_LINE_RE. + Returns (new_lines, removed_flag). Preserves EOLs. + """ + if not lines: + return lines, False + + first = lines[0].lstrip() + + # C-style block + if first.startswith("/*"): + end_idx, saw_mit = None, False + for i, line in enumerate(lines[1:], 1): + if not saw_mit and MIT_C_FIRST_LINE_RE.match(line): + saw_mit = True + s = line.lstrip() + if s.startswith("*/") or s.rstrip().endswith("*/"): + end_idx = i + 1 + break + if end_idx is not None and saw_mit: + return lines[end_idx:], True + return lines, False + + # Hash-style contiguous banner + if first.startswith("#"): + end_idx, saw_mit = 0, False + for i, line in enumerate(lines): + if line.lstrip().startswith("#"): + if not saw_mit and MIT_HASH_FIRST_LINE_RE.match(line): + saw_mit = True + end_idx = i + 1 + else: + break + if saw_mit: + return lines[end_idx:], True + return lines, False + + return lines, False + + +# --- Removal + normalization helpers --------------------------------------- + + +def remove_keyword_comment_lines_globally(lines: List[str], style: str) -> List[str]: + """Remove comment lines containing keywords anywhere in the file. + **Do not** remove blank lines; preserve all other lines as-is.""" + out: List[str] = [] + for line in lines: + if is_comment_line(line, style) and has_keywords(line): + continue + out.append(line) + return out + + +def drop_leading_blank_lines(lines: List[str]) -> List[str]: + """Drop only the leading blank lines at the start of the given list.""" + i = 0 + while i < len(lines) and lines[i].strip() == "": + i += 1 + return lines[i:] + + +# --- Header builder --------------------------------------------------------- + + +def build_header_lines(style: str, nl: str) -> List[str]: + base = AMD_CPP_HEADER_TEXT if style == "cpp" else AMD_HASH_HEADER_TEXT + return [base[0] + nl, base[1] + nl, nl] # header + exactly one blank + + +# --- Main transforms -------------------------------------------------------- + + +def process_cpp(text: str, nl: str) -> str: + lines = text.splitlines(True) + + # Remove MIT block if it is at the *absolute* top + lines, _ = remove_top_mit_block(lines) + + # Remove keyworded comment lines globally (blank lines preserved) + lines = remove_keyword_comment_lines_globally(lines, style="cpp") + + # Normalize insertion point and remove MIT block if it appears *after header* + lines = drop_leading_blank_lines(lines) + lines, _ = remove_top_mit_block(lines) + + # Prepend AMD header (guarantee exactly one blank after) + return "".join(build_header_lines("cpp", nl) + lines) + + +def process_hash(text: str, nl: str) -> str: + lines = text.splitlines(True) + if not lines: + return "".join(build_header_lines("hash", nl)) + + shebang = lines[0].startswith("#!") + + if shebang: + remainder = remove_keyword_comment_lines_globally(lines[1:], style="hash") + remainder = drop_leading_blank_lines(remainder) + remainder, _ = remove_top_mit_block(remainder) # remove MIT block after header + new_top = [lines[0], nl] + build_header_lines("hash", nl) + return "".join(new_top + remainder) + else: + remainder = remove_keyword_comment_lines_globally(lines, style="hash") + remainder = drop_leading_blank_lines(remainder) + remainder, _ = remove_top_mit_block(remainder) # remove MIT block after header + return "".join(build_header_lines("hash", nl) + remainder) + + +# --- File processing & CLI -------------------------------------------------- + + +def process_file(path: Path) -> bool: + name = path.name + suffix = path.suffix.lower() + if suffix in CPP_EXTS: + style = "cpp" + elif suffix in HASH_EXTS or name == "CMakeLists.txt": + style = "hash" + else: + return False + + raw = path.read_bytes() + bom = has_bom(raw) + nl = detect_newline_sequence(raw) + text = decode_text(raw) + + updated = process_cpp(text, nl) if style == "cpp" else process_hash(text, nl) + if updated != text: + path.write_bytes(encode_text(updated, bom)) + return True + return False + + +def main(argv: List[str]) -> int: + if len(argv) < 2: + print(__doc__) + return 2 + changed = 0 + skipped = 0 + errors: List[str] = [] + for arg in argv[1:]: + p = Path(arg) + try: + if not p.exists(): + errors.append(f"Not found: {p}") + continue + if p.is_dir(): + errors.append(f"Is a directory (pass specific files): {p}") + continue + if process_file(p): + changed += 1 + print(f"Updated: {p}") + else: + skipped += 1 + print(f"Skipped (no change needed or unsupported type): {p}") + except Exception as e: + errors.append(f"Error processing {p}: {e}") + print(f"\nSummary: {changed} updated, {skipped} skipped, {len(errors)} errors") + for msg in errors: + print(f" - {msg}") + return 0 if not errors else 1 + + +if __name__ == "__main__": + raise SystemExit(main(sys.argv)) diff --git a/test/ck_tile/core/arch/mma/test_amdgcn_mma.cpp b/test/ck_tile/core/arch/mma/test_amdgcn_mma.cpp index 4121e199e2..c7093e3477 100644 --- a/test/ck_tile/core/arch/mma/test_amdgcn_mma.cpp +++ b/test/ck_tile/core/arch/mma/test_amdgcn_mma.cpp @@ -1,5 +1,5 @@ -// Copyright © Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT #include #include diff --git a/test/ck_tile/core/arch/test_arch.cpp b/test/ck_tile/core/arch/test_arch.cpp index 2d553c1595..f015d3ce0a 100644 --- a/test/ck_tile/core/arch/test_arch.cpp +++ b/test/ck_tile/core/arch/test_arch.cpp @@ -1,5 +1,5 @@ -// Copyright © Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT #include #include "ck_tile/core/arch/arch.hpp" diff --git a/tile_engine/include/utility/validation.hpp b/tile_engine/include/utility/validation.hpp index dc57e6cc6a..f10f37fbaa 100644 --- a/tile_engine/include/utility/validation.hpp +++ b/tile_engine/include/utility/validation.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c), Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/tile_engine/ops/gemm_streamk/CMakeLists.txt b/tile_engine/ops/gemm_streamk/CMakeLists.txt index acfd78edc5..c692a6d247 100644 --- a/tile_engine/ops/gemm_streamk/CMakeLists.txt +++ b/tile_engine/ops/gemm_streamk/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + set(GEMM_STREAMK_DATATYPE "fp8;fp16" CACHE STRING "List of datatypes for GEMM (semicolon-separated)") set(GEMM_STREAMK_LAYOUT "rcr" CACHE STRING "List of layout for GEMM (semicolon-separated)") set(GEMM_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)") diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark.hpp b/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark.hpp index fa8a019be5..45beb0acce 100644 --- a/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark.hpp +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark_single.cpp b/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark_single.cpp index 5e88dc486a..9dbba04082 100644 --- a/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark_single.cpp +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark_single.cpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_common.hpp b/tile_engine/ops/gemm_streamk/gemm_streamk_common.hpp index 15a3c91964..2708ac2e56 100644 --- a/tile_engine/ops/gemm_streamk/gemm_streamk_common.hpp +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_common.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py b/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py index 6aebc54564..2225619fad 100644 --- a/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py @@ -1,4 +1,7 @@ #!/usr/bin/env python +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + import os import json diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_profiler.hpp b/tile_engine/ops/gemm_streamk/gemm_streamk_profiler.hpp index 256e0b9ca4..0541116522 100644 --- a/tile_engine/ops/gemm_streamk/gemm_streamk_profiler.hpp +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_profiler.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_validation_utils.py b/tile_engine/ops/gemm_streamk/gemm_streamk_validation_utils.py index 2288d7752f..bef3cdfe85 100644 --- a/tile_engine/ops/gemm_streamk/gemm_streamk_validation_utils.py +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_validation_utils.py @@ -1,6 +1,6 @@ #!/usr/bin/env python +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. """ Validation utilities for GEMM kernel generation. From d66e5f667c9d36b9c4ad8fa0cae7dd48ec9d5ebb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <188998872+vpietila-amd@users.noreply.github.com> Date: Thu, 11 Dec 2025 09:50:00 +0200 Subject: [PATCH 06/10] [CK_BUILDER] Improve CK Builder and CK Builder tests (#3382) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Remove stale documentation. * Add placeholder for conv algorithm design description. Add link to conv factory description. * Improve testing transfer parameters. * Python script to check the block tilings. * Improve tests and conv types serialization. * Change representation of boolean values from 1/0 to true/false in instance strings. * Change representation of boolean values from 1/0 to true/false in conv algorithm types. * Test code improvements. * Improve covn descriptions tests. * Improve conv signature definition in conv fwd builder tests. * clang-format. * Remove obsolete script. * Revert StaticAssertTypeEq changes in conv layout tests. * Remove obsolete using declaration. --------- Co-authored-by: Ville Pietilä <> --- .../builder/include/ck_tile/builder/README.md | 30 +- .../factory/helpers/ck/conv_tensor_type.hpp | 12 +- ...ped_conv_fwd_multiple_abd_xdl_cshuffle.hpp | 10 +- ..._conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp | 10 +- ...d_multiple_d_xdl_large_tensor_cshuffle.hpp | 10 +- .../builder/include/ck_tile/builder/types.hpp | 326 ++++--- .../conv/ck/test_ckb_conv_fwd_1d_bf16.cpp | 23 +- .../conv/ck/test_ckb_conv_fwd_1d_fp16.cpp | 21 +- .../test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp | 21 +- .../conv/ck/test_ckb_conv_fwd_2d_bf16.cpp | 40 +- ...est_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp | 30 +- .../conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp | 41 +- .../conv/ck/test_ckb_conv_fwd_2d_fp16.cpp | 21 +- .../conv/ck/test_ckb_conv_fwd_2d_fp32.cpp | 21 +- .../test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp | 21 +- ...test_ckb_conv_fwd_2d_large_tensor_fp16.cpp | 41 +- .../conv/ck/test_ckb_conv_fwd_3d_bf16.cpp | 24 +- .../conv/ck/test_ckb_conv_fwd_3d_fp16.cpp | 24 +- .../conv/ck/test_ckb_conv_fwd_3d_fp32.cpp | 24 +- .../builder/test/conv/ck/test_conv_traits.cpp | 52 +- .../test/impl/conv_signature_types.hpp | 4 +- ..._grouped_convolution_forward_convscale.cpp | 864 +++++++++--------- ...grouped_convolution_forward_dynamic_op.cpp | 48 +- ...rouped_convolution_forward_scaleadd_ab.cpp | 96 +- ...olution_forward_scaleadd_scaleadd_relu.cpp | 96 +- .../builder/test/test_conv_description.cpp | 48 +- .../builder/test/test_fwd_instance_traits.cpp | 12 +- .../test_instance_string_fwd_grp_conv.cpp | 4 +- ...tance_string_fwd_grp_conv_large_tensor.cpp | 4 +- .../test_instance_string_fwd_grp_conv_v3.cpp | 4 +- .../builder/test/test_testing_utils.cpp | 4 +- .../builder/test/unit_conv_tensor_layout.cpp | 278 +++--- .../test/utils/conv_algorithm_type_utils.hpp | 346 +++++++ 33 files changed, 1568 insertions(+), 1042 deletions(-) create mode 100644 experimental/builder/test/utils/conv_algorithm_type_utils.hpp diff --git a/experimental/builder/include/ck_tile/builder/README.md b/experimental/builder/include/ck_tile/builder/README.md index a0522a50d6..8075e33220 100644 --- a/experimental/builder/include/ck_tile/builder/README.md +++ b/experimental/builder/include/ck_tile/builder/README.md @@ -4,14 +4,16 @@ This directory contains the builder framework for Composable Kernel, which provi ## Table of Contents -- [Convolution Signature Design](#convolution-signature-design) +- [Convolution Signature](#convolution-signature) - [Overview](#overview) - [Architecture](#architecture) - [Core Components](#core-components) - [Concepts and Validation](#concepts-and-validation) +- [Convolution Algorithm](#convolution-algorithm) +- [Convolution Factory](#convolution-factory) --- -## Convolution Signature Design +## Convolution Signature ### Overview @@ -220,25 +222,9 @@ Several fields in the signature are optional: This design follows the principle of "make the common case simple, the complex case possible." -#### Union-Based Layout Representation +## Convolution Algorithm -The `ConvLayout` type uses unions to support dimension-agnostic code: +## Convolution Factory -```cpp -struct ConvLayout { - union { - ConvInputLayout _input_layout; - ConvWeightLayout _weight_layout; - ConvOutputLayout _output_layout; - ConvAuxiliaryTensorLayout _aux_tensor_layout; - }; - // ... constructors for each type -}; -``` - -This allows: -- Single type to represent all layout variants -- Type-safe construction through overloaded constructors -- Compile-time enforcement of valid combinations through concepts - ---- +Convolution factory builds the instance based on the convolution signature and convolution algorithm. +The signature and the algorithm descriptions are dispatched to the relevant algorithm specific factory for instance creation. The convolution factory design is described in a separate [Readme](factory/README.md). diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp index 81de2140f2..c819e11d00 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp @@ -65,17 +65,19 @@ consteval auto GetTensorDataAndComputeTypes() constexpr auto data_type = Config.data_type; constexpr auto compute_type = Config.compute_type; - if constexpr(data_type == DataType::UNDEFINDED && compute_type == DataType::UNDEFINDED) + using enum DataType; + + if constexpr(data_type == UNDEFINED_DATA_TYPE && compute_type == UNDEFINED_DATA_TYPE) { return std::make_pair(ConvertDataTypeToCK(), ConvertDataTypeToCK()); } - else if constexpr(data_type == DataType::UNDEFINDED) + else if constexpr(data_type == UNDEFINED_DATA_TYPE) { return std::make_pair(ConvertDataTypeToCK(), ConvertDataTypeToCK()); } - else if constexpr(compute_type == DataType::UNDEFINDED) + else if constexpr(compute_type == UNDEFINED_DATA_TYPE) { return std::make_pair(ConvertDataTypeToCK(), ConvertDataTypeToCK()); @@ -91,7 +93,7 @@ template consteval auto GetTensorAccumulationType() { constexpr auto data_type = SignatureAccDataType; - if constexpr(data_type == DataType::UNDEFINDED) + if constexpr(data_type == DataType::UNDEFINED_DATA_TYPE) { return ConvertDataTypeToCK(); } @@ -105,7 +107,7 @@ template consteval auto GetAuxiliaryTensorDataTypeValue() { constexpr auto data_type = Config.data_type; - if constexpr(data_type == DataType::UNDEFINDED) + if constexpr(data_type == DataType::UNDEFINED_DATA_TYPE) { return ConvertDataTypeToCK(); } diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index 126be93f01..f5f3df3159 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -316,7 +316,7 @@ struct InstanceTraits; + + const auto expected_transfer_parameters = to_string(FwdConvAlgorithm); run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3", - "256,256,256,32", + expected_transfer_parameters, "NGCW,GKXC,EmptyTuple,NGKW", "PassThrough,PassThrough,Scale", "Filter1x1Stride1Pad0", diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp index e8cd8fb136..6802e0caf8 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp @@ -3,6 +3,7 @@ #include "utils/ckb_conv_test_configs.hpp" #include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" namespace { @@ -12,13 +13,17 @@ using namespace ck_tile::builder::test_utils; TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Instance_1D_FP16_ChannelsFirst) { + using enum ck_tile::builder::ConvDirection; + using enum ck_tile::builder::DataType; + using enum ck_tile::builder::TensorLayout; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 1, - .direction = ConvDirection::FORWARD, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NWGC}}, - .weight = {.config = {.layout = TensorLayout::GKXC}}, - .output = {.config = {.layout = TensorLayout::NWGK}}}; + .direction = FORWARD, + .data_type = FP16, + .accumulation_data_type = FP32, + .input = {.config = {.layout = NWGC}}, + .weight = {.config = {.layout = GKXC}}, + .output = {.config = {.layout = NWGK}}}; constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} @@ -29,11 +34,13 @@ TEST(FwdConvInstances, .with_prefetch_config(1, 2, PipelineScheduler::DEFAULT); using Builder = ConvBuilder; + + const auto expected_transfer_parameters = to_string(FwdConvAlgorithm); run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle", + expected_transfer_parameters, "NWGC,GKXC,EmptyTuple,NWGK", "PassThrough,PassThrough,PassThrough", "MNKPadding", - "64,64,32,32", "Default"}); } diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp index 014e221101..14463bbc17 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp @@ -3,6 +3,7 @@ #include "utils/ckb_conv_test_configs.hpp" #include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" namespace { @@ -14,13 +15,17 @@ using namespace ck_tile::builder::test_utils; TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_Instance_1D_FP32_ChannelsFirst_scale) { + using enum ck_tile::builder::ConvDirection; + using enum ck_tile::builder::DataType; + using enum ck_tile::builder::TensorLayout; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 1, - .direction = ConvDirection::FORWARD, - .data_type = DataType::I8, - .accumulation_data_type = DataType::INT32, - .input = {.config = {.layout = TensorLayout::GNWC}}, - .weight = {.config = {.layout = TensorLayout::GKXC}}, - .output = {.config = {.layout = TensorLayout::GNWK}}}; + .direction = FORWARD, + .data_type = I8, + .accumulation_data_type = INT32, + .input = {.config = {.layout = GNWC}}, + .weight = {.config = {.layout = GKXC}}, + .output = {.config = {.layout = GNWK}}}; constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle{} @@ -31,8 +36,10 @@ TEST(FwdConvInstances, .with_prefetch_config(1, 0, PipelineScheduler::DEFAULT); using Builder = ConvBuilder; + + const auto expected_transfer_parameters = to_string(FwdConvAlgorithm); run_test({"DeviceGroupedConvFwdMultipleD_Wmma_CShuffle", - "128,64,64,64", + expected_transfer_parameters, "GNWC,GKXC,EmptyTuple,GNWK", "PassThrough,PassThrough,PassThrough", "Default"}); diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp index b98e28c45a..4a5618a6b1 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp @@ -3,6 +3,7 @@ #include "utils/ckb_conv_test_configs.hpp" #include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" namespace { @@ -12,13 +13,17 @@ using namespace ck_tile::builder::test_utils; TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_ChannelsLast) { + using enum ck_tile::builder::ConvDirection; + using enum ck_tile::builder::DataType; + using enum ck_tile::builder::TensorLayout; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, - .direction = ConvDirection::FORWARD, - .data_type = DataType::BF16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NHWGC}}, - .weight = {.config = {.layout = TensorLayout::GKYXC}}, - .output = {.config = {.layout = TensorLayout::NHWGK}}}; + .direction = FORWARD, + .data_type = BF16, + .accumulation_data_type = FP32, + .input = {.config = {.layout = NHWGC}}, + .weight = {.config = {.layout = GKYXC}}, + .output = {.config = {.layout = NHWGK}}}; constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} @@ -29,8 +34,10 @@ TEST(FwdConvInstances, .with_block_gemm(BlockGemmDesc_v1_intrawave); using Builder = ConvBuilder; + + const auto expected_transfer_parameters = to_string(FwdConvAlgorithm); run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3", - "256,256,256,32", + expected_transfer_parameters, "Default", "NHWGC,GKYXC,EmptyTuple,NHWGK", "PassThrough,PassThrough,PassThrough", @@ -43,13 +50,17 @@ TEST(FwdConvInstances, TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_NHWGC_Filter3x3) { + using enum ck_tile::builder::ConvDirection; + using enum ck_tile::builder::DataType; + using enum ck_tile::builder::TensorLayout; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, - .direction = ConvDirection::FORWARD, - .data_type = DataType::BF16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NHWGC}}, - .weight = {.config = {.layout = TensorLayout::GKYXC}}, - .output = {.config = {.layout = TensorLayout::NHWGK}}}; + .direction = FORWARD, + .data_type = BF16, + .accumulation_data_type = FP32, + .input = {.config = {.layout = NHWGC}}, + .weight = {.config = {.layout = GKYXC}}, + .output = {.config = {.layout = NHWGK}}}; constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} @@ -60,7 +71,10 @@ TEST(FwdConvInstances, .with_block_gemm(BlockGemmDesc_v5_intrawave); using Builder = ConvBuilder; + + const auto expected_transfer_parameters = to_string(FwdConvAlgorithm); run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3", + expected_transfer_parameters, "Filter3x3", "NHWGC,GKYXC,EmptyTuple,NHWGK", "PassThrough,PassThrough,PassThrough", diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp index bc4a5e1047..e3dc261fe3 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp @@ -3,6 +3,7 @@ #include "utils/ckb_conv_test_configs.hpp" #include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" namespace { @@ -12,19 +13,22 @@ using namespace ck_tile::builder::test_utils; TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Instance_2D_BF16_scale_add_relu) { + using enum ck_tile::builder::ConvDirection; + using enum ck_tile::builder::DataType; + using enum ck_tile::builder::TensorLayout; + using enum ck_tile::builder::ElementwiseOperation; + constexpr ConvSignature FwdConvSignature{ .spatial_dim = 2, - .direction = ConvDirection::FORWARD, - .data_type = DataType::BF16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NHWGC}}, - .weight = {.config = {.layout = TensorLayout::GKYXC, .data_type = DataType::BF16}}, - .output = ConvolutionTensor{ - .config = {.layout = TensorLayout::NHWGK}, - .operation = TensorOperation<>{.elementwise_operation = - ElementwiseOperation::SCALEADD_SCALEADD_RELU} - .with_auxiliary_operand_configs()}}; + .direction = FORWARD, + .data_type = BF16, + .accumulation_data_type = FP32, + .input = {.config = {.layout = NHWGC}}, + .weight = {.config = {.layout = GKYXC, .data_type = BF16}}, + .output = ConvolutionTensor{ + .config = {.layout = NHWGK}, + .operation = TensorOperation<>{.elementwise_operation = SCALEADD_SCALEADD_RELU} + .with_auxiliary_operand_configs()}}; constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} @@ -35,10 +39,12 @@ TEST(FwdConvInstances, .with_prefetch_config(1, 1, PipelineScheduler::DEFAULT); using Builder = ConvBuilder; + + const auto expected_transfer_parameters = to_string(FwdConvAlgorithm); run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle", + expected_transfer_parameters, "NHWGC,GKYXC,Tuple(NHWGK,G_K),NHWGK", "PassThrough,PassThrough,ScaleAddScaleAddRelu", - "64,64,32,32", "MNKPadding", "Default"}); } diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp index 7af1448403..9bea834ef9 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp @@ -3,6 +3,7 @@ #include "utils/ckb_conv_test_configs.hpp" #include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" namespace { @@ -10,13 +11,17 @@ using namespace ck_tile::builder::test_utils; TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_GNHWC) { + using enum ck_tile::builder::ConvDirection; + using enum ck_tile::builder::DataType; + using enum ck_tile::builder::TensorLayout; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, - .direction = ConvDirection::FORWARD, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::GNHWC}}, - .weight = {.config = {.layout = TensorLayout::GKYXC}}, - .output = {.config = {.layout = TensorLayout::GNHWK}}}; + .direction = FORWARD, + .data_type = FP16, + .accumulation_data_type = FP32, + .input = {.config = {.layout = GNHWC}}, + .weight = {.config = {.layout = GKYXC}}, + .output = {.config = {.layout = GNHWK}}}; constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK{} @@ -27,8 +32,10 @@ TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Ins .with_dl_transfer(DlFwdTransfer); using Builder = ConvBuilder; + + const auto expected_transfer_parameters = to_string(FwdConvAlgorithm); run_test({"DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK", - "256,128,128,16", + expected_transfer_parameters, "Default", "MNKPadding", "GNHWC,GKYXC,EmptyTuple,GNHWK", @@ -38,13 +45,17 @@ TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Ins TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_FILTER_1X1_PAD0) { + using enum ck_tile::builder::ConvDirection; + using enum ck_tile::builder::DataType; + using enum ck_tile::builder::TensorLayout; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, - .direction = ConvDirection::FORWARD, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::GNHWC}}, - .weight = {.config = {.layout = TensorLayout::GKYXC}}, - .output = {.config = {.layout = TensorLayout::GNHWK}}}; + .direction = FORWARD, + .data_type = FP16, + .accumulation_data_type = FP32, + .input = {.config = {.layout = GNHWC}}, + .weight = {.config = {.layout = GKYXC}}, + .output = {.config = {.layout = GNHWK}}}; constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK{} @@ -56,8 +67,10 @@ TEST(FwdConvInstances, .with_dl_transfer(DlFwdTransfer); using Builder = ConvBuilder; + + const auto expected_transfer_parameters = to_string(FwdConvAlgorithm); run_test({"DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK", - "256,128,128,16", + expected_transfer_parameters, "Filter1x1Pad0", "MNKPadding", "GNHWC,GKYXC,EmptyTuple,GNHWK", diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp index 7b522403d3..bba0128810 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp @@ -3,6 +3,7 @@ #include "utils/ckb_conv_test_configs.hpp" #include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" namespace { @@ -11,13 +12,17 @@ using namespace ck_tile::builder::test_utils; TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP16_GNHWC) { + using enum ck_tile::builder::ConvDirection; + using enum ck_tile::builder::DataType; + using enum ck_tile::builder::TensorLayout; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, - .direction = ConvDirection::FORWARD, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::GNHWC}}, - .weight = {.config = {.layout = TensorLayout::GKYXC}}, - .output = {.config = {.layout = TensorLayout::GNHWK}}}; + .direction = FORWARD, + .data_type = FP16, + .accumulation_data_type = FP32, + .input = {.config = {.layout = GNHWC}}, + .weight = {.config = {.layout = GKYXC}}, + .output = {.config = {.layout = GNHWK}}}; constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} @@ -29,8 +34,10 @@ TEST(FwdConvInstances, .with_block_gemm(BlockGemmDesc_v3_intrawave); using Builder = ConvBuilder; + + const auto expected_transfer_parameters = to_string(FwdConvAlgorithm); run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3", - "256,256,256,32", + expected_transfer_parameters, "Filter1x1Pad0", "Intrawave", "v3", diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp index 615d098c7c..79ee4915e8 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp @@ -3,6 +3,7 @@ #include "utils/ckb_conv_test_configs.hpp" #include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" namespace { @@ -11,13 +12,17 @@ using namespace ck_tile::builder::test_utils; TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP32_NGCHW_GKCYX) { + using enum ck_tile::builder::ConvDirection; + using enum ck_tile::builder::DataType; + using enum ck_tile::builder::TensorLayout; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, - .direction = ConvDirection::FORWARD, - .data_type = DataType::FP32, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NGCHW}}, - .weight = {.config = {.layout = TensorLayout::GKCYX}}, - .output = {.config = {.layout = TensorLayout::NGKHW}}}; + .direction = FORWARD, + .data_type = FP32, + .accumulation_data_type = FP32, + .input = {.config = {.layout = NGCHW}}, + .weight = {.config = {.layout = GKCYX}}, + .output = {.config = {.layout = NGKHW}}}; constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} @@ -29,8 +34,10 @@ TEST(FwdConvInstances, .with_block_gemm(BlockGemmDesc_v4_intrawave); using Builder = ConvBuilder; + + const auto expected_transfer_parameters = to_string(FwdConvAlgorithm); run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3", - "256,128,128,32", + expected_transfer_parameters, "Filter1x1Stride1Pad0", "Intrawave", "v4", diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp index 4dd9e2beef..3e3d7e8c2b 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp @@ -3,6 +3,7 @@ #include "utils/ckb_conv_test_configs.hpp" #include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" namespace { @@ -12,13 +13,17 @@ using namespace ck_tile::builder::test_utils; TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Instance_2D_FP8_ChannelsLast) { + using enum ck_tile::builder::ConvDirection; + using enum ck_tile::builder::DataType; + using enum ck_tile::builder::TensorLayout; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, - .direction = ConvDirection::FORWARD, - .data_type = DataType::FP8, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NHWGC}}, - .weight = {.config = {.layout = TensorLayout::GKYXC}}, - .output = {.config = {.layout = TensorLayout::NHWGK}}}; + .direction = FORWARD, + .data_type = FP8, + .accumulation_data_type = FP32, + .input = {.config = {.layout = NHWGC}}, + .weight = {.config = {.layout = GKYXC}}, + .output = {.config = {.layout = NHWGK}}}; constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} @@ -29,8 +34,10 @@ TEST(FwdConvInstances, .with_prefetch_config(1, 1, PipelineScheduler::DEFAULT); using Builder = ConvBuilder; + + const auto expected_transfer_parameters = to_string(FwdConvAlgorithm); run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle", - "256,256,128,32", + expected_transfer_parameters, "Default", "NHWGC,GKYXC,EmptyTuple,NHWGK", "PassThrough,PassThrough,PassThrough", diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp index 8fe58dbe82..3019c57a18 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp @@ -3,6 +3,7 @@ #include "utils/ckb_conv_test_configs.hpp" #include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" namespace { @@ -11,13 +12,17 @@ using namespace ck_tile::builder::test_utils; TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Instance_2D_FP16_GNHWC) { + using enum ck_tile::builder::ConvDirection; + using enum ck_tile::builder::DataType; + using enum ck_tile::builder::TensorLayout; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, - .direction = ConvDirection::FORWARD, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::GNHWC}}, - .weight = {.config = {.layout = TensorLayout::GKYXC}}, - .output = {.config = {.layout = TensorLayout::GNHWK}}}; + .direction = FORWARD, + .data_type = FP16, + .accumulation_data_type = FP32, + .input = {.config = {.layout = GNHWC}}, + .weight = {.config = {.layout = GKYXC}}, + .output = {.config = {.layout = GNHWK}}}; constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{ @@ -30,8 +35,10 @@ TEST(FwdConvInstances, .with_prefetch_config(1, 1, PipelineScheduler::DEFAULT)}; using Builder = ConvBuilder; + + const auto expected_transfer_parameters = to_string(FwdConvAlgorithm); run_test({"DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor", - "256,256,128,32", + expected_transfer_parameters, "Default", "GNHWC,GKYXC,EmptyTuple,GNHWK", "PassThrough,PassThrough,PassThrough", @@ -42,13 +49,17 @@ TEST( FwdConvInstances, Create_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Instance_2D_FP16_GNHWC_Filter1x1Pad0) { + using enum ck_tile::builder::ConvDirection; + using enum ck_tile::builder::DataType; + using enum ck_tile::builder::TensorLayout; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, - .direction = ConvDirection::FORWARD, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::GNHWC}}, - .weight = {.config = {.layout = TensorLayout::GKYXC}}, - .output = {.config = {.layout = TensorLayout::GNHWK}}}; + .direction = FORWARD, + .data_type = FP16, + .accumulation_data_type = FP32, + .input = {.config = {.layout = GNHWC}}, + .weight = {.config = {.layout = GKYXC}}, + .output = {.config = {.layout = GNHWK}}}; constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{ @@ -61,8 +72,10 @@ TEST( .with_prefetch_config(1, 1, PipelineScheduler::DEFAULT)}; using Builder = ConvBuilder; + + const auto expected_transfer_parameters = to_string(FwdConvAlgorithm); run_test({"DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor", - "128,128,128,32", + expected_transfer_parameters, "Filter1x1Pad0", "GNHWC,GKYXC,EmptyTuple,GNHWK", "PassThrough,PassThrough,PassThrough", diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp index 2df76ab3e0..3f9bdfb972 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp @@ -3,6 +3,7 @@ #include "utils/ckb_conv_test_configs.hpp" #include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" namespace { @@ -12,14 +13,17 @@ using namespace ck_tile::builder::test_utils; TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_BF16_GNDHWC) { - constexpr ConvSignature FwdConvSignature{ - .spatial_dim = 3, - .direction = ConvDirection::FORWARD, - .data_type = DataType::BF16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::GNDHWC}}, - .weight = {.config = {.layout = TensorLayout::GKZYXC}}, - .output = {.config = {.layout = TensorLayout::GNDHWK}}}; + using enum ck_tile::builder::ConvDirection; + using enum ck_tile::builder::DataType; + using enum ck_tile::builder::TensorLayout; + + constexpr ConvSignature FwdConvSignature{.spatial_dim = 3, + .direction = FORWARD, + .data_type = BF16, + .accumulation_data_type = FP32, + .input = {.config = {.layout = GNDHWC}}, + .weight = {.config = {.layout = GKZYXC}}, + .output = {.config = {.layout = GNDHWK}}}; constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} @@ -30,8 +34,10 @@ TEST(FwdConvInstances, .with_block_gemm(BlockGemmDesc_v3_intrawave); using Builder = ConvBuilder; + + const auto expected_transfer_parameters = to_string(FwdConvAlgorithm); run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3", - "256,256,256,32", + expected_transfer_parameters, "Default", "Intrawave", "v3", diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp index ad626d9a15..11c8172533 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp @@ -3,6 +3,7 @@ #include "utils/ckb_conv_test_configs.hpp" #include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" namespace { @@ -12,14 +13,17 @@ using namespace ck_tile::builder::test_utils; TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP16_NDHWGC_ChannelsLast) { - constexpr ConvSignature FwdConvSignature{ - .spatial_dim = 3, - .direction = ConvDirection::FORWARD, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NDHWGC}}, - .weight = {.config = {.layout = TensorLayout::GKZYXC}}, - .output = {.config = {.layout = TensorLayout::NDHWGK}}}; + using enum ck_tile::builder::ConvDirection; + using enum ck_tile::builder::DataType; + using enum ck_tile::builder::TensorLayout; + + constexpr ConvSignature FwdConvSignature{.spatial_dim = 3, + .direction = FORWARD, + .data_type = FP16, + .accumulation_data_type = FP32, + .input = {.config = {.layout = NDHWGC}}, + .weight = {.config = {.layout = GKZYXC}}, + .output = {.config = {.layout = NDHWGK}}}; constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} @@ -31,8 +35,10 @@ TEST(FwdConvInstances, .with_block_gemm(BlockGemmDesc_v4_intrawave); using Builder = ConvBuilder; + + const auto expected_transfer_parameters = to_string(FwdConvAlgorithm); run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3", - "256,128,128,32", + expected_transfer_parameters, "Filter1x1Pad0", "Intrawave", "v4", diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp index 85974ace5d..33c01c8ac4 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp @@ -3,6 +3,7 @@ #include "utils/ckb_conv_test_configs.hpp" #include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" namespace { @@ -12,14 +13,17 @@ using namespace ck_tile::builder::test_utils; TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP32_ChannelsFirst) { - constexpr ConvSignature FwdConvSignature{ - .spatial_dim = 3, - .direction = ConvDirection::FORWARD, - .data_type = DataType::FP32, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NGCDHW}}, - .weight = {.config = {.layout = TensorLayout::GKCZYX}}, - .output = {.config = {.layout = TensorLayout::NGKDHW}}}; + using enum ck_tile::builder::ConvDirection; + using enum ck_tile::builder::DataType; + using enum ck_tile::builder::TensorLayout; + + constexpr ConvSignature FwdConvSignature{.spatial_dim = 3, + .direction = FORWARD, + .data_type = FP32, + .accumulation_data_type = FP32, + .input = {.config = {.layout = NGCDHW}}, + .weight = {.config = {.layout = GKCZYX}}, + .output = {.config = {.layout = NGKDHW}}}; constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} @@ -31,8 +35,10 @@ TEST(FwdConvInstances, .with_block_gemm(BlockGemmDesc_v1_intrawave); using Builder = ConvBuilder; + + const auto expected_transfer_parameters = to_string(FwdConvAlgorithm); run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3", - "256,256,256,32", + expected_transfer_parameters, "Filter1x1Pad0", "Intrawave", "v1", diff --git a/experimental/builder/test/conv/ck/test_conv_traits.cpp b/experimental/builder/test/conv/ck/test_conv_traits.cpp index a6a7694703..d052aba548 100644 --- a/experimental/builder/test/conv/ck/test_conv_traits.cpp +++ b/experimental/builder/test/conv/ck/test_conv_traits.cpp @@ -12,6 +12,12 @@ namespace { +using ck_tile::builder::ConvDirection; +using ck_tile::builder::DataType; +using ck_tile::builder::ElementwiseOperation; +using ck_tile::builder::PipelineScheduler; +using ck_tile::builder::PipelineVersion; +using ck_tile::builder::TensorLayout; using ::testing::ElementsAre; // Test fixture for ConvTraits tests @@ -84,15 +90,13 @@ TEST_F(ConvTraitsTest, ConvFwdTraitsExtraction) // Verify signature information EXPECT_EQ(Traits::spatial_dim, 2); - EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD); + EXPECT_EQ(Traits::direction, ConvDirection::FORWARD); EXPECT_THAT(Traits::layout, - ::testing::ElementsAre(ck_tile::builder::TensorLayout::GNHWC, - ck_tile::builder::TensorLayout::GKYXC, - ck_tile::builder::TensorLayout::GNHWK)); - EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16); - EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); - EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); - EXPECT_EQ(Traits::output_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); + ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); + EXPECT_EQ(Traits::data_type, DataType::FP16); + EXPECT_EQ(Traits::input_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(Traits::weight_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(Traits::output_element_op, ElementwiseOperation::PASS_THROUGH); // Verify specializations EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT); @@ -145,8 +149,8 @@ TEST_F(ConvTraitsTest, ConvFwdTraitsExtraction) EXPECT_EQ(Traits::c_tile_transfer.scalar_per_vector, 8); // Verify pipeline configuration - EXPECT_EQ(Traits::pipeline_scheduler, ck_tile::builder::PipelineScheduler::INTRAWAVE); - EXPECT_EQ(Traits::pipeline_version, ck_tile::builder::PipelineVersion::V1); + EXPECT_EQ(Traits::pipeline_scheduler, PipelineScheduler::INTRAWAVE); + EXPECT_EQ(Traits::pipeline_version, PipelineVersion::V1); } // Test ConvTraits with DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle @@ -214,15 +218,13 @@ TEST_F(ConvTraitsTest, ConvFwdBaseTraitsExtraction) // Verify signature information EXPECT_EQ(Traits::spatial_dim, 2); - EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD); + EXPECT_EQ(Traits::direction, ConvDirection::FORWARD); EXPECT_THAT(Traits::layout, - ::testing::ElementsAre(ck_tile::builder::TensorLayout::GNHWC, - ck_tile::builder::TensorLayout::GKYXC, - ck_tile::builder::TensorLayout::GNHWK)); - EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16); - EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); - EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); - EXPECT_EQ(Traits::output_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); + ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); + EXPECT_EQ(Traits::data_type, DataType::FP16); + EXPECT_EQ(Traits::input_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(Traits::weight_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(Traits::output_element_op, ElementwiseOperation::PASS_THROUGH); // Verify specializations EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT); @@ -300,15 +302,13 @@ TEST_F(ConvTraitsTest, ConvFwdLargeTensorTraitsExtraction) // Verify signature information EXPECT_EQ(Traits::spatial_dim, 2); - EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD); + EXPECT_EQ(Traits::direction, ConvDirection::FORWARD); EXPECT_THAT(Traits::layout, - ::testing::ElementsAre(ck_tile::builder::TensorLayout::GNHWC, - ck_tile::builder::TensorLayout::GKYXC, - ck_tile::builder::TensorLayout::GNHWK)); - EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16); - EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); - EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); - EXPECT_EQ(Traits::output_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); + ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); + EXPECT_EQ(Traits::data_type, DataType::FP16); + EXPECT_EQ(Traits::input_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(Traits::weight_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(Traits::output_element_op, ElementwiseOperation::PASS_THROUGH); // Verify specializations EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT); diff --git a/experimental/builder/test/impl/conv_signature_types.hpp b/experimental/builder/test/impl/conv_signature_types.hpp index ef87981c3d..f046289057 100644 --- a/experimental/builder/test/impl/conv_signature_types.hpp +++ b/experimental/builder/test/impl/conv_signature_types.hpp @@ -14,8 +14,8 @@ struct TensorConfig { TensorLayout layout; // Optional data types, override the type defined in the signature if provided. - DataType data_type{DataType::UNDEFINDED}; - DataType compute_type{DataType::UNDEFINDED}; + DataType data_type{DataType::UNDEFINED_DATA_TYPE}; + DataType compute_type{DataType::UNDEFINED_DATA_TYPE}; }; template diff --git a/experimental/builder/test/test_ck_factory_grouped_convolution_forward_convscale.cpp b/experimental/builder/test/test_ck_factory_grouped_convolution_forward_convscale.cpp index 3b3b0fa7e1..60af599551 100644 --- a/experimental/builder/test/test_ck_factory_grouped_convolution_forward_convscale.cpp +++ b/experimental/builder/test/test_ck_factory_grouped_convolution_forward_convscale.cpp @@ -76,54 +76,54 @@ struct F8_ConvScale constexpr static auto expected = { // clang-format off - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>" + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>" // clang-format on }; }; @@ -141,54 +141,54 @@ struct F8_BF8_comb1_ConvScale constexpr static auto expected = { // clang-format off - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,bf8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>" + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,bf8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>" // clang-format on }; }; @@ -206,54 +206,54 @@ struct F8_BF8_comb2_ConvScale constexpr static auto expected = { // clang-format off - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,bf8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>" + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,bf8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>" // clang-format on }; }; @@ -271,54 +271,54 @@ struct F8_BF8_comb3_ConvScale constexpr static auto expected = { // clang-format off - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,bf8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>" + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,bf8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>" // clang-format on }; }; @@ -336,54 +336,54 @@ struct F8_float_CombConvScale constexpr static auto expected = { // clang-format off - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>" + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>" // clang-format on }; }; @@ -401,54 +401,54 @@ struct F8_ConvScaleRelu constexpr static auto expected = { // clang-format off - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>" + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>" // clang-format on }; }; @@ -466,54 +466,54 @@ struct F8_CombConvScaleRelu constexpr static auto expected = { // clang-format off - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>" + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>" // clang-format on }; }; @@ -531,54 +531,54 @@ struct F8_ConvScaleAdd constexpr static auto expected = { // clang-format off - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>" + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>" // clang-format on }; }; @@ -596,54 +596,54 @@ struct F8_ConvInvscale constexpr static auto expected = { // clang-format off - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>" + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>" // clang-format on }; }; diff --git a/experimental/builder/test/test_ck_factory_grouped_convolution_forward_dynamic_op.cpp b/experimental/builder/test/test_ck_factory_grouped_convolution_forward_dynamic_op.cpp index 2e06ebc74c..6aa2f57db2 100644 --- a/experimental/builder/test/test_ck_factory_grouped_convolution_forward_dynamic_op.cpp +++ b/experimental/builder/test/test_ck_factory_grouped_convolution_forward_dynamic_op.cpp @@ -85,9 +85,9 @@ struct DyOp_F32_2 constexpr static auto expected = { // clang-format off - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,fp32,fp32,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,256,128,128,16,4,4,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,fp32,fp32,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,32,16,4,4,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,fp32,fp32,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,64,16,4,4,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>" + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,fp32,fp32,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,256,128,128,16,4,4,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,fp32,fp32,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,32,16,4,4,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,fp32,fp32,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,64,16,4,4,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>" // clang-format on }; }; @@ -98,9 +98,9 @@ struct DyOp_F32_3 constexpr static auto expected = { // clang-format off - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp32,fp32,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,256,128,128,16,4,4,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp32,fp32,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,32,16,4,4,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp32,fp32,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,64,16,4,4,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>" + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp32,fp32,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,256,128,128,16,4,4,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp32,fp32,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,32,16,4,4,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp32,fp32,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,64,16,4,4,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>" // clang-format on }; }; @@ -111,9 +111,9 @@ struct DyOp_F16_2 constexpr static auto expected = { // clang-format off - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,fp16,fp16,fp32,fp16,EmptyTuple,fp16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,fp16,fp16,fp32,fp16,EmptyTuple,fp16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,fp16,fp16,fp32,fp16,EmptyTuple,fp16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>" + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,fp16,fp16,fp32,fp16,EmptyTuple,fp16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,fp16,fp16,fp32,fp16,EmptyTuple,fp16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,fp16,fp16,fp32,fp16,EmptyTuple,fp16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>" // clang-format on }; }; @@ -124,9 +124,9 @@ struct DyOp_F16_3 constexpr static auto expected = { // clang-format off - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp16,fp16,fp32,fp16,EmptyTuple,fp16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp16,fp16,fp32,fp16,EmptyTuple,fp16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp16,fp16,fp32,fp16,EmptyTuple,fp16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>" + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp16,fp16,fp32,fp16,EmptyTuple,fp16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp16,fp16,fp32,fp16,EmptyTuple,fp16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp16,fp16,fp32,fp16,EmptyTuple,fp16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>" // clang-format on }; }; @@ -137,9 +137,9 @@ struct DyOp_BF16_2 constexpr static auto expected = { // clang-format off - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,bf16,bf16,fp32,bf16,EmptyTuple,bf16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,bf16,bf16,fp32,bf16,EmptyTuple,bf16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,bf16,bf16,fp32,bf16,EmptyTuple,bf16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>" + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,bf16,bf16,fp32,bf16,EmptyTuple,bf16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,bf16,bf16,fp32,bf16,EmptyTuple,bf16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,bf16,bf16,fp32,bf16,EmptyTuple,bf16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>" // clang-format on }; }; @@ -150,9 +150,9 @@ struct DyOp_BF16_3 constexpr static auto expected = { // clang-format off - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf16,bf16,fp32,bf16,EmptyTuple,bf16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf16,bf16,fp32,bf16,EmptyTuple,bf16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf16,bf16,fp32,bf16,EmptyTuple,bf16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>" + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf16,bf16,fp32,bf16,EmptyTuple,bf16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf16,bf16,fp32,bf16,EmptyTuple,bf16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf16,bf16,fp32,bf16,EmptyTuple,bf16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>" // clang-format on }; }; @@ -163,9 +163,9 @@ struct DyOp_INT8_2 constexpr static auto expected = { // clang-format off - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,s8,s8,s32,s8,EmptyTuple,s8,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,s8,s8,s32,s8,EmptyTuple,s8,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,s8,s8,s32,s8,EmptyTuple,s8,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>" + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,s8,s8,s32,s8,EmptyTuple,s8,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,s8,s8,s32,s8,EmptyTuple,s8,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,s8,s8,s32,s8,EmptyTuple,s8,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>" // clang-format on }; }; @@ -176,9 +176,9 @@ struct DyOp_INT8_3 constexpr static auto expected = { // clang-format off - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,s8,s8,s32,s8,EmptyTuple,s8,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,s8,s8,s32,s8,EmptyTuple,s8,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,s8,s8,s32,s8,EmptyTuple,s8,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>" + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,s8,s8,s32,s8,EmptyTuple,s8,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,s8,s8,s32,s8,EmptyTuple,s8,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,s8,s8,s32,s8,EmptyTuple,s8,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>" // clang-format on }; }; diff --git a/experimental/builder/test/test_ck_factory_grouped_convolution_forward_scaleadd_ab.cpp b/experimental/builder/test/test_ck_factory_grouped_convolution_forward_scaleadd_ab.cpp index 56843b214f..918642c266 100644 --- a/experimental/builder/test/test_ck_factory_grouped_convolution_forward_scaleadd_ab.cpp +++ b/experimental/builder/test/test_ck_factory_grouped_convolution_forward_scaleadd_ab.cpp @@ -53,18 +53,18 @@ struct F32 constexpr static auto expected = { // clang-format off - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,256,128,128,16,4,4,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,256,256,128,16,4,4,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,64,64,32,16,4,4,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,64,64,64,16,4,4,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,256,128,128,16,4,4,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,256,256,128,16,4,4,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,64,64,32,16,4,4,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,64,64,64,16,4,4,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,16,4,4,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,16,4,4,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,16,4,4,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,16,4,4,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>" + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,256,128,128,16,4,4,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,256,256,128,16,4,4,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,64,64,32,16,4,4,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,64,64,64,16,4,4,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,256,128,128,16,4,4,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,256,256,128,16,4,4,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,64,64,32,16,4,4,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,64,64,64,16,4,4,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,16,4,4,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,16,4,4,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,16,4,4,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,16,4,4,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>" // clang-format on }; }; @@ -75,18 +75,18 @@ struct F16 constexpr static auto expected = { // clang-format off - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>" + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>" // clang-format on }; }; @@ -97,18 +97,18 @@ struct BF16 constexpr static auto expected = { // clang-format off - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>" + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>" // clang-format on }; }; @@ -119,18 +119,18 @@ struct S8 constexpr static auto expected = { // clang-format off - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>" + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>" // clang-format on }; }; diff --git a/experimental/builder/test/test_ck_factory_grouped_convolution_forward_scaleadd_scaleadd_relu.cpp b/experimental/builder/test/test_ck_factory_grouped_convolution_forward_scaleadd_scaleadd_relu.cpp index a833a1fe87..74f5f5e231 100644 --- a/experimental/builder/test/test_ck_factory_grouped_convolution_forward_scaleadd_scaleadd_relu.cpp +++ b/experimental/builder/test/test_ck_factory_grouped_convolution_forward_scaleadd_scaleadd_relu.cpp @@ -54,18 +54,18 @@ struct F32 constexpr static auto expected = { // clang-format off - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,256,128,128,16,4,4,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,256,256,128,16,4,4,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,64,64,32,16,4,4,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,64,64,64,16,4,4,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,256,128,128,16,4,4,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,256,256,128,16,4,4,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,64,64,32,16,4,4,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,64,64,64,16,4,4,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,16,4,4,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,16,4,4,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,16,4,4,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,16,4,4,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>" + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,256,128,128,16,4,4,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,256,256,128,16,4,4,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,64,64,32,16,4,4,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,64,64,64,16,4,4,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,256,128,128,16,4,4,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,256,256,128,16,4,4,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,64,64,32,16,4,4,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,64,64,64,16,4,4,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,16,4,4,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,16,4,4,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,16,4,4,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,16,4,4,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>" // clang-format on }; }; @@ -76,18 +76,18 @@ struct F16 constexpr static auto expected = { // clang-format off - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>" + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>" // clang-format on }; }; @@ -98,18 +98,18 @@ struct BF16 constexpr static auto expected = { // clang-format off - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>" + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>" // clang-format on }; }; @@ -120,18 +120,18 @@ struct S8 constexpr static auto expected = { // clang-format off - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>", - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>" + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>" // clang-format on }; }; diff --git a/experimental/builder/test/test_conv_description.cpp b/experimental/builder/test/test_conv_description.cpp index 689577fb3b..ace9ce0239 100644 --- a/experimental/builder/test/test_conv_description.cpp +++ b/experimental/builder/test/test_conv_description.cpp @@ -30,8 +30,8 @@ static_assert(!ckb::TensorOperatorDescriptor); struct TensorConfig { ckb::TensorLayout layout; - ckb::DataType data_type{ckb::DataType::UNDEFINDED}; - ckb::DataType compute_type{ckb::DataType::UNDEFINDED}; + ckb::DataType data_type{ckb::DataType::UNDEFINED_DATA_TYPE}; + ckb::DataType compute_type{ckb::DataType::UNDEFINED_DATA_TYPE}; }; struct ConvTensorSimple @@ -55,39 +55,49 @@ struct ConvTensorWithInvalidOp // This includes dimensionality, direction, data layout, and data type. struct ConvSignature { + using enum ckb::DataType; + using enum ckb::TensorLayout; + int spatial_dim = 2; - ckb::DataType data_type = ckb::DataType::FP16; - ckb::DataType accumulation_data_type = ckb::DataType::FP32; - ConvTensorSimple input = {.config = {ckb::TensorLayout::GNHWC}}; - ConvTensorSimple weight = {.config = {ckb::TensorLayout::GKYXC}}; - ConvTensorSimple output = {.config = {ckb::TensorLayout::GNHWK}}; + ckb::DataType data_type = FP16; + ckb::DataType accumulation_data_type = FP32; + ConvTensorSimple input = {.config = {GNHWC}}; + ConvTensorSimple weight = {.config = {GKYXC}}; + ConvTensorSimple output = {.config = {GNHWK}}; }; static_assert(ckb::ConvSignatureDescriptor); // Compile time tests for concepts struct ConvSignatureWithOptionalParams { + using enum ckb::DataType; + using enum ckb::TensorLayout; + using enum ckb::ConvDirection; + using enum ckb::ElementwiseOperation; + int spatial_dim = 2; - ckb::DataType data_type = ckb::DataType::FP16; - ckb::DataType accumulation_data_type = ckb::DataType::FP32; - ckb::ConvDirection direction = ckb::ConvDirection::FORWARD; + ckb::DataType data_type = FP16; + ckb::DataType accumulation_data_type = FP32; + ckb::ConvDirection direction = FORWARD; ConvTensorWithOp input = { - .config = {ckb::TensorLayout::GNHWC, ckb::DataType::FP16}, + .config = {GNHWC, FP16}, }; - ConvTensorWithOp weight = {.config = {ckb::TensorLayout::GKYXC, ckb::DataType::FP16}}; - ConvTensorWithOp output = {.config = {ckb::TensorLayout::GNHWK, ckb::DataType::FP16}, - .operation = {ckb::ElementwiseOperation::SCALE}}; + ConvTensorWithOp weight = {.config = {GKYXC, FP16}}; + ConvTensorWithOp output = {.config = {GNHWK, FP16}, .operation = {SCALE}}; }; static_assert(ckb::ConvSignatureDescriptor); struct ConvSignatureWithInvalidOptionalParams { + using enum ckb::DataType; + using enum ckb::TensorLayout; + int spatial_dim = 2; - ckb::DataType data_type = ckb::DataType::FP16; - ckb::DataType accumulation_data_type = ckb::DataType::FP32; - ConvTensorWithInvalidOp input = {.config = {ckb::TensorLayout::GNHWC}}; - ConvTensorWithInvalidOp weight = {.config = {ckb::TensorLayout::GKYXC}}; - ConvTensorWithInvalidOp output = {.config = {ckb::TensorLayout::GNHWK}}; + ckb::DataType data_type = FP16; + ckb::DataType accumulation_data_type = FP32; + ConvTensorWithInvalidOp input = {.config = {GNHWC}}; + ConvTensorWithInvalidOp weight = {.config = {GKYXC}}; + ConvTensorWithInvalidOp output = {.config = {GNHWK}}; }; static_assert(!ckb::ConvSignatureDescriptor); diff --git a/experimental/builder/test/test_fwd_instance_traits.cpp b/experimental/builder/test/test_fwd_instance_traits.cpp index 396533cef4..6dd2a4eada 100644 --- a/experimental/builder/test/test_fwd_instance_traits.cpp +++ b/experimental/builder/test/test_fwd_instance_traits.cpp @@ -262,14 +262,14 @@ TEST(InstanceTraits, V3InstanceStringReturnsCorrectFormat) ",2" // ABlockTransferSrcVectorDim ",8" // ABlockTransferSrcScalarPerVector ",8" // ABlockTransferDstScalarPerVector_AK1 - ",1" // ABlockLdsExtraM + ",true" // ABlockLdsExtraM ",Seq(4,64,1)" // BBlockTransferThreadClusterLengths ",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder ",Seq(1,0,2)" // BBlockTransferSrcAccessOrder ",2" // BBlockTransferSrcVectorDim ",8" // BBlockTransferSrcScalarPerVector ",8" // BBlockTransferDstScalarPerVector_BK1 - ",1" // BBlockLdsExtraN + ",true" // BBlockLdsExtraN ",1" // CShuffleMXdlPerWavePerShuffle ",1" // CShuffleNXdlPerWavePerShuffle ",Seq(1,32,1,8)" // CDEBlockTransferClusterLengths @@ -377,14 +377,14 @@ TEST(InstanceTraits, BaseInstanceStringReturnsCorrectFormat) ",2" // ABlockTransferSrcVectorDim ",8" // ABlockTransferSrcScalarPerVector ",8" // ABlockTransferDstScalarPerVector_AK1 - ",1" // ABlockLdsExtraM + ",true" // ABlockLdsExtraM ",Seq(4,64,1)" // BBlockTransferThreadClusterLengths ",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder ",Seq(1,0,2)" // BBlockTransferSrcAccessOrder ",2" // BBlockTransferSrcVectorDim ",8" // BBlockTransferSrcScalarPerVector ",8" // BBlockTransferDstScalarPerVector_BK1 - ",1" // BBlockLdsExtraN + ",true" // BBlockLdsExtraN ",1" // CShuffleMXdlPerWavePerShuffle ",1" // CShuffleNXdlPerWavePerShuffle ",Seq(1,32,1,8)" // CDEBlockTransferClusterLengths @@ -492,14 +492,14 @@ TEST(InstanceTraits, LargeTensorInstanceStringReturnsCorrectFormat) ",2" // ABlockTransferSrcVectorDim ",8" // ABlockTransferSrcScalarPerVector ",8" // ABlockTransferDstScalarPerVector_AK1 - ",1" // ABlockLdsExtraM + ",true" // ABlockLdsExtraM ",Seq(4,64,1)" // BBlockTransferThreadClusterLengths ",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder ",Seq(1,0,2)" // BBlockTransferSrcAccessOrder ",2" // BBlockTransferSrcVectorDim ",8" // BBlockTransferSrcScalarPerVector ",8" // BBlockTransferDstScalarPerVector_BK1 - ",1" // BBlockLdsExtraN + ",true" // BBlockLdsExtraN ",1" // CShuffleMXdlPerWavePerShuffle ",1" // CShuffleNXdlPerWavePerShuffle ",Seq(1,32,1,8)" // CDEBlockTransferClusterLengths diff --git a/experimental/builder/test/test_instance_string_fwd_grp_conv.cpp b/experimental/builder/test/test_instance_string_fwd_grp_conv.cpp index 9929f276a7..35f3db1469 100644 --- a/experimental/builder/test/test_instance_string_fwd_grp_conv.cpp +++ b/experimental/builder/test/test_instance_string_fwd_grp_conv.cpp @@ -60,14 +60,14 @@ std::string expected_str = "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle" ",2" // ABlockTransferSrcVectorDim ",1" // ABlockTransferSrcScalarPerVector ",8" // ABlockTransferDstScalarPerVector_AK1 - ",1" // ABlockLdsExtraM + ",true" // ABlockLdsExtraM ",Seq(4,16,1)" // BBlockTransferThreadClusterLengths ",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder ",Seq(1,0,2)" // BBlockTransferSrcAccessOrder ",2" // BBlockTransferSrcVectorDim ",1" // BBlockTransferSrcScalarPerVector ",8" // BBlockTransferDstScalarPerVector_BK1 - ",1" // BBlockLdsExtraN + ",true" // BBlockLdsExtraN ",1" // CShuffleMXdlPerWavePerShuffle ",1" // CShuffleNXdlPerWavePerShuffle ",Seq(1,16,1,4)" // CDEBlockTransferClusterLengths diff --git a/experimental/builder/test/test_instance_string_fwd_grp_conv_large_tensor.cpp b/experimental/builder/test/test_instance_string_fwd_grp_conv_large_tensor.cpp index aecce25f1d..26b50bea6d 100644 --- a/experimental/builder/test/test_instance_string_fwd_grp_conv_large_tensor.cpp +++ b/experimental/builder/test/test_instance_string_fwd_grp_conv_large_tensor.cpp @@ -60,14 +60,14 @@ std::string expected_str = "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Ten ",2" // ABlockTransferSrcVectorDim ",1" // ABlockTransferSrcScalarPerVector ",8" // ABlockTransferDstScalarPerVector_AK1 - ",1" // ABlockLdsExtraM + ",true" // ABlockLdsExtraM ",Seq(4,16,1)" // BBlockTransferThreadClusterLengths ",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder ",Seq(1,0,2)" // BBlockTransferSrcAccessOrder ",2" // BBlockTransferSrcVectorDim ",1" // BBlockTransferSrcScalarPerVector ",8" // BBlockTransferDstScalarPerVector_BK1 - ",1" // BBlockLdsExtraN + ",true" // BBlockLdsExtraN ",1" // CShuffleMXdlPerWavePerShuffle ",1" // CShuffleNXdlPerWavePerShuffle ",Seq(1,16,1,4)" // CDEBlockTransferClusterLengths diff --git a/experimental/builder/test/test_instance_string_fwd_grp_conv_v3.cpp b/experimental/builder/test/test_instance_string_fwd_grp_conv_v3.cpp index 7eeaec8e25..604667dd10 100644 --- a/experimental/builder/test/test_instance_string_fwd_grp_conv_v3.cpp +++ b/experimental/builder/test/test_instance_string_fwd_grp_conv_v3.cpp @@ -60,14 +60,14 @@ std::string expected_str = "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3" ",2" // ABlockTransferSrcVectorDim ",8" // ABlockTransferSrcScalarPerVector ",8" // ABlockTransferDstScalarPerVector_AK1 - ",0" // ABlockLdsExtraM + ",false" // ABlockLdsExtraM ",Seq(8,32,1)" // BBlockTransferThreadClusterLengths ",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder ",Seq(1,0,2)" // BBlockTransferSrcAccessOrder ",2" // BBlockTransferSrcVectorDim ",8" // BBlockTransferSrcScalarPerVector ",8" // BBlockTransferDstScalarPerVector_BK1 - ",0" // BBlockLdsExtraN + ",false" // BBlockLdsExtraN ",1" // CShuffleMXdlPerWavePerShuffle ",1" // CShuffleNXdlPerWavePerShuffle ",Seq(1,32,1,8)" // CDEBlockTransferClusterLengths diff --git a/experimental/builder/test/test_testing_utils.cpp b/experimental/builder/test/test_testing_utils.cpp index 694bec4c20..dd65f3f327 100644 --- a/experimental/builder/test/test_testing_utils.cpp +++ b/experimental/builder/test/test_testing_utils.cpp @@ -34,8 +34,8 @@ TEST(InstanceSet, FromFactory) const auto* el = "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,fp16,fp16," "fp32,fp16,EmptyTuple,fp16,PassThrough,PassThrough,PassThrough,Default,MNKPadding,1,128," - "128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2)," - "Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp16,fp16,Default,1>"; + "128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1)," + "Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp16,fp16,Default,1>"; EXPECT_THAT(instances.instances, testing::Contains(el)); } diff --git a/experimental/builder/test/unit_conv_tensor_layout.cpp b/experimental/builder/test/unit_conv_tensor_layout.cpp index 26df33cc8d..ce31f41933 100644 --- a/experimental/builder/test/unit_conv_tensor_layout.cpp +++ b/experimental/builder/test/unit_conv_tensor_layout.cpp @@ -9,27 +9,34 @@ namespace { -namespace ckb = ::ck_tile::builder; -using ::ck_tile::builder::DataType; -using ::ck_tile::builder::ElementwiseOperation; -using ::ck_tile::builder::TensorLayout; -using ::ck_tile::builder::factory::internal::AuxiliaryTensorLayouts; -using ::ck_tile::builder::factory::internal::ConvTensorLayouts; -using ::ck_tile::builder::factory::internal::LayoutToCK; +namespace ckb = ck_tile::builder; +using ck_tile::builder::DataType; +using ck_tile::builder::ElementwiseOperation; +using ck_tile::builder::TensorLayout; +using ck_tile::builder::factory::internal::AuxiliaryTensorLayouts; +using ck_tile::builder::factory::internal::ConvTensorLayouts; +using ck_tile::builder::factory::internal::LayoutToCK; +using ck_tile::builder::test::ConvolutionTensor; +using ck_tile::builder::test::ConvSignature; +using ck_tile::builder::test::TensorConfig; +using ck_tile::builder::test::TensorOperation; -using namespace ::ck_tile::builder::test; -using enum ::ck_tile::builder::ConvDirection; +namespace enums { +using enum ck_tile::builder::ConvDirection; +using enum ck_tile::builder::TensorLayout; +using enum ck_tile::builder::DataType; +} // namespace enums TEST(ConvTensorLayout, AssignsLayoutsFor1D_NWGC_GKXC_NWGK) { - static constexpr auto sig = - ConvSignature<>{.spatial_dim = 1, - .direction = FORWARD, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NWGC}}, - .weight = {.config = {.layout = TensorLayout::GKXC}}, - .output = {.config = {.layout = TensorLayout::NWGK}}}; + using namespace enums; + static constexpr auto sig = ConvSignature<>{.spatial_dim = 1, + .direction = FORWARD, + .data_type = FP16, + .accumulation_data_type = FP32, + .input = {.config = {.layout = NWGC}}, + .weight = {.config = {.layout = GKXC}}, + .output = {.config = {.layout = NWGK}}}; using TensorLayouts = ConvTensorLayouts; @@ -41,14 +48,14 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_NWGC_GKXC_NWGK) TEST(ConvTensorLayout, AssignsLayoutsFor1D_NGCW_GKXC_NGKW) { - static constexpr auto sig = - ConvSignature<>{.spatial_dim = 1, - .direction = FORWARD, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NGCW}}, - .weight = {.config = {.layout = TensorLayout::GKXC}}, - .output = {.config = {.layout = TensorLayout::NGKW}}}; + using namespace enums; + static constexpr auto sig = ConvSignature<>{.spatial_dim = 1, + .direction = FORWARD, + .data_type = FP16, + .accumulation_data_type = FP32, + .input = {.config = {.layout = NGCW}}, + .weight = {.config = {.layout = GKXC}}, + .output = {.config = {.layout = NGKW}}}; using TensorLayouts = ConvTensorLayouts; @@ -60,14 +67,14 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_NGCW_GKXC_NGKW) TEST(ConvTensorLayout, AssignsLayoutsFor1D_GNWC_GKXC_GNWK) { - static constexpr auto sig = - ConvSignature<>{.spatial_dim = 1, - .direction = FORWARD, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::GNWC}}, - .weight = {.config = {.layout = TensorLayout::GKXC}}, - .output = {.config = {.layout = TensorLayout::GNWK}}}; + using namespace enums; + static constexpr auto sig = ConvSignature<>{.spatial_dim = 1, + .direction = FORWARD, + .data_type = FP16, + .accumulation_data_type = FP32, + .input = {.config = {.layout = GNWC}}, + .weight = {.config = {.layout = GKXC}}, + .output = {.config = {.layout = GNWK}}}; using TensorLayouts = ConvTensorLayouts; @@ -79,14 +86,14 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_GNWC_GKXC_GNWK) TEST(ConvTensorLayout, AssignsLayoutsFor1D_NGCW_GKCX_NGKW) { - static constexpr auto sig = - ConvSignature<>{.spatial_dim = 1, - .direction = FORWARD, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NGCW}}, - .weight = {.config = {.layout = TensorLayout::GKCX}}, - .output = {.config = {.layout = TensorLayout::NGKW}}}; + using namespace enums; + static constexpr auto sig = ConvSignature<>{.spatial_dim = 1, + .direction = FORWARD, + .data_type = FP16, + .accumulation_data_type = FP32, + .input = {.config = {.layout = NGCW}}, + .weight = {.config = {.layout = GKCX}}, + .output = {.config = {.layout = NGKW}}}; using TensorLayouts = ConvTensorLayouts; @@ -98,14 +105,14 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_NGCW_GKCX_NGKW) TEST(ConvTensorLayout, AssignsLayoutsFor2D_NGCHW_GKYXC_NGKHW) { - static constexpr auto sig = - ConvSignature<>{.spatial_dim = 2, - .direction = FORWARD, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NGCHW}}, - .weight = {.config = {.layout = TensorLayout::GKYXC}}, - .output = {.config = {.layout = TensorLayout::NGKHW}}}; + using namespace enums; + static constexpr auto sig = ConvSignature<>{.spatial_dim = 2, + .direction = FORWARD, + .data_type = FP16, + .accumulation_data_type = FP32, + .input = {.config = {.layout = NGCHW}}, + .weight = {.config = {.layout = GKYXC}}, + .output = {.config = {.layout = NGKHW}}}; using TensorLayouts = ConvTensorLayouts; @@ -117,14 +124,14 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_NGCHW_GKYXC_NGKHW) TEST(ConvTensorLayout, AssignsLayoutsFor2D_NHWGC_GKYXC_NHWGK) { - static constexpr auto sig = - ConvSignature<>{.spatial_dim = 2, - .direction = FORWARD, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NHWGC}}, - .weight = {.config = {.layout = TensorLayout::GKYXC}}, - .output = {.config = {.layout = TensorLayout::NHWGK}}}; + using namespace enums; + static constexpr auto sig = ConvSignature<>{.spatial_dim = 2, + .direction = FORWARD, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = NHWGC}}, + .weight = {.config = {.layout = GKYXC}}, + .output = {.config = {.layout = NHWGK}}}; using TensorLayouts = ConvTensorLayouts; @@ -136,14 +143,14 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_NHWGC_GKYXC_NHWGK) TEST(ConvTensorLayout, AssignsLayoutsFor2D_GNHWC_GKYXC_GNHWK) { - static constexpr auto sig = - ConvSignature<>{.spatial_dim = 2, - .direction = FORWARD, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::GNHWC}}, - .weight = {.config = {.layout = TensorLayout::GKYXC}}, - .output = {.config = {.layout = TensorLayout::GNHWK}}}; + using namespace enums; + static constexpr auto sig = ConvSignature<>{.spatial_dim = 2, + .direction = FORWARD, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = GNHWC}}, + .weight = {.config = {.layout = GKYXC}}, + .output = {.config = {.layout = GNHWK}}}; using TensorLayouts = ConvTensorLayouts; @@ -155,14 +162,14 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_GNHWC_GKYXC_GNHWK) TEST(ConvTensorLayout, AssignsLayoutsFor2D_NGCHW_GKCYX_NGKHW) { - static constexpr auto sig = - ConvSignature<>{.spatial_dim = 2, - .direction = FORWARD, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NGCHW}}, - .weight = {.config = {.layout = TensorLayout::GKCYX}}, - .output = {.config = {.layout = TensorLayout::NGKHW}}}; + using namespace enums; + static constexpr auto sig = ConvSignature<>{.spatial_dim = 2, + .direction = FORWARD, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = NGCHW}}, + .weight = {.config = {.layout = GKCYX}}, + .output = {.config = {.layout = NGKHW}}}; using TensorLayouts = ConvTensorLayouts; @@ -174,14 +181,14 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_NGCHW_GKCYX_NGKHW) TEST(ConvTensorLayout, AssignsLayoutsFor3D_NGCDHW_GKCZYX_NGKDHW) { - static constexpr auto sig = - ConvSignature<>{.spatial_dim = 3, - .direction = FORWARD, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NGCDHW}}, - .weight = {.config = {.layout = TensorLayout::GKCZYX}}, - .output = {.config = {.layout = TensorLayout::NGKDHW}}}; + using namespace enums; + static constexpr auto sig = ConvSignature<>{.spatial_dim = 3, + .direction = FORWARD, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = NGCDHW}}, + .weight = {.config = {.layout = GKCZYX}}, + .output = {.config = {.layout = NGKDHW}}}; using TensorLayouts = ConvTensorLayouts; @@ -193,14 +200,14 @@ TEST(ConvTensorLayout, AssignsLayoutsFor3D_NGCDHW_GKCZYX_NGKDHW) TEST(ConvTensorLayout, AssignsLayoutsFor3D_NDHWGC_GKZYXC_NDHWGK) { - static constexpr auto sig = - ConvSignature<>{.spatial_dim = 3, - .direction = FORWARD, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NDHWGC}}, - .weight = {.config = {.layout = TensorLayout::GKZYXC}}, - .output = {.config = {.layout = TensorLayout::NDHWGK}}}; + using namespace enums; + static constexpr auto sig = ConvSignature<>{.spatial_dim = 3, + .direction = FORWARD, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = NDHWGC}}, + .weight = {.config = {.layout = GKZYXC}}, + .output = {.config = {.layout = NDHWGK}}}; using TensorLayouts = ConvTensorLayouts; @@ -212,14 +219,14 @@ TEST(ConvTensorLayout, AssignsLayoutsFor3D_NDHWGC_GKZYXC_NDHWGK) TEST(ConvTensorLayout, AssignsLayoutsFor3D_GNDHWC_GKZYXC_GNDHWK) { - static constexpr auto sig = - ConvSignature<>{.spatial_dim = 3, - .direction = FORWARD, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::GNDHWC}}, - .weight = {.config = {.layout = TensorLayout::GKZYXC}}, - .output = {.config = {.layout = TensorLayout::GNDHWK}}}; + using namespace enums; + static constexpr auto sig = ConvSignature<>{.spatial_dim = 3, + .direction = FORWARD, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = GNDHWC}}, + .weight = {.config = {.layout = GKZYXC}}, + .output = {.config = {.layout = GNDHWK}}}; using TensorLayouts = ConvTensorLayouts; @@ -261,8 +268,10 @@ struct MockAuxiliaryTensorConfig TEST(AuxiliaryTensorLayoutIntegration, SingleBiasTensorWithG_K_Layout) { + using namespace enums; + static constexpr std::array aux_configs = { - MockAuxiliaryTensorConfig{.layout = TensorLayout::G_K_strided}}; + MockAuxiliaryTensorConfig{.layout = G_K_strided}}; using AuxLayouts = AuxiliaryTensorLayouts; @@ -273,6 +282,8 @@ TEST(AuxiliaryTensorLayoutIntegration, SingleBiasTensorWithG_K_Layout) TEST(AuxiliaryTensorLayoutIntegration, SingleBiasTensorWithGC_Layout) { + using namespace enums; + static constexpr std::array aux_configs = { MockAuxiliaryTensorConfig{.layout = TensorLayout::GC}}; @@ -285,8 +296,10 @@ TEST(AuxiliaryTensorLayoutIntegration, SingleBiasTensorWithGC_Layout) TEST(AuxiliaryTensorLayoutIntegration, SingleBiasTensorWithG_C_Layout) { + using namespace enums; + static constexpr std::array aux_configs = { - MockAuxiliaryTensorConfig{.layout = TensorLayout::G_C_strided}}; + MockAuxiliaryTensorConfig{.layout = G_C_strided}}; using AuxLayouts = AuxiliaryTensorLayouts; @@ -297,9 +310,11 @@ TEST(AuxiliaryTensorLayoutIntegration, SingleBiasTensorWithG_C_Layout) TEST(AuxiliaryTensorLayoutIntegration, TwoAuxiliaryTensors) { + using namespace enums; + static constexpr std::array aux_configs = { MockAuxiliaryTensorConfig{.layout = TensorLayout::G_K_strided}, - MockAuxiliaryTensorConfig{.layout = TensorLayout::GC}}; + MockAuxiliaryTensorConfig{.layout = GC}}; using AuxLayouts = AuxiliaryTensorLayouts; @@ -311,10 +326,12 @@ TEST(AuxiliaryTensorLayoutIntegration, TwoAuxiliaryTensors) TEST(AuxiliaryTensorLayoutIntegration, ThreeAuxiliaryTensors) { + using namespace enums; + static constexpr std::array aux_configs = { - MockAuxiliaryTensorConfig{.layout = TensorLayout::G_K_strided}, - MockAuxiliaryTensorConfig{.layout = TensorLayout::GC}, - MockAuxiliaryTensorConfig{.layout = TensorLayout::G_C_strided}}; + MockAuxiliaryTensorConfig{.layout = G_K_strided}, + MockAuxiliaryTensorConfig{.layout = GC}, + MockAuxiliaryTensorConfig{.layout = G_C_strided}}; using AuxLayouts = AuxiliaryTensorLayouts; @@ -327,8 +344,10 @@ TEST(AuxiliaryTensorLayoutIntegration, ThreeAuxiliaryTensors) TEST(AuxiliaryTensorLayoutIntegration, WorksWith1DConvolution) { + using namespace enums; + static constexpr std::array aux_configs = { - MockAuxiliaryTensorConfig{.layout = TensorLayout::G_K_strided}}; + MockAuxiliaryTensorConfig{.layout = G_K_strided}}; using AuxLayouts = AuxiliaryTensorLayouts; @@ -339,8 +358,10 @@ TEST(AuxiliaryTensorLayoutIntegration, WorksWith1DConvolution) TEST(AuxiliaryTensorLayoutIntegration, WorksWith3DConvolution) { + using namespace enums; + static constexpr std::array aux_configs = { - MockAuxiliaryTensorConfig{.layout = TensorLayout::GC}}; + MockAuxiliaryTensorConfig{.layout = GC}}; using AuxLayouts = AuxiliaryTensorLayouts; @@ -351,7 +372,8 @@ TEST(AuxiliaryTensorLayoutIntegration, WorksWith3DConvolution) TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithSingleBiasG_K) { - using OutputOp = TensorOperation; + using namespace enums; + using OutputOp = TensorOperation; static constexpr auto sig = ConvSignature, ConvolutionTensor<>, ConvolutionTensor>{ @@ -359,9 +381,9 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithSingleBiasG_K) .direction = FORWARD, .data_type = DataType::FP16, .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NGCHW}}, - .weight = {.config = {.layout = TensorLayout::GKYXC}}, - .output = {.config = {.layout = TensorLayout::NGKHW}, + .input = {.config = {.layout = NGCHW}}, + .weight = {.config = {.layout = GKYXC}}, + .output = {.config = {.layout = NGKHW}, .operation = OutputOp{.elementwise_operation = ElementwiseOperation::SCALE}}}; @@ -377,7 +399,8 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithSingleBiasG_K) TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithSingleBiasGC) { - using OutputOp = TensorOperation; + using namespace enums; + using OutputOp = TensorOperation; static constexpr auto sig = ConvSignature, ConvolutionTensor<>, ConvolutionTensor>{ @@ -385,9 +408,9 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithSingleBiasGC) .direction = FORWARD, .data_type = DataType::BF16, .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NHWGC}}, - .weight = {.config = {.layout = TensorLayout::GKYXC}}, - .output = {.config = {.layout = TensorLayout::NHWGK}, + .input = {.config = {.layout = NHWGC}}, + .weight = {.config = {.layout = GKYXC}}, + .output = {.config = {.layout = NHWGK}, .operation = OutputOp{.elementwise_operation = ElementwiseOperation::SCALE}}}; @@ -403,8 +426,9 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithSingleBiasGC) TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithTwoAuxiliaryTensors) { - using OutputOp = TensorOperation; + using namespace enums; + using OutputOp = + TensorOperation; static constexpr auto sig = ConvSignature, ConvolutionTensor<>, ConvolutionTensor>{ @@ -412,9 +436,9 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithTwoAuxiliaryTensors) .direction = FORWARD, .data_type = DataType::FP16, .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::GNHWC}}, - .weight = {.config = {.layout = TensorLayout::GKYXC}}, - .output = {.config = {.layout = TensorLayout::GNHWK}, + .input = {.config = {.layout = GNHWC}}, + .weight = {.config = {.layout = GKYXC}}, + .output = {.config = {.layout = GNHWK}, .operation = OutputOp{.elementwise_operation = ElementwiseOperation::SCALEADD_SCALEADD_RELU}}}; @@ -431,7 +455,8 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithTwoAuxiliaryTensors) TEST(ConvTensorLayoutsWithAuxiliary, Conv1DWithBias) { - using OutputOp = TensorOperation; + using namespace enums; + using OutputOp = TensorOperation; static constexpr auto sig = ConvSignature, ConvolutionTensor<>, ConvolutionTensor>{ @@ -439,9 +464,9 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv1DWithBias) .direction = FORWARD, .data_type = DataType::FP32, .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NWGC}}, - .weight = {.config = {.layout = TensorLayout::GKXC}}, - .output = {.config = {.layout = TensorLayout::NWGK}, + .input = {.config = {.layout = NWGC}}, + .weight = {.config = {.layout = GKXC}}, + .output = {.config = {.layout = NWGK}, .operation = OutputOp{.elementwise_operation = ElementwiseOperation::SCALE}}}; @@ -457,7 +482,8 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv1DWithBias) TEST(ConvTensorLayoutsWithAuxiliary, Conv3DWithBias) { - using OutputOp = TensorOperation; + using namespace enums; + using OutputOp = TensorOperation; static constexpr auto sig = ConvSignature, ConvolutionTensor<>, ConvolutionTensor>{ @@ -465,9 +491,9 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv3DWithBias) .direction = FORWARD, .data_type = DataType::FP16, .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NDHWGC}}, - .weight = {.config = {.layout = TensorLayout::GKZYXC}}, - .output = {.config = {.layout = TensorLayout::NDHWGK}, + .input = {.config = {.layout = NDHWGC}}, + .weight = {.config = {.layout = GKZYXC}}, + .output = {.config = {.layout = NDHWGK}, .operation = OutputOp{.elementwise_operation = ElementwiseOperation::BIAS_BNORM_CLAMP}}}; diff --git a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp new file mode 100644 index 0000000000..e4db149a98 --- /dev/null +++ b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp @@ -0,0 +1,346 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "../impl/conv_algorithm_types.hpp" +#include +#include + +namespace ck_tile::builder::test { + +namespace ckb = ck_tile::builder; + +// Helper function to convert arrays to Seq(...) format +template +std::string array_to_seq(const std::array& arr) +{ + std::ostringstream oss; + oss << "Seq("; + for(size_t i = 0; i < N; ++i) + { + if(i > 0) + oss << ","; + oss << arr[i]; + } + oss << ")"; + return oss.str(); +} + +// Base template - will cause compilation error for unsupported types +template +std::string to_string(T) +{ + static_assert(sizeof(T) == 0, "Unsupported type"); + return ""; +} + +// Template specializations for enum types + +template <> +inline std::string to_string(PipelineVersion t) +{ + std::ostringstream oss; + oss << t; + return oss.str(); +} + +template <> +inline std::string to_string(PipelineScheduler t) +{ + std::ostringstream oss; + oss << t; + return oss.str(); +} + +template <> +inline std::string to_string(ConvFwdSpecialization t) +{ + std::ostringstream oss; + oss << t; + return oss.str(); +} + +template <> +inline std::string to_string(GemmSpecialization t) +{ + std::ostringstream oss; + oss << t; + return oss.str(); +} + +// Template specializations for struct types + +template <> +inline std::string to_string>(MNK t) +{ + return array_to_seq(std::array{t.m, t.n, t.k}); +} + +template <> +inline std::string to_string(ThreadBlock t) +{ + std::ostringstream oss; + oss << t.block_size << "," << t.tile_size.m << "," << t.tile_size.n << "," << t.tile_size.k; + return oss.str(); +} + +template <> +inline std::string to_string(GridwiseXdlGemm t) +{ + std::ostringstream oss; + oss << t.ak1 << "," << t.bk1 << "," << t.m_per_xdl << "," << t.n_per_xdl << "," + << t.m_xdl_per_wave << "," << t.n_xdl_per_wave; + return oss.str(); +} + +template <> +inline std::string to_string(GridwiseWmmaGemm t) +{ + std::ostringstream oss; + oss << t.k1 << "," << t.m_per_wmma << "," << t.n_per_wmma << "," << t.m_wmma_per_wave << "," + << t.n_wmma_per_wave; + return oss.str(); +} + +template <> +inline std::string to_string(BlockGemm t) +{ + std::ostringstream oss; + oss << to_string(t.scheduler) << "," << to_string(t.pipeline_version); + return oss.str(); +} + +template <> +inline std::string to_string(BlockTransfer t) +{ + return array_to_seq(std::array{t.k0, t.m_n, t.k1}); +} + +template <> +inline std::string to_string(ThreadCluster t) +{ + return array_to_seq( + std::array{t.m_block, t.m_wave_per_xdl, t.n_block, t.n_wave_per_xdl}); +} + +template <> +inline std::string to_string(LdsTransfer t) +{ + std::ostringstream oss; + oss << t.src_vector_dim << "," << t.src_scalar_per_vector << "," << t.lds_dst_scalar_per_vector + << "," << (t.lds_padding ? "true" : "false") << "," + << (t.is_direct_load ? "true" : "false"); + return oss.str(); +} + +template <> +inline std::string to_string(AccessOrder t) +{ + return array_to_seq(t.order); +} + +template <> +inline std::string to_string(TransferAB t) +{ + std::ostringstream oss; + oss << to_string(t.block_transfer) << "," << to_string(t.block_transfer_access_order) << "," + << to_string(t.src_access_order) << "," << t.lds_transfer.src_vector_dim << "," + << t.lds_transfer.src_scalar_per_vector << "," << t.lds_transfer.lds_dst_scalar_per_vector + << "," << (t.lds_transfer.lds_padding ? "true" : "false"); + return oss.str(); +} + +template <> +inline std::string to_string(TransferC t) +{ + std::ostringstream oss; + oss << t.epilogue.m_xdl_per_wave_per_shuffle << "," << t.epilogue.n_per_wave_per_shuffle << "," + << to_string(t.thread_cluster_dims) << "," << t.epilogue.scalar_per_vector; + return oss.str(); +} + +template <> +inline std::string to_string(TransferABC t) +{ + std::ostringstream oss; + oss << to_string(t.a) << "," << to_string(t.b) << "," << to_string(t.c); + return oss.str(); +} + +template <> +inline std::string to_string(DlThreadConfig t) +{ + std::ostringstream oss; + oss << t.k1 << "," << t.m1_per_thread << "," << t.n1_per_thread << "," << t.k_per_thread; + return oss.str(); +} + +template <> +inline std::string to_string(DlThreadCluster t) +{ + std::ostringstream oss; + oss << array_to_seq(t.m1_xs) << "," << array_to_seq(t.n1_xs); + return oss.str(); +} + +template <> +inline std::string to_string(DlBlockTransfer t) +{ + std::ostringstream oss; + oss << array_to_seq(t.thread_slice_lengths) << "," << array_to_seq(t.thread_cluster_lengths) + << "," << array_to_seq(t.thread_cluster_arrange_order) << "," + << array_to_seq(t.src_access_order) << "," << array_to_seq(t.src_vector_tensor_lengths) + << "," << array_to_seq(t.src_vector_tensor_contiguous_dim_order) << "," + << array_to_seq(t.dst_vector_tensor_lengths); + return oss.str(); +} + +template <> +inline std::string to_string(DlEpilogue t) +{ + std::ostringstream oss; + oss << array_to_seq(t.src_dst_access_order) << "," << t.src_dst_vector_dim << "," + << t.dst_scalar_per_vector; + return oss.str(); +} + +template <> +inline std::string to_string(DlBlockTransferAB t) +{ + return to_string(t.block_transfer); +} + +template <> +inline std::string to_string(DlBlockTransferC t) +{ + return to_string(t.epilogue); +} + +template <> +inline std::string to_string(DlTransferABC t) +{ + std::ostringstream oss; + oss << to_string(t.a) << "," << to_string(t.b) << "," << to_string(t.c); + return oss.str(); +} + +// Template specializations for factory wrapper types + +template <> +inline std::string to_string(ThreadBlock_ t) +{ + return to_string(t.thread_block); +} + +template <> +inline std::string to_string(XdlGemm_ t) +{ + return to_string(t.gridwise_gemm); +} + +template <> +inline std::string to_string(WmmaGemm_ t) +{ + return to_string(t.gridwise_gemm); +} + +template <> +inline std::string to_string(Transfer_ t) +{ + return to_string(t.transfer); +} + +template <> +inline std::string to_string(ConvSpecialization_ t) +{ + std::ostringstream oss; + oss << to_string(t.fwd_specialization) << "," << to_string(t.gemm_specialization); + return oss.str(); +} + +template <> +inline std::string to_string(Prefetch_ t) +{ + std::ostringstream oss; + oss << t.num_gemm_k_prefetch_stages << "," << t.num_groups_to_merge << "," + << to_string(t.loop_scheduler); + return oss.str(); +} + +template <> +inline std::string to_string(BlockGemm_ t) +{ + return to_string(t.block_gemm); +} + +template <> +inline std::string to_string(DlThreadConfig_ t) +{ + return to_string(t.thread_config); +} + +template <> +inline std::string to_string(DlThreadCluster_ t) +{ + return to_string(t.thread_cluster); +} + +template <> +inline std::string to_string(DlTransfer_ t) +{ + return to_string(t.transfer); +} + +// Template specializations for algorithm types + +template <> +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle t) +{ + std::ostringstream oss; + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast(t)); + return oss.str(); +} + +template <> +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 t) +{ + std::ostringstream oss; + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast(t)); + return oss.str(); +} + +template <> +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle t) +{ + std::ostringstream oss; + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast(t)); + return oss.str(); +} + +template <> +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK t) +{ + std::ostringstream oss; + oss << to_string(static_cast(t)) << "," + << to_string(static_cast(t)) << "," + << to_string(static_cast(t)) << "," + << to_string(static_cast(t)); + return oss.str(); +} + +template <> +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor t) +{ + return to_string(t.base_algorithm); +} + +} // namespace ck_tile::builder::test From ce99cab6056d1ffef5acb6f4ad7ede87a46a3cfc Mon Sep 17 00:00:00 2001 From: Enrico Degregori <73224202+EnricoDeg@users.noreply.github.com> Date: Thu, 11 Dec 2025 09:06:20 +0100 Subject: [PATCH 07/10] Wmma support for gemm_ab_scale (#3314) * Support gemm_ab_scale: - Add tests - Integrate scaling implementation in multiple D - Generalize existing b_scale for ab_scale - Add instances - Generalize implementation for ScaleBlockM, ScaleBlockN, ScaleBlockK - Add support for all layouts supported by xdl - Fix splitk xdl * Fix copyright * Wmma support for gemm_blockscale_wp (#3315) * Support for preshuffle with ab scale - add support for b preshuffle in GridwiseGemm_wmma_cshuffle_v3_ab_scale - add support for AScaleLayout amnd BScaleLayout (can be different from ALayout and BLayout, respectively) - add Run method in v1 pipeline to support preshuffle + scaling - add support for preshuffle gemms in common invoker - Add splitk support * Fix copyright header --- .../65_gemm_multiply_multiply/CMakeLists.txt | 2 + ...mm_multiply_multiply_wmma_fp8_ab_scale.cpp | 345 +++++++++++++ ...ltiply_wmma_fp8_blockscale_bpreshuffle.cpp | 357 +++++++++++++ .../blockwise_gemm_pipeline_wmmaops_base.hpp | 146 ++++-- .../blockwise_gemm_pipeline_wmmaops_v1.hpp | 468 +++++++++++++++++- .../blockwise_gemm_pipeline_wmmaops_v3.hpp | 345 ++++++++++++- .../device_gemm_multiple_d_ab_scale.hpp | 347 +++++++++++++ ..._batched_gemm_wmma_cshuffle_v3_b_scale.hpp | 11 +- ...m_multiple_d_wmma_cshuffle_v3_ab_scale.hpp | 362 ++++++++++++++ ...ltiple_d_wmma_cshuffle_v3_b_preshuffle.hpp | 308 +----------- ...mma_cshuffle_v3_blockscale_bpreshuffle.hpp | 360 ++++++++++++++ ...mm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp | 102 +++- .../device_gemm_wmma_cshuffle_v3_b_scale.hpp | 10 +- .../device_gemm_wmma_cshuffle_v3_common.hpp | 200 ++++++-- .../gridwise_ab_transfer_thread_tiles.hpp | 10 +- ...ise_batched_gemm_gemm_wmma_cshuffle_v3.hpp | 6 +- .../grid/gridwise_gemm_wmma_cshuffle_v3.hpp | 7 +- ...idwise_gemm_wmma_cshuffle_v3_ab_scale.hpp} | 393 +++++++++++---- .../gridwise_gemm_wmma_cshuffle_v3_common.hpp | 47 +- ..._gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp | 74 ++- .../gpu/gemm_ab_scale.hpp | 394 ++++++++++++++- .../gpu/gemm_blockscale_wp.hpp | 147 ++++++ .../gpu/CMakeLists.txt | 12 +- .../gpu/gemm_ab_scale/CMakeLists.txt | 21 +- ...e_wmma_f8_f8_bf16_km_kn_mn_128_128_128.hpp | 79 +++ ...n_mn_128_128_128_comp_default_instance.cpp | 37 ++ ..._mn_128_128_128_comp_kpadding_instance.cpp | 37 ++ ...mn_128_128_128_mem_v1_default_instance.cpp | 38 ++ ...n_128_128_128_mem_v1_kpadding_instance.cpp | 38 ++ ...e_wmma_f8_f8_bf16_mk_kn_mn_128_128_128.hpp | 80 +++ ...n_mn_128_128_128_comp_default_instance.cpp | 37 ++ ..._mn_128_128_128_comp_kpadding_instance.cpp | 37 ++ ...mn_128_128_128_mem_v1_default_instance.cpp | 38 ++ ...n_128_128_128_mem_v1_kpadding_instance.cpp | 38 ++ ...e_wmma_f8_f8_bf16_mk_nk_mn_128_128_128.hpp | 95 ++++ ...k_mn_128_128_128_comp_default_instance.cpp | 37 ++ ..._mn_128_128_128_comp_kpadding_instance.cpp | 37 ++ ...mn_128_128_128_mem_v1_default_instance.cpp | 38 ++ ...n_128_128_128_mem_v1_kpadding_instance.cpp | 38 ++ .../gpu/gemm_blockscale_wp/CMakeLists.txt | 5 +- ...p_wmma_f8_f8_bf16_mk_nk_mn_128_128_128.hpp | 77 +++ ...k_mn_128_128_128_comp_default_instance.cpp | 38 ++ ...nk_mn_128_128_128_mem_default_instance.cpp | 38 ++ .../profiler/profile_gemm_ab_scale_impl.hpp | 6 +- .../profile_gemm_blockscale_wp_impl.hpp | 2 +- test/CMakeLists.txt | 1 + test/gemm_ab_scale/CMakeLists.txt | 9 + test/gemm_ab_scale/test_gemm_ab_scale.cpp | 236 +++++++++ .../gemm_ab_scale/test_gemm_ab_scale_util.hpp | 102 ++++ test/gemm_blockscale_wp/CMakeLists.txt | 4 +- ...p8.cpp => test_gemm_blockscale_wp_fp8.cpp} | 0 51 files changed, 5144 insertions(+), 552 deletions(-) create mode 100644 example/65_gemm_multiply_multiply/gemm_multiply_multiply_wmma_fp8_ab_scale.cpp create mode 100644 example/65_gemm_multiply_multiply/gemm_multiply_multiply_wmma_fp8_blockscale_bpreshuffle.cpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3_ab_scale.hpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3_blockscale_bpreshuffle.hpp rename include/ck/tensor_operation/gpu/grid/{gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp => gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp} (58%) create mode 100644 library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_comp_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_comp_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_mem_v1_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_mem_v1_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_comp_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_comp_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_mem_v1_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_mem_v1_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_wmma_f8_f8_bf16/device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_128_128_128.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_wmma_f8_f8_bf16/device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_wmma_f8_f8_bf16/device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_mem_default_instance.cpp create mode 100644 test/gemm_ab_scale/CMakeLists.txt create mode 100644 test/gemm_ab_scale/test_gemm_ab_scale.cpp create mode 100644 test/gemm_ab_scale/test_gemm_ab_scale_util.hpp rename test/gemm_blockscale_wp/{test_gemm_blockscale_wp_xdl_fp8.cpp => test_gemm_blockscale_wp_fp8.cpp} (100%) diff --git a/example/65_gemm_multiply_multiply/CMakeLists.txt b/example/65_gemm_multiply_multiply/CMakeLists.txt index abfbe115fb..944a8f96bf 100644 --- a/example/65_gemm_multiply_multiply/CMakeLists.txt +++ b/example/65_gemm_multiply_multiply/CMakeLists.txt @@ -77,3 +77,5 @@ example_compile_options(example_moe_gemm1_xdl_fp8_blockscale PRIVATE ${BLOCKSCAL add_example_executable(example_gemm_add_add_wmma_fp16 gemm_add_add_wmma_fp16.cpp) add_example_executable(example_gemm_multiply_multiply_wmma_fp16_bpreshuffle gemm_multiply_multiply_wmma_fp16_bpreshuffle.cpp) add_example_executable(example_gemm_multiply_multiply_wmma_fp8_bpreshuffle gemm_multiply_multiply_wmma_fp8_bpreshuffle.cpp) +add_example_executable(example_gemm_multiply_multiply_wmma_fp8_ab_scale gemm_multiply_multiply_wmma_fp8_ab_scale.cpp) +add_example_executable(example_gemm_multiply_multiply_wmma_fp8_blockscale_bpreshuffle gemm_multiply_multiply_wmma_fp8_blockscale_bpreshuffle.cpp) diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_wmma_fp8_ab_scale.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_wmma_fp8_ab_scale.cpp new file mode 100644 index 0000000000..0fb7a70781 --- /dev/null +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_wmma_fp8_ab_scale.cpp @@ -0,0 +1,345 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3_ab_scale.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using FP8 = ck::f8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = FP8; +using A1DataType = F32; +using B0DataType = FP8; +using B1DataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = BF16; + +using A0Layout = Row; +using B0Layout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr ck::index_t Scale_Block_M = 1; +static constexpr ck::index_t Scale_Block_N = 128; +static constexpr ck::index_t Scale_Block_K = 128; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3 + // clang-format off + , S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + 1, 1, S<1, 32, 1, 8>, S<8>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + bool flush_cache = true; + + // GEMM shape + ck::index_t M = 128; + ck::index_t N = 1024; + ck::index_t K = 1024; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideE = N; + + ck::index_t KBatch = 1; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 8 || argc == 9) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + flush_cache = std::stoi(argv[7]); + + if(argc == 9) + { + KBatch = std::stoi(argv[8]); + } + + StrideA = K; + StrideB = K; + StrideE = N; + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 6: M, N, K\n"); + printf("arg7: flush both I$ and L2$ (0=no, 1=yes)\n"); + printf("arg8: KBatch (default: 1)\n"); + exit(0); + } + + ck::index_t Scale_Stride_AM = (K + Scale_Block_K - 1) / Scale_Block_K; + ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return ck::HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return ck::HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + ck::Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); + ck::Tensor a1_m_k(f_host_tensor_descriptor((M + Scale_Block_M - 1) / Scale_Block_M, + (K + Scale_Block_K - 1) / Scale_Block_K, + Scale_Stride_AM, + A0Layout{})); + ck::Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + ck::Tensor b1_k_n(f_host_tensor_descriptor((K + Scale_Block_K - 1) / Scale_Block_K, + (N + Scale_Block_N - 1) / Scale_Block_N, + Scale_Stride_BN, + B0Layout{})); + ck::Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + ck::Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "a1_m_k: " << a1_m_k.mDesc << std::endl; + std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; + std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_m_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + b1_k_n.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + break; + case 2: + a0_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_k_n.GenerateTensorValue(GeneratorTensor_1{}); + a1_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 3: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 4: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_m_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 5: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b0_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_m_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + } + + ck::DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); + ck::DeviceMem a1_device_buf(sizeof(A1DataType) * a1_m_k.mDesc.GetElementSpaceSize()); + ck::DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); + ck::DeviceMem b1_device_buf(sizeof(B1DataType) * b1_k_n.mDesc.GetElementSpaceSize()); + ck::DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + a1_device_buf.ToDevice(a1_m_k.mData.data()); + b0_device_buf.ToDevice(b0_k_n.mData.data()); + b1_device_buf.ToDevice(b1_k_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumDTensor = DsDataType::Size(); + + // do GEMM + auto device_op = DeviceOpInstance{}; + std::string op_name = device_op.GetTypeString(); + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(static_cast(a0_device_buf.GetDeviceBuffer()), + static_cast(b0_device_buf.GetDeviceBuffer()), + std::array{}, + static_cast(e_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + std::array{}, + StrideE, + static_cast(a1_device_buf.GetDeviceBuffer()), + static_cast(b1_device_buf.GetDeviceBuffer()), + a_element_op, + b_element_op, + cde_element_op, + KBatch); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float ave_time = .0; + + ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 0, 50, 100}); + + int pass = 0; + + if(do_verification) + { + ck::Tensor c_m_n({M, N}); + ck::Tensor a_m_k({M, K}); + ck::Tensor b_k_n({K, N}); + + for(int m = 0; m < M; m++) + { + for(int k = 0; k < K; k++) + { + a_m_k(m, k) = ck::type_convert(a0_m_k(m, k)) * + a1_m_k(m / Scale_Block_M, k / Scale_Block_K); + } + } + + for(int n = 0; n < N; n++) + { + for(int k = 0; k < K; k++) + { + b_k_n(k, n) = ck::type_convert(b0_k_n(k, n)) * + b1_k_n(k / Scale_Block_K, n / Scale_Block_N); + } + } + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + e_m_n_host_result(m, n) = ck::type_convert(c_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + pass = ck::utils::check_err( + e_m_n_device_result, e_m_n_host_result, "Error: Incorrect results!", 5e-2, 5e-2) + ? 0 + : 1; + } + + if(flush_cache) + { + int rotating_buf = (512 * 1024 * 1024 + num_btype - 1) / num_btype; + + ave_time = invoker.Run(argument, + StreamConfig{nullptr, time_kernel, 0, 50, 100, true, rotating_buf}); + } + else + { + ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 50, 100}); + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << op_name << ", KBatch " << KBatch << std::endl; + + return pass; +} diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_wmma_fp8_blockscale_bpreshuffle.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_wmma_fp8_blockscale_bpreshuffle.cpp new file mode 100644 index 0000000000..ba95724d3f --- /dev/null +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_wmma_fp8_blockscale_bpreshuffle.cpp @@ -0,0 +1,357 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3_blockscale_bpreshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +#include "common.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using FP8 = ck::f8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = FP8; +using A1DataType = F32; +using B0DataType = FP8; +using B1DataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = BF16; + +using A0Layout = Row; +using A1Layout = Col; +using B0Layout = Col; +using D0Layout = Row; +using D1Layout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +static constexpr int KPack = 16; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr ck::index_t Scale_Block_M = 1; +static constexpr ck::index_t Scale_Block_N = 128; +static constexpr ck::index_t Scale_Block_K = 128; + +using DeviceOpInstance = + ck::tensor_operation::device::DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle + // clang-format off + , S<1, 0, 2>, S<1, 0, 2>, + 2, 16, 16, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 16, 16, 0, + 1, 1, + S<1, 32, 1, 8>, S<8>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, FP8>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + bool flush_cache = true; + + // GEMM shape + ck::index_t M = 128; + ck::index_t N = 1024; + ck::index_t K = 1024; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideE = N; + + ck::index_t KBatch = 1; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 8 || argc == 9) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + flush_cache = std::stoi(argv[7]); + + if(argc == 9) + { + KBatch = std::stoi(argv[8]); + } + + StrideA = K; + StrideB = K; + StrideE = N; + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 6: M, N, K\n"); + printf("arg7: flush both I$ and L2$ (0=no, 1=yes)\n"); + printf("arg8: KBatch (default: 1)\n"); + exit(0); + } + + // Transpose the AScale tensor for better performance + ck::index_t Scale_Stride_AK = (M + Scale_Block_M - 1) / Scale_Block_M; + ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return ck::HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return ck::HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + ck::Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); + ck::Tensor a1_m_k(f_host_tensor_descriptor((M + Scale_Block_M - 1) / Scale_Block_M, + (K + Scale_Block_K - 1) / Scale_Block_K, + Scale_Stride_AK, + A1Layout{})); + ck::Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + ck::Tensor b0_preshuffled( + f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); // use laout only for size + ck::Tensor b1_k_n(f_host_tensor_descriptor((K + Scale_Block_K - 1) / Scale_Block_K, + (N + Scale_Block_N - 1) / Scale_Block_N, + Scale_Stride_BN, + B0Layout{})); + ck::Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + ck::Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "a1_m_k: " << a1_m_k.mDesc << std::endl; + std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; + std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_m_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 2: + a0_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_k_n.GenerateTensorValue(GeneratorTensor_1{}); + a1_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 3: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 4: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_m_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 5: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b0_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_m_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + } + + ck::DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); + ck::DeviceMem a1_device_buf(sizeof(A1DataType) * a1_m_k.mDesc.GetElementSpaceSize()); + ck::DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); + ck::DeviceMem b1_device_buf(sizeof(B1DataType) * b1_k_n.mDesc.GetElementSpaceSize()); + ck::DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + a1_device_buf.ToDevice(a1_m_k.mData.data()); + b1_device_buf.ToDevice(b1_k_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumDTensor = DsDataType::Size(); + + // do GEMM + auto device_op = DeviceOpInstance{}; + std::string op_name = device_op.GetTypeString(); + int NPerWmma = device_op.GetPreShuffleParameters(); + + preShuffleBuffer(b0_k_n.mData.data(), b0_preshuffled.mData.data(), N, K, NPerWmma); + + b0_device_buf.ToDevice(b0_preshuffled.mData.data()); + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + std::array{}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{}, + StrideE, + a1_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer(), + a_element_op, + b_element_op, + cde_element_op, + KBatch); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float ave_time = 0.0f; + + if(flush_cache) + { + int rotating_buf = (512 * 1024 * 1024 + num_btype - 1) / num_btype; + + ave_time = invoker.Run(argument, + StreamConfig{nullptr, time_kernel, 0, 50, 100, true, rotating_buf}); + } + else + { + ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 50, 100}); + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << op_name << ", KBatch " << KBatch << std::endl; + + if(do_verification) + { + ck::Tensor c_m_n({M, N}); + ck::Tensor a_m_k({M, K}); + ck::Tensor b_k_n({K, N}); + + for(int m = 0; m < M; m++) + { + for(int k = 0; k < K; k++) + { + a_m_k(m, k) = ck::type_convert(a0_m_k(m, k)) * + a1_m_k(m / Scale_Block_M, k / Scale_Block_K); + } + } + + for(int n = 0; n < N; n++) + { + for(int k = 0; k < K; k++) + { + b_k_n(k, n) = ck::type_convert(b0_k_n(k, n)) * + b1_k_n(k / Scale_Block_K, n / Scale_Block_N); + } + } + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + e_m_n_host_result(m, n) = ck::type_convert(c_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err( + e_m_n_device_result, e_m_n_host_result, "Error: Incorrect results!", 5e-2, 5e-2) + ? 0 + : 1; + } + + return 0; +} diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp index f24a1eb3bc..f831c0f6cf 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp @@ -109,65 +109,145 @@ struct BlockwiseGemmWmmaops_pipeline_base } }; - template - struct BScale + typename ThreadDesc> + struct ABScale { - __device__ BScale(GridDesc b_scale_grid_desc_, - ThreadCopy b_scale_thread_copy_, - GridBuffer b_scale_grid_buf_) - : b_scale_thread_copy(b_scale_thread_copy_), - b_scale_grid_desc(b_scale_grid_desc_), - b_scale_grid_buf(b_scale_grid_buf_) {}; + __device__ ABScale(GridDesc scale_grid_desc_, + ThreadCopy scale_thread_copy_, + GridBuffer scale_grid_buf_) + : scale_thread_copy(scale_thread_copy_), + scale_grid_desc(scale_grid_desc_), + scale_grid_buf(scale_grid_buf_) {}; - static constexpr index_t num_scale_k_block = BScaleThreadDesc{}.GetLength(Number<1>{}); + static constexpr index_t num_scale_k_block = ThreadDesc{}.GetLength(Number<1>{}); static constexpr index_t num_scale_krepeat = KRepeat / num_scale_k_block; - static constexpr auto b_scale_thread_desc = BScaleThreadDesc{}; + static constexpr index_t num_slice_mn = ScaleSliceSizeMN; + static constexpr index_t num_slice_k = ScaleSliceSizeK; + static constexpr index_t reg_size_per_wmma = RegSizePerWmma; - static constexpr auto b_scale_thread_copy_step = - make_tuple(make_multi_index(NWaves * NPerWmma, 0), - make_multi_index(-NPerBlock, 0), - make_multi_index(-NPerBlock, (KPerBlock + ScaleBlockK - 1) / ScaleBlockK)); + static constexpr auto scale_thread_desc = ThreadDesc{}; + + static constexpr auto scale_thread_copy_step = + make_tuple(make_multi_index(ScaleSliceStrideMN, 0), + make_multi_index(-ScaleSliceSizeMN / RegSizePerWmma * ScaleSliceStrideMN, 0), + make_multi_index(-ScaleSliceSizeMN / RegSizePerWmma * ScaleSliceStrideMN, + ScaleSliceSizeK)); template __device__ void GlobalLoad(bool cond) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_scale_thread_copy.Run(b_scale_grid_desc, - b_scale_grid_buf, - b_scale_thread_desc, - make_tuple(n0, Number<0>{}), - b_scale_thread_bufs(Number{})); + static_for<0, ScaleSliceSizeMN / RegSizePerWmma, 1>{}([&](auto m0) { + scale_thread_copy.Run(scale_grid_desc, + scale_grid_buf, + scale_thread_desc, + make_tuple(m0 * Number{}, Number<0>{}), + scale_thread_bufs(Number{})); - b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, - b_scale_thread_copy_step.At(Number<0>{})); + scale_thread_copy.MoveSrcSliceWindow(scale_grid_desc, + scale_thread_copy_step.At(Number<0>{})); }); if(cond) { - b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, - b_scale_thread_copy_step.At(Number<2>{})); + scale_thread_copy.MoveSrcSliceWindow(scale_grid_desc, + scale_thread_copy_step.At(Number<2>{})); } else { - b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, - b_scale_thread_copy_step.At(Number<1>{})); + scale_thread_copy.MoveSrcSliceWindow(scale_grid_desc, + scale_thread_copy_step.At(Number<1>{})); } } - ThreadCopy b_scale_thread_copy; - GridDesc b_scale_grid_desc; - GridBuffer b_scale_grid_buf; - StaticallyIndexedArray{}> b_scale_thread_bufs; + ThreadCopy scale_thread_copy; + GridDesc scale_grid_desc; + GridBuffer scale_grid_buf; + StaticallyIndexedArray{}> scale_thread_bufs; + }; + + template + struct CScale + { + __device__ CScale() {} + + static constexpr auto reg_size_per_wmma = + ck::is_same_v && ck::is_same_v + ? 1 + : wmma_gemm.GetRegSizePerWmma(); + static constexpr auto c_scale_thread_desc = make_naive_tensor_descriptor_packed(make_tuple( + Number{}, + Number{}, + Number{})); + using CScaleThreadDesc = decltype(c_scale_thread_desc); + static constexpr auto num_scale_k_block = CScaleThreadDesc{}.GetLength(Number<0>{}); + static constexpr auto num_scale_m_block = CScaleThreadDesc{}.GetLength(Number<1>{}); + static constexpr auto num_scale_n_block = CScaleThreadDesc{}.GetLength(Number<2>{}); + using ThreadStaticBuffer = decltype(make_static_buffer( + c_scale_thread_desc.GetElementSpaceSize())); + + __device__ void Load(AScaleStruct& a_scale_struct, BScaleStruct& b_scale_struct) + { + using AScaleThreadDesc = decltype(AScaleStruct::scale_thread_desc); + using BScaleThreadDesc = decltype(BScaleStruct::scale_thread_desc); + + static_for<0, num_scale_m_block, 1>{}([&](auto m0) { + static_for<0, num_scale_n_block, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto k0) { + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = + AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = + BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + + c_scale_thread_bufs(I0)(Number{}) = + a_scale_struct.scale_thread_bufs(I0)[Number{}] * + b_scale_struct.scale_thread_bufs(I0)[Number{}]; + }); + }); + }); + } + + __device__ void Clear() + { + static_for<0, reg_size_per_wmma, 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + } + + template + __device__ void UpdateCThreadBuf(CThreadBuf& c_thread_buf) + { + static_for<0, reg_size_per_wmma, 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m_index, n_index, t)); + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(make_tuple( + k_index, + (m_index * num_scale_m_block / MRepeat) % num_scale_m_block + + (Number{}) % + AScaleStruct::reg_size_per_wmma, + (n_index * num_scale_n_block / NRepeat) % num_scale_n_block)); + c_thread_buf(Number{}) += + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[Number{}] * + type_convert(c_scale_thread_bufs(I0)[Number{}]); + }); + } + + StaticallyIndexedArray{}> c_scale_thread_bufs; + StaticBufferTupleOfVector + c_thread_buf_per_scale; }; __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp index 0f62aee0a8..3b12e7feb0 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp @@ -174,7 +174,9 @@ struct BlockwiseGemmWmmaops_pipeline_v1 + typename AScaleStruct, + typename BScaleStruct, + typename enable_if, bool>::type = false> __device__ void Run(const AGridDesc& a_grid_desc, const ABlockDesc& a_block_desc, ABlockTransfer& a_blockwise_copy, @@ -188,7 +190,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1(num_loop_per_scale == 1); // Local prefill 1 @@ -217,6 +220,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, @@ -245,7 +249,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}], b_thread_desc_, @@ -366,6 +370,189 @@ struct BlockwiseGemmWmmaops_pipeline_v1 && + !ck::is_same_v, + bool>::type = false> + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + AScaleStruct& a_scale_struct, + BScaleStruct& b_scale_struct, + index_t num_loop, + index_t num_loop_per_scale) const + { + constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk(); + static constexpr auto NumScaleKBlock = + Number{}; + + auto a_thread_buf = make_static_buffer( + Base::a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + Base::b_thread_desc_.GetElementSpaceSize()); + + using CScaleStruct = typename Base::template CScale; + auto c_scale_struct = CScaleStruct{}; + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Scales global load + a_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1); + b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1); + + c_scale_struct.Load(a_scale_struct, b_scale_struct); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // Initialize C + c_thread_buf.Clear(); + + auto blockwise_gemm_func = [&]() { + // Local load + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + Base::a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(I0, m0, k0, I0, I0, I0, I0), + a_block_buf, + Base::a_thread_desc_, + make_tuple(I0, m0, k0, I0, I0, I0, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + Base::b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, + make_tuple(I0, n0, k0, I0, I0, I0, I0), + b_block_buf, + Base::b_thread_desc_, + make_tuple(I0, n0, k0, I0, I0, I0, I0), + b_thread_buf); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) { + c_scale_struct.Clear(); + static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KInner, 1>{}([&](auto k_inner) { + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + constexpr index_t k_index = + kscale0 * (KRepeat / NumScaleKBlock) + k0; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + m0, + k_index, + I0, + I0, + I0, + Number{}))>{}]; + }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + constexpr index_t k_index = + kscale0 * (KRepeat / NumScaleKBlock) + k0; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, + n0, + k_index, + I0, + I0, + I0, + Number{}))>{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + wmma_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference( + Number<0>{})); + }); + }); + c_scale_struct.template UpdateCThreadBuf(c_thread_buf); + }); + }); + }); + }; + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0); + b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0); + + block_sync_lds(); + blockwise_gemm_func(); + + block_sync_lds(); + c_scale_struct.Load(a_scale_struct, b_scale_struct); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + i += 1; + } while(i < (num_loop - 1)); + } + + // tail + if constexpr(TailNum == TailNumber::Full) + { + block_sync_lds(); + blockwise_gemm_func(); + } + } + protected: // A[MRepeat, I1, I1, KPack] static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( @@ -528,6 +715,23 @@ struct BlockwiseGemmWmmaops_pipeline_v1 + struct KLoopParams + { + static constexpr auto KRepeatNoScale = 1; + static constexpr auto NumScaleKBlock = + Number{}; + static constexpr auto KRepeatPerNumScaleKBlock = KRepeatPerCluster / NumScaleKBlock; + }; + + template <> + struct KLoopParams + { + static constexpr index_t KRepeatNoScale = KRepeatPerCluster; + static constexpr index_t NumScaleKBlock = 1; + static constexpr index_t KRepeatPerNumScaleKBlock = 1; + }; + template + typename AScaleStruct, + typename BScaleStruct, + typename enable_if, bool>::type = false> __device__ void Run(const AGridDesc& a_grid_desc, const ABlockDesc& a_block_desc, ABlockTransfer& a_blockwise_copy, @@ -557,7 +763,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1(num_loop_per_scale == 1); // Local prefill 1 @@ -615,7 +822,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}], b_thread_desc_, @@ -704,6 +911,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1 + typename AScaleStruct, + typename BScaleStruct, + typename enable_if, bool>::type = false> __device__ void Run(const AGridDesc& a_grid_desc, const ABlockDesc& a_block_desc, ABlockTransfer& a_blockwise_copy, @@ -996,7 +1206,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1 && + !ck::is_same_v, + bool>::type = false> + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc&, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer&, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + AScaleStruct& a_scale_struct, + BScaleStruct& b_scale_struct, + index_t num_loop, + index_t num_loop_per_scale) const + { + __builtin_amdgcn_sched_barrier(0); + constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk(); + static constexpr auto NumScaleKBlock = + Number{}; + + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> b_thread_bufs; + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0); + + using CScaleStruct = typename Base::template CScale; + auto c_scale_struct = CScaleStruct{}; + + auto gemm_core_func = [&](auto reg_buf) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) { + c_scale_struct.Clear(); + static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + static_for<0, KInner, 1>{}([&](auto k_inner) { + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + constexpr index_t k_index = + kscale0 * (KRepeat / NumScaleKBlock) + k0; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + m0, + k_index, + I0, + I0, + I0, + Number{}))>{}]; + }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + constexpr index_t k_index = + kscale0 * (KRepeat / NumScaleKBlock) + k0; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[reg_buf] + [Number{}, + I0, + I0, + n0, + I0, + k_index, + Number{}))>{}]; + }); + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + wmma_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference( + Number<0>{})); + }); + }); + c_scale_struct.template UpdateCThreadBuf(c_thread_buf); + }); + }); + }); + }; + + auto a_local_prefetch_func = [&]() { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(I0, m0, k0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, k0, I0, I0, I0, I0), + a_thread_buf); + }); + }); + }; + + // Global prefetch A1 B1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_k0_n0_n1_n2_k1, + b_block_origin_idx, + b_thread_bufs(I0)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Scales global load + a_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1); + b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1); + + __builtin_amdgcn_sched_barrier(0); + + c_scale_struct.Load(a_scale_struct, b_scale_struct); + + // Local prefill A1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + + // Global prefetch A2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + // Local prefetch A1 + block_sync_lds(); + a_local_prefetch_func(); + + // Initialize C + c_thread_buf.Clear(); + + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + auto LoopFunc = [&](auto wmma_reg_buf, auto local_read_buf) { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_k0_n0_n1_n2_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf)); + + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, wmma_reg_buf); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + a_scale_struct.template GlobalLoad<0>( + (i + 2 + wmma_reg_buf) % num_loop_per_scale == 0); + b_scale_struct.template GlobalLoad<0>( + (i + 2 + wmma_reg_buf) % num_loop_per_scale == 0); + + gemm_core_func(wmma_reg_buf); + + block_sync_lds(); + + // loop prefetch copy + a_local_prefetch_func(); + + c_scale_struct.Load(a_scale_struct, b_scale_struct); + + // HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + + // tail + if constexpr(TailNum == TailNumber::Even) + { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_k0_n0_n1_n2_k1, + b_block_origin_idx, + b_thread_bufs(I1)); + + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + + a_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0); + b_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0); + + gemm_core_func(I0); + + block_sync_lds(); + + // tail Local Prefetch A1 + a_local_prefetch_func(); + + c_scale_struct.Load(a_scale_struct, b_scale_struct); + + __builtin_amdgcn_sched_barrier(0); + + gemm_core_func(I1); + // Let's leak last WMMA block to epilogue region, cover the potential lds-shuffle + // latency + // __builtin_amdgcn_sched_barrier(0); + } + else if constexpr(TailNum == TailNumber::Odd) + { + gemm_core_func(I0); + } + } + protected: static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(make_tuple(Number{}, diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp index 08c765dd0a..b8d451363e 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp @@ -123,6 +123,9 @@ struct BlockwiseGemmWmmaops_pipeline_v3; using Base::I0; + using Base::I1; + using Base::I2; + using Base::I3; using Base::A_K1; using Base::A_KRow; @@ -322,7 +325,7 @@ struct BlockwiseGemmWmmaops_pipeline_v3{}], b_thread_desc_, @@ -348,7 +351,9 @@ struct BlockwiseGemmWmmaops_pipeline_v3 + typename AScaleStruct, + typename BScaleStruct, + typename enable_if, bool>::type = false> __device__ void Run(const AGridDesc& a_grid_desc, const ABlockDesc& a_block_desc, ABlockTransfer& a_blockwise_copy, @@ -362,7 +367,7 @@ struct BlockwiseGemmWmmaops_pipeline_v3(num_loop_per_scale == 1); // Local prefill 1 @@ -611,6 +617,339 @@ struct BlockwiseGemmWmmaops_pipeline_v3 && + !ck::is_same_v, + bool>::type = false> + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + AScaleStruct& a_scale_struct, + BScaleStruct& b_scale_struct, + index_t num_loop, + index_t num_loop_per_scale) const + { + __builtin_amdgcn_sched_barrier(0); + + constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk(); + static constexpr auto NumScaleKBlock = + Number{}; + + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + using CScaleStruct = typename Base::template CScale; + auto c_scale_struct = CScaleStruct{}; + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Scales global load + a_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1); + b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1); + + c_scale_struct.Load(a_scale_struct, b_scale_struct); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // Global prefetch 2, perform when at least 2 loops exist. + if constexpr(TailNum == TailNumber::Even || TailNum == TailNumber::Full) + { + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + } + + // Initialize C + c_thread_buf.Clear(); + + // Local prefetch 1 + block_sync_lds(); + + auto local_load_func = [&]() { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(I0, m0, k0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, k0, I0, I0, I0, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, + make_tuple(I0, n0, k0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, k0, I0, I0, I0, I0), + b_thread_buf); + }); + }); + }; + + local_load_func(); + + __builtin_amdgcn_sched_barrier(0); + + // Main body, perform when at least 3 loops exist. + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0); + b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) { + c_scale_struct.Clear(); + static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) { + static_for<0, KInner, 1>{}([&](auto k_inner) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + constexpr index_t k_index = + kscale0 * (KRepeat / NumScaleKBlock) + k0; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + m0, + k_index, + I0, + I0, + I0, + Number{}))>{}]; + }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + constexpr index_t k_index = + kscale0 * (KRepeat / NumScaleKBlock) + k0; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, + n0, + k_index, + I0, + I0, + I0, + Number{}))>{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_scale_struct.c_thread_buf_per_scale + .GetVectorTypeReference(Number<0>{})); + }); + }); + c_scale_struct.template UpdateCThreadBuf(c_thread_buf); + }); + }); + }); + + c_scale_struct.Load(a_scale_struct, b_scale_struct); + block_sync_lds(); + + local_load_func(); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + + i += 1; + } while(i < (num_loop - 2)); + } + + // Pre-tail, perform when at least 2 loops exist. + if constexpr(TailNum == TailNumber::Even || TailNum == TailNumber::Full) + { + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // No RunRead or MoveSrcSliceWindow here, already finished them all! + a_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0); + b_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) { + c_scale_struct.Clear(); + static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) { + static_for<0, KInner, 1>{}([&](auto k_inner) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + constexpr index_t k_index = + kscale0 * (KRepeat / NumScaleKBlock) + k0; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + m0, + k_index, + I0, + I0, + I0, + Number{}))>{}]; + }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + constexpr index_t k_index = + kscale0 * (KRepeat / NumScaleKBlock) + k0; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, + n0, + k_index, + I0, + I0, + I0, + Number{}))>{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + wmma_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference( + Number<0>{})); + }); + }); + c_scale_struct.template UpdateCThreadBuf(c_thread_buf); + }); + }); + }); + + c_scale_struct.Load(a_scale_struct, b_scale_struct); + block_sync_lds(); + + local_load_func(); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + } + + // Tail, always perform. + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) { + c_scale_struct.Clear(); + static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + static_for<0, KInner, 1>{}([&](auto k_inner) { + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + constexpr index_t k_index = + kscale0 * (KRepeat / NumScaleKBlock) + k0; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + m0, + k_index, + I0, + I0, + I0, + Number{}))>{}]; + }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + constexpr index_t k_index = + kscale0 * (KRepeat / NumScaleKBlock) + k0; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, + n0, + k_index, + I0, + I0, + I0, + Number{}))>{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + wmma_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference( + Number<0>{})); + }); + }); + c_scale_struct.template UpdateCThreadBuf(c_thread_buf); + }); + }); + }); + // Let's leak last WMMA block to epilogue region, cover the potential lds-shuffle + // latency + // __builtin_amdgcn_sched_barrier(0); + } + } + protected: using Base::a_thread_copy_; using Base::a_thread_desc_; diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp index 52a915de52..23b5178e3d 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp @@ -105,6 +105,353 @@ struct DeviceGemmMultipleD_BlockScale_BPreshuffle : public BaseOperator virtual int GetPreShuffleParameters() = 0; }; +template +struct DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + const ck::index_t M, + const ck::index_t N, + const ck::index_t K, + const ck::index_t StrideA, + const ck::index_t StrideB, + const std::array StrideDs, + const ck::index_t StrideE, + const void* p_a_scale, + const void* p_b_scale, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + index_t KBatch) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; + + virtual int GetPreShuffleParameters() = 0; +}; + +/// @brief Wrapper for backward compatibility that allows to use instances of +/// DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK in contexts where +// DeviceGemmMultipleD_BlockScale_BPreshuffle is expected. +/// +/// @note The main area where it can be used is DeviceOperationInstanceFactory::GetInstances(). +/// The only difference between API of DeviceGemmMultipleD_BlockScale_BPreshuffle and +// DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK is +/// that DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK::MakeArgumentPointer requires +// an additional parameter KBatch which is explicitly passed as 1 by this wrapper. +template +struct DeviceGemmMultipleD_BlockScale_BPreshuffleWrapper + : public DeviceGemmMultipleD_BlockScale_BPreshuffle +{ + using DeviceOp = DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK; + + static constexpr index_t NumDTensor = DsDataType::Size(); + +#ifndef __HIPCC_RTC__ + + explicit DeviceGemmMultipleD_BlockScale_BPreshuffleWrapper(std::unique_ptr p_op) + : p_op_(std::move(p_op)) + { + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return p_op_->IsSupportedArgument(p_arg); + } + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + const ck::index_t M, + const ck::index_t N, + const ck::index_t K, + const ck::index_t StrideA, + const ck::index_t StrideB, + const std::array StrideDs, + const ck::index_t StrideE, + const void* p_a_scale, + const void* p_b_scale, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return p_op_->MakeArgumentPointer(p_a, + p_b, + p_ds, + p_e, + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + p_a_scale, + p_b_scale, + a_element_op, + b_element_op, + cde_element_op, + 1); // KBatch + } + + std::unique_ptr MakeInvokerPointer() override + { + return p_op_->MakeInvokerPointer(); + } + + int GetPreShuffleParameters() override { return p_op_->GetPreShuffleParameters(); } + + std::string GetTypeString() const override { return p_op_->GetTypeString(); } + + private: + std::unique_ptr p_op_; + +#endif // __HIPCC_RTC__ +}; + +// GEMM: +// input : A[M, K], B[K, N], +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct DeviceGemmMultipleD_ABScaleSplitK : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + const ck::index_t M, + const ck::index_t N, + const ck::index_t K, + const ck::index_t StrideA, + const ck::index_t StrideB, + const std::array StrideDs, + const ck::index_t StrideE, + const void* p_a_scale, + const void* p_b_scale, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + index_t KBatch) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; + + virtual void SetKBatch(BaseArgument* arg, int KBatch) const = 0; +}; + +/// @brief Wrapper for backward compatibility that allows to use instances of +/// DeviceGemmMultipleD_ABScaleSplitK in contexts where DeviceGemmMultipleD_ABScale is +/// expected. +/// +/// @note The main area where it can be used is DeviceOperationInstanceFactory::GetInstances(). +/// The only difference between API of DeviceGemmMultipleD_ABScale and +/// DeviceGemmMultipleD_ABScaleSplitK is that +/// DeviceGemmMultipleD_ABScaleSplitK::MakeArgumentPointer requires a additional parameter +/// KBatch which is explicitly passed as 1 by this wrapper. +template +struct DeviceGemmMultipleD_ABScaleSplitKWrapper + : public DeviceGemmMultipleD_ABScale +{ + + using DeviceOp = DeviceGemmMultipleD_ABScaleSplitK; + + static constexpr index_t NumDTensor = DsDataType::Size(); + +#ifndef __HIPCC_RTC__ + + explicit DeviceGemmMultipleD_ABScaleSplitKWrapper(std::unique_ptr p_op) + : p_op_(std::move(p_op)) + { + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return p_op_->IsSupportedArgument(p_arg); + } + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + const ck::index_t M, + const ck::index_t N, + const ck::index_t K, + const ck::index_t StrideA, + const ck::index_t StrideB, + const std::array StrideDs, + const ck::index_t StrideE, + const void* p_a_scale, + const void* p_b_scale, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return p_op_->MakeArgumentPointer(p_a, + p_b, + p_ds, + p_e, + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + p_a_scale, + p_b_scale, + a_element_op, + b_element_op, + cde_element_op, + 1); // KBatch + } + + void SetKBatch(BaseArgument* arg, int KBatch) const override { p_op_->SetKBatch(arg, KBatch); } + + std::unique_ptr MakeInvokerPointer() override + { + return p_op_->MakeInvokerPointer(); + } + + std::string GetTypeString() const override { return p_op_->GetTypeString(); } + + private: + std::unique_ptr p_op_; + +#endif // __HIPCC_RTC__ +}; + } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp index 7752b334ed..ee1ddc494d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp @@ -12,7 +12,7 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/flush_cache.hpp" @@ -93,7 +93,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) p_bs_grid_shift, karg.p_ds_grid, karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset, - karg.p_b_scale_grid + b_scale_batch_offset + splitk_batch_offset.scale_k_split_offset, + karg.p_a_scale_grid, + karg.p_b_scale_grid + b_scale_batch_offset + splitk_batch_offset.scale_b_k_split_offset, p_shared, karg, karg.a_element_op, @@ -315,12 +316,13 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3_BScale }; // GridwiseGemm - using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3_b_scale< + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3_ab_scale< ALayout, BLayout, Tuple<>, // DsLayout CLayout, Tuple, + void, // AScaleType Tuple, BScaleDataType, AccDataType, @@ -332,6 +334,7 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3_BScale CElementwiseOperation, GemmSpec, BlockSize, + 0, // ScaleBlockM ScaleBlockN, ScaleBlockK, MPerBlock, @@ -405,7 +408,9 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3_BScale std::array{StrideB_}, std::array{}, // StrideDs_ StrideC_, + 0, // StrideScaleA StrideScaleB_, + nullptr, p_b_scale_grid_, k_batch_, a_element_op_, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3_ab_scale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3_ab_scale.hpp new file mode 100644 index 0000000000..81a5d35e7c --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3_ab_scale.hpp @@ -0,0 +1,362 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3 + : public DeviceGemmMultipleD_ABScaleSplitK +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3_ab_scale< + ALayout, + BLayout, + DsLayout, + CLayout, + Tuple, + AScaleDataType, + Tuple, + BScaleDataType, + AccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + ScaleBlockM, + ScaleBlockN, + ScaleBlockK, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB>; + + using Argument = typename GridwiseGemm::Argument; + + using DeviceGemmCommon = + DeviceGemm_Wmma_CShuffleV3_Common, + Tuple, + DsDataType, + CDataType, + MPerBlock, + NPerBlock, + KPerBlock, + BlockSize, + AK1, + BK1, + GemmSpec, + CShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB>; + + // Invoker + using Invoker = typename DeviceGemmCommon::Invoker; + + static bool IsSupportedArgument(const Argument& arg) + { + // with splitk the implementation doesn't work + // when KRead % ScaleBlockK != 0, independently of K padding + if(arg.KBatch > 1 && arg.KRead % ScaleBlockK != 0) + { + return false; + } + + return DeviceGemmCommon::IsSupportedArgument(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + void SetKBatch(BaseArgument* base_arg, int KBatch) const override + { + auto& arg = *dynamic_cast(base_arg); + arg.KBatch = KBatch; + arg.KRead = GridwiseGemm::CalculateKRead(arg.K, KBatch); + arg.KPadded = GridwiseGemm::CalculateKPadded(arg.K, KBatch); + arg.AK0 = GridwiseGemm::CalculateAK0Padded(arg.K, KBatch); + arg.BK0 = GridwiseGemm::CalculateBK0Padded(arg.K, KBatch); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + std::array p_ds, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + const std::array StrideDs, + index_t StrideC, + const BScaleDataType* p_a_scale, + const BScaleDataType* p_b_scale, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation cde_element_op, + index_t KBatch = 1) + { + index_t StrideScaleA = ck::is_same_v + ? math::integer_divide_ceil(K, ScaleBlockK) + : math::integer_divide_ceil(M, ScaleBlockM); + + index_t StrideScaleB = ck::is_same_v + ? math::integer_divide_ceil(K, ScaleBlockK) + : math::integer_divide_ceil(N, ScaleBlockN); + + return Argument{std::array{p_a}, + std::array{p_b}, + p_ds, + p_c, + M, + N, + K, + std::array{StrideA}, + std::array{StrideB}, + StrideDs, + StrideC, + StrideScaleA, + StrideScaleB, + p_a_scale, + p_b_scale, + KBatch, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + const std::array StrideDs, + index_t StrideC, + const void* p_a_scale, + const void* p_b_scale, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t KBatch = 1) override + { + index_t StrideScaleA = ck::is_same_v + ? math::integer_divide_ceil(K, ScaleBlockK) + : math::integer_divide_ceil(M, ScaleBlockM); + + index_t StrideScaleB = ck::is_same_v + ? math::integer_divide_ceil(K, ScaleBlockK) + : math::integer_divide_ceil(N, ScaleBlockN); + + return std::make_unique(std::array{p_a}, + std::array{p_b}, + p_ds, + static_cast(p_c), + M, + N, + K, + std::array{StrideA}, + std::array{StrideB}, + StrideDs, + StrideC, + StrideScaleA, + StrideScaleB, + static_cast(p_a_scale), + static_cast(p_b_scale), + KBatch, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGemm_ABScale_Wmma_CShuffleV3" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"< -__global__ void -#if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) -#endif - kernel_gemm_b_preshuffle_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg) -{ -#if(defined(__gfx11__) || defined(__gfx12__)) -#if defined(__gfx11__) - // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions - using e_data_type = remove_cvref_t>; - if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && - (std::is_same_v || - std::is_same_v))) - { -#endif - constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< - typename GridwiseGemm::EpilogueCShuffle>(); - __shared__ char p_shared[LDS_size]; - - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - - const index_t num_k_per_block = math::integer_divide_ceil(karg.K, GridwiseGemm::KPack); - const index_t k_id = blockIdx.z * num_k_per_block; - - auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; - - GridwiseGemm::template Run( - p_shared, splitk_batch_offset, karg, epilogue_args, k_id); - -#if defined(__gfx11__) - } -#endif -#else - ignore = karg; -#endif -} - -} // namespace ck - namespace ck { namespace tensor_operation { namespace device { @@ -202,270 +156,14 @@ struct DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, - ComputeTypeB>; + ComputeTypeB, + true>; // IsBPreshuffle // Invoker - struct Invoker : public BaseInvoker - { - /// @brief This function issues GPU kernel execution. - /// @param arg The GPU kernel arguments. - /// @param stream_config The HIP stream configuration helper structure. - /// @return The kernel's average execution time (if time measurement is - /// enabled). - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) - { - if(stream_config.log_level_ > 0) - { - arg.Print(); - GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print(); - } - - if(!GridwiseGemm::CheckValidity(arg)) - { - throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); - } - - index_t gdx, gdy, gdz; - std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); - - float ave_time = 0; - - index_t k_grain = arg.KBatch * KPerBlock; - index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; - - const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); - - const auto Run = [&](const auto& kernel) { - if(stream_config.flush_cache) - { - Argument arg_ = arg; - - const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAsGridDescriptor_AK0_M_AK1( - arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideAs, arg_.AK0); - const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBsGridDescriptor_BK0_N_BK1( - arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideBs, arg_.BK0); - - std::array size_as_buffers; - size_as_buffers[Number<0>{}] = - a_grid_desc_ak0_m_ak1[Number<0>{}].GetElementSpaceSize() * - sizeof(ADataType) / GridwiseGemm::APackedSize; - - std::array size_bs_buffers; - size_bs_buffers[Number<0>{}] = - b_grid_desc_bk0_n_bk1[Number<0>{}].GetElementSpaceSize() * - sizeof(BDataType) / GridwiseGemm::BPackedSize; - - const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N( - arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs); - - std::array size_ds_buffers; - static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) { - using DDataType = remove_cvref_t>; - size_ds_buffers[i] = - ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType); - }); - - ck::utility::RotatingMemWrapperMultiABD, - Tuple, - DsDataType> - rotating_mem(arg_, - stream_config.rotating_count, - size_as_buffers, - size_bs_buffers, - size_ds_buffers); - rotating_mem.Print(); - - auto run_flush_cache = [&]() { - // flush icache - ck::utility::flush_icache(); - // rotating mem - rotating_mem.Next(); - // clear c mem - if(arg_.KBatch > 1) - HIP_CHECK_ERROR(hipMemsetAsync(arg_.p_e_grid, - 0, - arg_.M * arg_.N * sizeof(EDataType), - stream_config.stream_id_)); - }; - - ave_time = ck::utility::launch_and_time_kernel_with_preprocess( - stream_config, - run_flush_cache, - kernel, - dim3(gdx, gdy, gdz), - dim3(BlockSize), - 0, - arg_); - } - else - { - if(arg.KBatch > 1) - HIP_CHECK_ERROR(hipMemsetAsync(arg.p_e_grid, - 0, - arg.M * arg.N * sizeof(EDataType), - stream_config.stream_id_)); - - ave_time = launch_and_time_kernel( - stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); - } - }; - - constexpr index_t minimum_occupancy = []() { - if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave) - { - return 2; - } - else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) - { - return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1; - } - else - { - return 1; - } - }(); - - // ThreadwiseTensorSliceTransfer_v7r3 (used in ThreadGroupTensorSliceTransfer_v7r3) is - // currently implemented in such a way that all SrcScalarPerVectors must be the same, so - // if one of D matrices is column-major, then all SrcScalarPerVectors must be 1. On the - // other hand, Split K for 16-bit outputs uses packed atomics so ScalarPerVectors cannot - // be odd. - constexpr bool AtomicsImplementationExists = - !(std::is_same_v || std::is_same_v || - std::is_same_v) || - (CDEShuffleBlockTransferScalarPerVectors{}[0] % 2 == 0); - - if(has_main_k_block_loop) - { - // Tail number always full - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) - { - if(arg.KBatch > 1) - { - if constexpr(AtomicsImplementationExists) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) - { - const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3< - GridwiseGemm, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Odd>; - Run(kernel); - } - else - { - const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3< - GridwiseGemm, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Even>; - Run(kernel); - } - } - } - else - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) - { - const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3< - GridwiseGemm, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Odd>; - Run(kernel); - } - else - { - const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3< - GridwiseGemm, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Even>; - Run(kernel); - } - } - } - } - else - { - // Tail number always 1 - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) - { - if(arg.KBatch > 1) - { - if constexpr(AtomicsImplementationExists) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) - { - const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3< - GridwiseGemm, - false, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Odd>; - Run(kernel); - } - else - { - const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3< - GridwiseGemm, - false, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Even>; - Run(kernel); - } - } - } - else - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) - { - const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3< - GridwiseGemm, - false, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Odd>; - Run(kernel); - } - else - { - const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3< - GridwiseGemm, - false, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Even>; - Run(kernel); - } - } - } - } - - return ave_time; - } - - // polymorphic - float Run(const BaseArgument* p_arg, - const StreamConfig& stream_config = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg), stream_config); - } - }; + using Invoker = typename DeviceGemmCommon::Invoker; static bool IsSupportedArgument(const Argument& arg) { - if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0) - { - return false; - } return DeviceGemmCommon::IsSupportedArgument(arg); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3_blockscale_bpreshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3_blockscale_bpreshuffle.hpp new file mode 100644 index 0000000000..1b1a1fcc6c --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3_blockscale_bpreshuffle.hpp @@ -0,0 +1,360 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle + : public DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + using AScaleLayout = tensor_layout::gemm::ColumnMajor; + using BScaleLayout = BLayout; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3_ab_scale< + ALayout, + BLayout, + DsLayout, + CLayout, + Tuple, + AScaleDataType, + Tuple, + BScaleDataType, + AccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + ScaleBlockM, + ScaleBlockN, + ScaleBlockK, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB, + true, + AScaleLayout, + BScaleLayout>; + + using Argument = typename GridwiseGemm::Argument; + int GetPreShuffleParameters() override { return NPerWmma; } + + using DeviceGemmCommon = + DeviceGemm_Wmma_CShuffleV3_Common, + Tuple, + DsDataType, + CDataType, + MPerBlock, + NPerBlock, + KPerBlock, + BlockSize, + AK1, + BK1, + GemmSpec, + CShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + true>; // IsBPreshuffle + + // Invoker + using Invoker = typename DeviceGemmCommon::Invoker; + + static bool IsSupportedArgument(const Argument& arg) + { + // with splitk the implementation doesn't work + // when KRead % ScaleBlockK != 0, independently of K padding + if(arg.KBatch > 1 && arg.KRead % ScaleBlockK != 0) + { + return false; + } + + return DeviceGemmCommon::IsSupportedArgument(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + const std::array StrideDs, + index_t StrideC, + const void* p_a_scale, + const void* p_b_scale, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation cde_element_op, + index_t KBatch) + { + index_t StrideScaleA = ck::is_same_v + ? math::integer_divide_ceil(K, ScaleBlockK) + : math::integer_divide_ceil(M, ScaleBlockM); + + index_t StrideScaleB = ck::is_same_v + ? math::integer_divide_ceil(K, ScaleBlockK) + : math::integer_divide_ceil(N, ScaleBlockN); + + return Argument{std::array{p_a}, + std::array{p_b}, + p_ds, + static_cast(p_e), + M, + N, + K, + std::array{StrideA}, + std::array{StrideB}, + StrideDs, + StrideC, + StrideScaleA, + StrideScaleB, + static_cast(p_a_scale), + static_cast(p_b_scale), + KBatch, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + const std::array StrideDs, + index_t StrideC, + const void* p_a_scale, + const void* p_b_scale, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t KBatch) override + { + index_t StrideScaleA = ck::is_same_v + ? math::integer_divide_ceil(K, ScaleBlockK) + : math::integer_divide_ceil(M, ScaleBlockM); + + index_t StrideScaleB = ck::is_same_v + ? math::integer_divide_ceil(K, ScaleBlockK) + : math::integer_divide_ceil(N, ScaleBlockN); + + return std::make_unique(std::array{p_a}, + std::array{p_b}, + p_ds, + static_cast(p_e), + M, + N, + K, + std::array{StrideA}, + std::array{StrideB}, + StrideDs, + StrideC, + StrideScaleA, + StrideScaleB, + static_cast(p_a_scale), + static_cast(p_b_scale), + KBatch, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"< 1) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else { const auto kernel = kernel_gemm_xdl_cshuffle_v3; - Run(kernel); + if(arg.KBatch > 1) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } } else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - Run(kernel); + if(arg.KBatch > 1) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } } } } @@ -315,6 +350,20 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 { auto& arg = *dynamic_cast(base_arg); arg.KBatch = KBatch; + if(get_warp_size() == 64) + { + arg.KRead = GridwiseGemm64::CalculateKRead(arg.K, KBatch); + arg.KPadded = GridwiseGemm64::CalculateKPadded(arg.K, KBatch); + arg.AK0 = GridwiseGemm64::CalculateAK0Padded(arg.K, KBatch); + arg.BK0 = GridwiseGemm64::CalculateBK0Padded(arg.K, KBatch); + } + else + { + arg.KRead = GridwiseGemm32::CalculateKRead(arg.K, KBatch); + arg.KPadded = GridwiseGemm32::CalculateKPadded(arg.K, KBatch); + arg.AK0 = GridwiseGemm32::CalculateAK0Padded(arg.K, KBatch); + arg.BK0 = GridwiseGemm32::CalculateBK0Padded(arg.K, KBatch); + } } static constexpr bool IsValidCompilationParameter() @@ -325,6 +374,13 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 static bool IsSupportedArgument(const Argument& arg) { + // with splitk the implementation doesn't work + // when KRead % ScaleBlockK != 0, independently of K padding + if(arg.KBatch > 1 && arg.KRead % ScaleBlockK != 0) + { + return false; + } + if(!ck::is_xdl_wmma_supported()) { return false; @@ -385,6 +441,14 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) { + index_t StrideScaleA = ck::is_same_v + ? math::integer_divide_ceil(K, ScaleBlockK) + : math::integer_divide_ceil(M, ScaleBlockM); + + index_t StrideScaleB = ck::is_same_v + ? math::integer_divide_ceil(K, ScaleBlockK) + : math::integer_divide_ceil(N, ScaleBlockN); + return Argument{static_cast(p_a), static_cast(p_b), p_ds, @@ -396,6 +460,8 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 StrideB, StrideDs, StrideC, + StrideScaleA, + StrideScaleB, static_cast(p_a_scale), static_cast(p_b_scale), 1, @@ -425,6 +491,14 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override { + index_t StrideScaleA = ck::is_same_v + ? math::integer_divide_ceil(K, ScaleBlockK) + : math::integer_divide_ceil(M, ScaleBlockM); + + index_t StrideScaleB = ck::is_same_v + ? math::integer_divide_ceil(K, ScaleBlockK) + : math::integer_divide_ceil(N, ScaleBlockN); + return std::make_unique(static_cast(p_a), static_cast(p_b), p_ds, @@ -436,6 +510,8 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 StrideB, StrideDs, StrideC, + StrideScaleA, + StrideScaleB, static_cast(p_a_scale), static_cast(p_b_scale), 1, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp index e824fcc9dd..491f3a7dac 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp @@ -12,7 +12,7 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/flush_cache.hpp" @@ -86,12 +86,13 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale, // DsLayout CLayout, Tuple, + void, // AScaleType Tuple, BScaleDataType, AccDataType, @@ -103,6 +104,7 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale{StrideB}, std::array{}, // StrideDs_ StrideC, + 0, // StrideScaleA StrideScaleB, + nullptr, p_b_scale, KBatch, a_element_op, @@ -245,7 +249,9 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale{StrideB}, std::array{}, // StrideDs_ StrideC, + 0, // StrideScaleA StrideScaleB, + nullptr, // p_a_scale static_cast(p_b_scale), KBatch, a_element_op, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp index 6706365fb7..e96ec58cba 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp @@ -38,7 +38,8 @@ template + typename ComputeTypeB, + bool IsBPreShuffled = false> struct DeviceGemm_Wmma_CShuffleV3_Common { @@ -189,61 +190,174 @@ struct DeviceGemm_Wmma_CShuffleV3_Common if(has_main_k_block_loop) { // Tail number always full - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || - BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + if constexpr(IsBPreShuffled) { - if(arg.KBatch > 1) + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) { - if constexpr(AtomicsImplementationExists) + if(arg.KBatch > 1) { - const auto kernel = - kernel_gemm_wmma_cshuffle_v3; - Run(kernel); + if constexpr(AtomicsImplementationExists) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Odd) + { + const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } } - } - else - { - const auto kernel = - kernel_gemm_wmma_cshuffle_v3; - Run(kernel); } } else { - // TODO: Implement - } - } - else - { - // Tail number always 1 - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) - { - if(arg.KBatch > 1) + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) { - if constexpr(AtomicsImplementationExists) + if(arg.KBatch > 1) + { + if constexpr(AtomicsImplementationExists) + { + const auto kernel = kernel_gemm_wmma_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + } + else { const auto kernel = kernel_gemm_wmma_cshuffle_v3; Run(kernel); } } - else + } + } + else + { + if constexpr(IsBPreShuffled) + { + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) { - const auto kernel = - kernel_gemm_wmma_cshuffle_v3; - Run(kernel); + if(arg.KBatch > 1) + { + if constexpr(AtomicsImplementationExists) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Odd) + { + const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3< + GridwiseGemm, + false, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3< + GridwiseGemm, + false, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3< + GridwiseGemm, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3< + GridwiseGemm, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + } + else + { + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(arg.KBatch > 1) + { + if constexpr(AtomicsImplementationExists) + { + const auto kernel = kernel_gemm_wmma_cshuffle_v3< + GridwiseGemm, + false, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + } + else + { + const auto kernel = + kernel_gemm_wmma_cshuffle_v3; + Run(kernel); + } } } } @@ -299,6 +413,14 @@ struct DeviceGemm_Wmma_CShuffleV3_Common return false; } + if constexpr(IsBPreShuffled) + { + if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0) + { + return false; + } + } + return GridwiseGemm::CheckValidity(arg); } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp index 4526eb3186..69f8f44390 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp @@ -388,11 +388,11 @@ struct ABTransferThreadTiles // 1 - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma - K1 return transform_tensor_descriptor( BlockDesc{}, - make_tuple( - make_unmerge_transform(make_tuple(Number{}, KRow, Number<1>{})), - make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(Number{})), + make_tuple(make_unmerge_transform(make_tuple( + Number{}, KRow, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<2, 4, 0>{}, Sequence<1, 3, 5>{}, Sequence<6>{})); } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp index f58f67dc6b..121ca258be 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp @@ -895,8 +895,9 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3 c_thread_buf.Clear(); // Empty BScale struct for the blockwise pipeline. - using BScale = typename BlockwiseGemmPipe::Empty; - auto b_scale_struct = BScale{}; + using ABScale = typename BlockwiseGemmPipe::Empty; + auto a_scale_struct = ABScale{}; + auto b_scale_struct = ABScale{}; /*******************************************************************************/ // @@ -919,6 +920,7 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3 b0_block_buf, b0_block_slice_copy_step, acc0_thread_buf, + a_scale_struct, b_scale_struct, KBlockMainLoop, 1); // num_k_block_per_scale diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp index e55ac807c5..fea0102337 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp @@ -618,8 +618,9 @@ struct GridwiseGemm_wmma_cshuffle_v3 __builtin_amdgcn_readfirstlane(block_work_idx[Number{}]); // BScale struct (Empty) - using BScale = typename BlockwiseGemmPipe::Empty; - auto b_scale_struct = BScale{}; + using Scale = typename BlockwiseGemmPipe::Empty; + auto a_scale_struct = Scale{}; + auto b_scale_struct = Scale{}; const index_t num_k_block_per_scale = GetKBlockPerScale(); @@ -627,6 +628,7 @@ struct GridwiseGemm_wmma_cshuffle_v3 decltype(bs_grid_desc_bk0_n_bk1), decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock), decltype(e_grid_desc_mblock_mperblock_nblock_nperblock), + decltype(a_scale_struct), decltype(b_scale_struct), decltype(epilogue_args), HasMainKBlockLoop, @@ -646,6 +648,7 @@ struct GridwiseGemm_wmma_cshuffle_v3 block_m_id, block_n_id, num_k_block_per_scale, + a_scale_struct, b_scale_struct, epilogue_args, k_id); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp similarity index 58% rename from include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp rename to include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp index 8684731c96..ac5b7dd0c4 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp @@ -23,6 +23,7 @@ template -struct GridwiseGemm_wmma_cshuffle_v3_b_scale + BlockGemmPipelineScheduler BlkGemmPipeSched, + BlockGemmPipelineVersion BlkGemmPipelineVer, + typename ComputeTypeA, + typename ComputeTypeB, + bool PermuteA, + bool PermuteB, + bool IsBPreShuffled = false, + typename AScaleLayout = ALayout, + typename BScaleLayout = BLayout> +struct GridwiseGemm_wmma_cshuffle_v3_ab_scale : GridwiseGemm_wmma_cshuffle_v3_base< ALayout, BLayout, @@ -123,7 +128,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale ComputeTypeB, PermuteA, PermuteB, - false, + IsBPreShuffled, true> { using Base = GridwiseGemm_wmma_cshuffle_v3_base< @@ -177,7 +182,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale ComputeTypeB, PermuteA, PermuteB, - false, + IsBPreShuffled, true>; using Base::I0; @@ -233,6 +238,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale std::array StrideBs_, std::array StrideDs_, index_t StrideE_, + index_t StrideScaleA_, index_t StrideScaleB_, index_t KBatch_) : M{M_}, @@ -242,6 +248,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale StrideBs{StrideBs_}, StrideDs{StrideDs_}, StrideE{StrideE_}, + StrideScaleA{StrideScaleA_}, StrideScaleB{StrideScaleB_}, KBatch{KBatch_}, MPadded{CalculateMPadded(M_)}, @@ -251,7 +258,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale AK0{CalculateAK0Padded(K_, KBatch_)}, BK0{CalculateBK0Padded(K_, KBatch_)}, MBlock{CalculateMBlock(M_)}, - NBlock{CalculateNBlock(N_)} + NBlock{CalculateNBlock(N_)}, + Kt{K_} { } @@ -275,11 +283,11 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale }); std::cout << " }, "; } - std::cout << "SE:" << StrideE << ", " << "SScaleB:" << StrideScaleB << ", " - << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " << "KRead:" << KRead - << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0 - << ", " << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" - << std::endl; + std::cout << "SE:" << StrideE << ", " << "SScaleA:" << StrideScaleA << ", " + << "SScaleB:" << StrideScaleB << ", " << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded + << ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl; } index_t M; @@ -289,6 +297,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale std::array StrideBs; std::array StrideDs; index_t StrideE; + index_t StrideScaleA; index_t StrideScaleB; index_t KBatch; index_t MPadded; @@ -299,6 +308,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale index_t BK0; index_t MBlock; index_t NBlock; + index_t Kt; }; // Argument @@ -315,7 +325,9 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale std::array StrideBs_, std::array StrideDs_, index_t StrideE_, + index_t StrideScaleA_, index_t StrideScaleB_, + const AScaleType* p_a_scale_grid_, const BScaleType* p_b_scale_grid_, index_t k_batch_, AElementwiseOperation a_element_op_, @@ -329,12 +341,14 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale StrideBs_, StrideDs_, StrideE_, + StrideScaleA_, StrideScaleB_, k_batch_}, p_as_grid{}, p_bs_grid{}, p_ds_grid{}, p_e_grid{p_e_grid_}, + p_a_scale_grid{p_a_scale_grid_}, p_b_scale_grid{p_b_scale_grid_}, a_element_op{a_element_op_}, b_element_op{b_element_op_}, @@ -379,6 +393,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale DsGridPointer p_ds_grid; EDataType* p_e_grid; + const AScaleType* p_a_scale_grid; const BScaleType* p_b_scale_grid; const AElementwiseOperation a_element_op; const BElementwiseOperation b_element_op; @@ -407,34 +422,52 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale [&](auto i) { a_k_split_offset[i] = k_id * karg.KRead * karg.StrideAs[i]; }); } - if constexpr(is_same_v) + if constexpr(IsBPreShuffled) { - static_for<0, NumBTensor, 1>{}( - [&](auto i) { b_k_split_offset[i] = k_id * karg.KRead * karg.StrideBs[i]; }); + static_for<0, NumBTensor, 1>{}([&](auto i) { b_k_split_offset[i] = 0; }); } - else if constexpr(is_same_v) + else { - if constexpr(!PermuteB) + if constexpr(is_same_v) { - static_for<0, NumBTensor, 1>{}( - [&](auto i) { b_k_split_offset[i] = k_id * karg.KRead / BPackedSize; }); + static_for<0, NumBTensor, 1>{}([&](auto i) { + b_k_split_offset[i] = k_id * karg.KRead * karg.StrideBs[i]; + }); } - else + else if constexpr(is_same_v) { - const int k0_offset = karg.KRead * karg.N; - static_for<0, NumBTensor, 1>{}( - [&](auto i) { b_k_split_offset[i] = k_id * k0_offset / BPackedSize; }); + if constexpr(!PermuteB) + { + static_for<0, NumBTensor, 1>{}( + [&](auto i) { b_k_split_offset[i] = k_id * karg.KRead / BPackedSize; }); + } + else + { + const int k0_offset = karg.KRead * karg.N; + static_for<0, NumBTensor, 1>{}( + [&](auto i) { b_k_split_offset[i] = k_id * k0_offset / BPackedSize; }); + } } } - // Calculate B scale offset - if constexpr(is_same_v) + // Calculate A scale offset + if constexpr(is_same_v) { - scale_k_split_offset = k_id * (karg.KRead / ScaleBlockK) * karg.StrideB; + scale_a_k_split_offset = k_id * (karg.KRead / ScaleBlockK); } - else if constexpr(is_same_v) + else if constexpr(is_same_v) { - scale_k_split_offset = k_id * (karg.KRead / ScaleBlockK); + scale_a_k_split_offset = k_id * (karg.KRead / ScaleBlockK) * karg.StrideScaleA; + } + + // Calculate B scale offset + if constexpr(is_same_v) + { + scale_b_k_split_offset = k_id * (karg.KRead / ScaleBlockK) * karg.StrideScaleB; + } + else if constexpr(is_same_v) + { + scale_b_k_split_offset = k_id * (karg.KRead / ScaleBlockK); } if(k_id < karg.KBatch - 1) @@ -458,77 +491,225 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale std::array a_k_split_offset; std::array b_k_split_offset; - index_t scale_k_split_offset; // New member for scale matrix offset + index_t scale_a_k_split_offset; // A scale matrix offset + index_t scale_b_k_split_offset; // B scale matrix offset index_t c_reduce_offset; }; using BlockwiseGemmPipe = typename Base::BlockwiseGemmPipe; // return block_id to C matrix tile idx (m0, n0) mapping - // if arch = gfx942 using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; - // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit; - template - __device__ static auto MakeBScale(const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak, - const BScaleType* p_b_scale_grid, - index_t block_n_id) + __device__ static constexpr auto + MakeAScaleGridDesciptor_M_K(index_t M, index_t K, index_t StrideScaleA) { - const auto b_scale_grid_buf = make_dynamic_buffer( - p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + const auto BM = math::integer_divide_ceil(M, ScaleBlockM); + const auto BK = math::integer_divide_ceil(K, ScaleBlockK); + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(StrideScaleA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(I1, StrideScaleA)); + } + } - static constexpr auto wmma = - WmmaSelector{}; - static constexpr auto KPerThread = wmma.selected_wmma.k_per_wmma; + template + __device__ static auto + MakeAScale(const Problem& problem, const AScaleType* p_a_scale_grid, index_t block_m_id) + { + if constexpr(ck::is_same_v) + { + using AScale = typename BlockwiseGemmPipe::Empty; + return AScale{}; + } + else + { +#if defined(__gfx11__) + // TODO: remove this restriction + static_assert(ScaleBlockM >= MPerWmma, + "ScaleBlockM must be greater equal than MPerWmma"); +#endif + static_assert( + ScaleBlockK >= + WmmaSelector:: + selected_wmma.k_per_wmma, + "ScaleBlockK must be greater equal than KPerWmma"); - static constexpr auto ScaleSliceSizeN = NRepeat; - static constexpr auto ScaleSliceSizeK = (KPerThread + ScaleBlockK - 1) / ScaleBlockK; + const auto a_scale_grid_desc_am_ak = + MakeAScaleGridDesciptor_M_K(problem.M, problem.K, problem.StrideScaleA); - constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{})); + const auto a_scale_grid_buf = make_dynamic_buffer( + p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); - constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma); + constexpr auto wmma = + WmmaSelector{}; + constexpr auto RegSizePerWmmaFull = + wmma.selected_wmma.num_acc_vgprs_per_wave * wmma.selected_wmma.acc_pack_number; + constexpr auto RegSizePerWmma = + math::integer_divide_ceil(RegSizePerWmmaFull, ScaleBlockM); - auto b_thread_offset_n = get_thread_local_1d_id() % NPerWmma + - (get_thread_local_1d_id() / 32) % NWaves * NPerWmma; - auto b_thread_offset_k = (get_thread_local_1d_id() % 32) / NPerWmma * KPerThread; + constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma); + constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma); - auto b_scale_thread_copy = - ThreadwiseTensorSliceTransfer_v2, - Sequence<0, 1>, - 1, - ScaleSliceSizeK, - 1, - false>( - b_scale_grid_desc_bn_ak, - make_multi_index(block_n_id * NPerBlock / ScaleBlockN + b_thread_offset_n, - b_thread_offset_k / ScaleBlockK)); + constexpr auto ScaleSliceSizeM = + ScaleBlockM < MPerWmma ? MRepeat * RegSizePerWmma + : math::integer_divide_ceil(MPerBlock, ScaleBlockM); + constexpr auto ScaleSliceStrideM = + math::integer_divide_ceil(MWaves * MPerWmma, ScaleBlockM); + constexpr auto ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK); - auto b_scale_thread_buf = make_static_buffer( - b_scale_thread_desc.GetElementSpaceSize()); + constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); - using BScale = - typename BlockwiseGemmPipe::template BScale; + auto a_thread_offset_m = + ((get_thread_local_1d_id() % 32) / MPerWmma * RegSizePerWmma) / + math::integer_divide_ceil(ScaleBlockM, RegSizePerWmmaFull) + + (get_thread_local_1d_id() / 32) / NWaves * MPerWmma / ScaleBlockM; - return BScale{b_scale_grid_desc_bn_ak, b_scale_thread_copy, b_scale_grid_buf}; + constexpr index_t VectorDim = + is_same::value ? 0 : 1; + constexpr index_t VectorSize = + is_same::value ? RegSizePerWmma + : ScaleSliceSizeK; + + auto a_scale_thread_copy = + ThreadwiseTensorSliceTransfer_v2, + Sequence<0, 1>, + VectorDim, + VectorSize, + 1, + true>( + a_scale_grid_desc_am_ak, + make_multi_index(block_m_id * MPerBlock / ScaleBlockM + a_thread_offset_m, 0)); + + auto a_scale_thread_buf = make_static_buffer( + a_scale_thread_desc.GetElementSpaceSize()); + + using AScale = + typename BlockwiseGemmPipe::template ABScale; + + return AScale{a_scale_grid_desc_am_ak, a_scale_thread_copy, a_scale_grid_buf}; + } + } + + __device__ static constexpr auto + MakeBScaleGridDesciptor_N_K(index_t N, index_t K, index_t StrideScaleB) + { + const auto BN = math::integer_divide_ceil(N, ScaleBlockN); + const auto BK = math::integer_divide_ceil(K, ScaleBlockK); + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(StrideScaleB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(I1, StrideScaleB)); + } + } + + template + __device__ static auto + MakeBScale(const Problem& problem, const BScaleType* p_b_scale_grid, index_t block_n_id) + { + if constexpr(ck::is_same_v) + { + using BScale = typename BlockwiseGemmPipe::Empty; + return BScale{}; + } + else + { + static_assert( + ScaleBlockK >= + WmmaSelector:: + selected_wmma.k_per_wmma, + "ScaleBlockK must be greater equal than KPerWmma"); + + const auto b_scale_grid_desc_bn_ak = + MakeBScaleGridDesciptor_N_K(problem.N, problem.K, problem.StrideScaleB); + + const auto b_scale_grid_buf = make_dynamic_buffer( + p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + + constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma); + + constexpr auto ScaleSliceSizeN = + ScaleBlockN < NPerWmma ? NRepeat + : math::integer_divide_ceil(NPerBlock, ScaleBlockN); + constexpr auto ScaleSliceStrideN = + math::integer_divide_ceil(NWaves * NPerWmma, ScaleBlockN); + constexpr auto ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK); + + constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + auto b_thread_offset_n = (get_thread_local_1d_id() % NPerWmma + + (get_thread_local_1d_id() / 32) % NWaves * NPerWmma) / + ScaleBlockN; + + constexpr index_t VectorDim = + is_same::value ? 0 : 1; + constexpr index_t VectorSize = + is_same::value ? 1 : ScaleSliceSizeK; + + auto b_scale_thread_copy = + ThreadwiseTensorSliceTransfer_v2, + Sequence<0, 1>, + VectorDim, + VectorSize, + 1, + true>( + b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock / ScaleBlockN + b_thread_offset_n, 0)); + + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + + using BScale = + typename BlockwiseGemmPipe::template ABScale; + + return BScale{b_scale_grid_desc_bn_ak, b_scale_thread_copy, b_scale_grid_buf}; + } } __device__ static index_t GetKBlockPerScale() { - return (ScaleBlockK + KPerBlock - 1) / KPerBlock; + if constexpr(ck::is_same_v && ck::is_same_v) + { + return 0; + } + else + { + return (ScaleBlockK + KPerBlock - 1) / KPerBlock; + } } template ( @@ -562,12 +746,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( e_grid_desc_m_n, problem.MBlock, problem.NBlock); - // B Scale grid - const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( - make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN), - math::integer_divide_ceil(problem.K, ScaleBlockK)), - make_tuple(problem.StrideScaleB, 1)); - // divide block work by [M, N] const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; @@ -585,8 +763,11 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + // AScale struct + auto a_scale_struct = MakeAScale<1>(problem, p_a_scale_grid, block_m_id); + // BScale struct - auto b_scale_struct = MakeBScale<1>(b_scale_grid_desc_bn_ak, p_b_scale_grid, block_n_id); + auto b_scale_struct = MakeBScale<1>(problem, p_b_scale_grid, block_n_id); const index_t num_k_block_per_scale = GetKBlockPerScale(); @@ -594,6 +775,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale decltype(bs_grid_desc_bk0_n_bk1), decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock), decltype(e_grid_desc_mblock_mperblock_nblock_nperblock), + decltype(a_scale_struct), decltype(b_scale_struct), decltype(epilogue_args), HasMainKBlockLoop, @@ -613,8 +795,10 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale block_m_id, block_n_id, num_k_block_per_scale, + a_scale_struct, b_scale_struct, - epilogue_args); + epilogue_args, + k_id); } // NOTE: Wrapper function to have __global__ function in common @@ -626,7 +810,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale __device__ static void Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, Argument& karg, - EpilogueArgument& epilogue_args) + EpilogueArgument& epilogue_args, + const index_t k_id = 0) { // shift A matrices pointer for splitk AsGridPointer p_as_grid_splitk; @@ -644,18 +829,40 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale splitk_batch_offset.b_k_split_offset[i]; }); + const AScaleType* p_a_scale_grid_ptr; + if constexpr(ck::is_same_v) + { + p_a_scale_grid_ptr = karg.p_a_scale_grid; + } + else + { + p_a_scale_grid_ptr = karg.p_a_scale_grid + splitk_batch_offset.scale_a_k_split_offset; + } + + const BScaleType* p_b_scale_grid_ptr; + if constexpr(ck::is_same_v) + { + p_b_scale_grid_ptr = karg.p_b_scale_grid; + } + else + { + p_b_scale_grid_ptr = karg.p_b_scale_grid + splitk_batch_offset.scale_b_k_split_offset; + } + Run( p_as_grid_splitk, p_bs_grid_splitk, karg.p_ds_grid, karg.p_e_grid + splitk_batch_offset.c_reduce_offset, - karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset, + p_a_scale_grid_ptr, + p_b_scale_grid_ptr, p_shared, karg, karg.a_element_op, karg.b_element_op, karg.cde_element_op, - epilogue_args); + epilogue_args, + k_id); } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index 04d1d98448..81aa1ac986 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -69,6 +69,48 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif } +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_gemm_b_preshuffle_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg) +{ +#if(defined(__gfx11__) || defined(__gfx12__)) +#if defined(__gfx11__) + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + using e_data_type = remove_cvref_t>; + if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && + (std::is_same_v || + std::is_same_v))) + { +#endif + constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< + typename GridwiseGemm::EpilogueCShuffle>(); + __shared__ char p_shared[LDS_size]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + const index_t num_k_per_block = math::integer_divide_ceil(karg.K, GridwiseGemm::KPack); + const index_t k_id = blockIdx.z * num_k_per_block; + + auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + + GridwiseGemm::template Run( + p_shared, splitk_batch_offset, karg, epilogue_args, k_id); + +#if defined(__gfx11__) + } +#endif +#else + ignore = karg; +#endif +} + template ( - karg.p_a_grid, - karg.p_b_grid, + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, karg.p_ds_grid, karg.p_c_grid, - karg.p_a_scale_grid, - karg.p_b_scale_grid, + karg.p_a_scale_grid + splitk_batch_offset.scale_a_k_split_offset, + karg.p_b_scale_grid + splitk_batch_offset.scale_b_k_split_offset, p_shared, karg, karg.a_element_op, @@ -405,31 +407,33 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 } } - __host__ __device__ static constexpr auto MakeAScaleGridDesciptor_M_K(index_t M, index_t K) + __host__ __device__ static constexpr auto + MakeAScaleGridDesciptor_M_K(index_t M, index_t K, index_t StrideScaleA) { const auto BM = math::integer_divide_ceil(M, ScaleBlockM); const auto BK = math::integer_divide_ceil(K, ScaleBlockK); if constexpr(is_same::value) { - return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(BK, I1)); + return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(StrideScaleA, I1)); } else if constexpr(is_same::value) { - return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(I1, BM)); + return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(I1, StrideScaleA)); } } - __host__ __device__ static constexpr auto MakeBScaleGridDesciptor_N_K(index_t N, index_t K) + __host__ __device__ static constexpr auto + MakeBScaleGridDesciptor_N_K(index_t N, index_t K, index_t StrideScaleB) { const auto BN = math::integer_divide_ceil(N, ScaleBlockN); const auto BK = math::integer_divide_ceil(K, ScaleBlockK); if constexpr(is_same::value) { - return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(BK, I1)); + return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(StrideScaleB, I1)); } else if constexpr(is_same::value) { - return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(I1, BN)); + return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(I1, StrideScaleB)); } } @@ -548,6 +552,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 index_t StrideB_, std::array StrideDs_, index_t StrideC_, + index_t StrideScaleA_, + index_t StrideScaleB_, index_t KBatch_) : M{M_}, N{N_}, @@ -556,6 +562,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 StrideB{StrideB_}, StrideDs{StrideDs_}, StrideC{StrideC_}, + StrideScaleA{StrideScaleA_}, + StrideScaleB{StrideScaleB_}, KBatch{KBatch_}, MPadded{CalculateMPadded(M_)}, NPadded{CalculateNPadded(N_)}, @@ -585,7 +593,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 index_t StrideB; std::array StrideDs; index_t StrideC; - + index_t StrideScaleA; + index_t StrideScaleB; index_t KBatch; index_t MPadded; index_t NPadded; @@ -611,13 +620,24 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 index_t StrideB_, std::array StrideDs_, index_t StrideC_, + index_t StrideScaleA_, + index_t StrideScaleB_, const AScaleType* p_a_scale_grid_, const BScaleType* p_b_scale_grid_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_) - : Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideC_, k_batch_}, + : Problem{M_, + N_, + K_, + StrideA_, + StrideB_, + StrideDs_, + StrideC_, + StrideScaleA_, + StrideScaleB_, + k_batch_}, p_a_grid{p_a_grid_}, p_b_grid{p_b_grid_}, p_ds_grid{}, @@ -673,6 +693,28 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 b_k_split_offset = blockIdx.z * karg.KRead; } + // Calculate A scale offset + if constexpr(is_same_v) + { + scale_a_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK); + } + else if constexpr(is_same_v) + { + scale_a_k_split_offset = + blockIdx.z * (karg.KRead / ScaleBlockK) * karg.StrideScaleA; + } + + // Calculate B scale offset + if constexpr(is_same_v) + { + scale_b_k_split_offset = + blockIdx.z * (karg.KRead / ScaleBlockK) * karg.StrideScaleB; + } + else if constexpr(is_same_v) + { + scale_b_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK); + } + if(blockIdx.z < static_cast(karg.KBatch - 1)) { karg.K = karg.KRead; @@ -685,6 +727,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 index_t a_k_split_offset; index_t b_k_split_offset; + index_t scale_a_k_split_offset; // A scale matrix offset + index_t scale_b_k_split_offset; // B scale matrix offset }; __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() @@ -1221,8 +1265,10 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); - const auto a_scale_grid_desc_am_ak = MakeAScaleGridDesciptor_M_K(problem.M, problem.K); - const auto b_scale_grid_desc_bn_ak = MakeBScaleGridDesciptor_N_K(problem.N, problem.K); + const auto a_scale_grid_desc_am_ak = + MakeAScaleGridDesciptor_M_K(problem.M, problem.K, problem.StrideScaleA); + const auto b_scale_grid_desc_bn_ak = + MakeBScaleGridDesciptor_N_K(problem.N, problem.K, problem.StrideScaleB); const auto c_grid_desc_mblock_mperblock_nblock_nperblock = MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_ab_scale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_ab_scale.hpp index faf10c2cce..d4ddbafeee 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_ab_scale.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_ab_scale.hpp @@ -16,7 +16,231 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + #if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8)) +#ifdef CK_USE_WMMA_FP8 +// Row, Col +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances); + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_kpadding_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances); + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances); + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_kpadding_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances); + +// Row, Row +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_comp_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances); + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_comp_kpadding_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances); + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_mem_v1_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances); + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_mem_v1_kpadding_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances); + +// Col, Row +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_comp_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances); + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_comp_kpadding_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances); + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_mem_v1_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances); + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_mem_v1_kpadding_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances); +#endif +#ifdef CK_USE_XDL // Row, Col void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances( std::vector>>& instances); #endif +#endif template -struct DeviceOperationInstanceFactory, - CLayout, - A0DataType, - A1DataType, - B0DataType, - B1DataType, - Tuple<>, - CDataType, - 1, - 128, - 128, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough>> +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGemmMultipleD_ABScaleSplitK, + CLayout, + A0DataType, + A1DataType, + B0DataType, + B1DataType, + Tuple<>, + CDataType, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>> +{ + using DeviceOp = DeviceGemmMultipleD_ABScaleSplitK, + CLayout, + A0DataType, + A1DataType, + B0DataType, + B1DataType, + Tuple<>, + CDataType, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8)) +#ifdef CK_USE_XDL + // No XDL instances for DeviceGemmMultipleABDSplitK with Add at the moment +#endif +#ifdef CK_USE_WMMA_FP8 + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances( + op_ptrs); + add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_kpadding_instances( + op_ptrs); + + add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_default_instances( + op_ptrs); + add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_kpadding_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_comp_default_instances( + op_ptrs); + add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_comp_kpadding_instances( + op_ptrs); + + add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_mem_v1_default_instances( + op_ptrs); + add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_mem_v1_kpadding_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_comp_default_instances( + op_ptrs); + add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_comp_kpadding_instances( + op_ptrs); + + add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_mem_v1_default_instances( + op_ptrs); + add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_mem_v1_kpadding_instances( + op_ptrs); + } + } +#endif +#endif + return op_ptrs; + } +}; + +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGemmMultipleD_ABScale, + CLayout, + A0DataType, + A1DataType, + B0DataType, + B1DataType, + Tuple<>, + CDataType, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>> { using DeviceOp = DeviceGemmMultipleD_ABScale; + PassThrough, + PassThrough, + PassThrough>; static auto GetInstances() { std::vector> op_ptrs; #if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8)) +#ifdef CK_USE_XDL if constexpr(is_same_v && is_same_v && is_same_v) { @@ -328,6 +655,33 @@ struct DeviceOperationInstanceFactory, + CLayout, + A0DataType, + A1DataType, + B0DataType, + B1DataType, + Tuple<>, + CDataType, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>; + + auto new_op_ptrs = + DeviceOperationInstanceFactory::GetInstances(); + for(auto& op_ptr : new_op_ptrs) + { + op_ptrs.emplace_back(std::make_unique(std::move(op_ptr))); + } +#endif // CK_USE_WMMA_FP8 #endif return op_ptrs; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_blockscale_wp.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_blockscale_wp.hpp index a8d9545194..d660c18fd0 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_blockscale_wp.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_blockscale_wp.hpp @@ -17,6 +17,47 @@ namespace tensor_operation { namespace device { namespace instance { #if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8)) +#if(defined(CK_USE_WMMA) && defined(CK_USE_WMMA_FP8)) +void add_device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& + instances); + +void add_device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& + instances); +#endif // CK_USE_WMMA && CK_USE_WMMA_FP8 + +#ifdef CK_USE_XDL void add_device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances( std::vector>>& instances); #endif +#endif + +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK< + ALayout, + BLayout, + Tuple<>, + CLayout, + A0DataType, + A1DataType, + B0DataType, + B1DataType, + Tuple<>, + CDataType, + 1, + 128, + 128, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough>> +{ + using DeviceOp = DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK< + ALayout, + BLayout, + Tuple<>, + CLayout, + A0DataType, + A1DataType, + B0DataType, + B1DataType, + Tuple<>, + CDataType, + 1, + 128, + 128, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough>; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#if defined(CK_USE_XDL) + // No XDL instances for DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK at the moment +#endif // CK_USE_XDL + +#if(defined(CK_USE_WMMA) && defined(CK_USE_WMMA_FP8)) +#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8)) + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances( + op_ptrs); + + add_device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_default_instances( + op_ptrs); + } + } +#endif +#endif // CK_USE_WMMA && CK_USE_WMMA_FP8 + + return op_ptrs; + } +}; template > op_ptrs; +#ifdef CK_USE_XDL #if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8)) if constexpr(is_same_v && is_same_v && is_same_v) @@ -162,6 +280,35 @@ struct DeviceOperationInstanceFactory< } } #endif +#endif + +#if defined(CK_USE_WMMA) + // Reuse DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK instances + using Wrapper = DeviceGemmMultipleD_BlockScale_BPreshuffleWrapper< + ALayout, + BLayout, + Tuple<>, + CLayout, + A0DataType, + A1DataType, + B0DataType, + B1DataType, + Tuple<>, + CDataType, + 1, + 128, + 128, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough>; + auto new_op_ptrs = + DeviceOperationInstanceFactory::GetInstances(); + for(auto& op_ptr : new_op_ptrs) + { + op_ptrs.emplace_back(std::make_unique(std::move(op_ptr))); + } +#endif // CK_USE_WMMA + return op_ptrs; } }; diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index ef037526ca..575e14d5bb 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -103,6 +103,16 @@ function(add_instance_library INSTANCE_NAME) message(DEBUG "removing gemm_universal_preshuffle_f8 instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() + # Do not build gemm_ab_scale_f8 for any targets except gfx94, gfx95 and gfx12 + if(NOT (INST_TARGETS MATCHES "gfx942" OR INST_TARGETS MATCHES "gfx950" OR INST_TARGETS MATCHES "gfx12") AND (source_name MATCHES "gemm_ab_scale") AND (source_name MATCHES "_f8_f8_")) + message(DEBUG "removing gemm_ab_scale_f8 instance ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() + # Do not build gemm_blockscale_wp_f8 for any targets except gfx94, gfx95 and gfx12 + if(NOT (INST_TARGETS MATCHES "gfx942" OR INST_TARGETS MATCHES "gfx950" OR INST_TARGETS MATCHES "gfx12") AND (source_name MATCHES "gemm_blockscale_wp") AND (source_name MATCHES "_f8_f8_")) + message(DEBUG "removing gemm_blockscale_wp_f8 instance ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() # Only build tf32 instances for gfx942 & gfx950 if(source_name MATCHES "_tf32_") if(NOT ((INST_TARGETS MATCHES "gfx942|gfx950") AND CK_ENABLE_TF32)) @@ -300,7 +310,7 @@ FOREACH(subdir_path ${dir_list}) message(DEBUG "Found gemm_multiply_multiply instances, but gfx94/gfx95/gfx11/gfx12 not on the target list. Skipping. ${cmake_instance}") set(add_inst 0) endif() - if(("${cmake_instance}" MATCHES "gemm_universal_preshuffle|gemm_blockscale" AND "${cmake_instance}" MATCHES "_f8_" ) AND (NOT INST_TARGETS MATCHES "gfx94|gfx95|gfx12") AND (NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH)) + if(("${cmake_instance}" MATCHES "gemm_universal_preshuffle|gemm_blockscale|gemm_ab_scale" AND "${cmake_instance}" MATCHES "_f8_" ) AND (NOT INST_TARGETS MATCHES "gfx94|gfx95|gfx12") AND (NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH)) message(DEBUG "Found gemm_f8 instances, but gfx94/gfx95 not on the target list. Skipping.") set(add_inst 0) endif() diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/CMakeLists.txt index a315db8bdd..0512b01175 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/CMakeLists.txt @@ -1,21 +1,38 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS set(GEMM_AB_SCALE_INSTANCES) list(APPEND GEMM_AB_SCALE_INSTANCES # Row, Col + device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp + device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp + device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp + device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp + device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp + # Row, Row + device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_comp_default_instance.cpp + device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_comp_kpadding_instance.cpp + device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_mem_v1_default_instance.cpp + device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_mem_v1_kpadding_instance.cpp + device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_128_128_128_comp_default_instance.cpp device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_128_128_128_comp_kpadding_instance.cpp device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_128_128_128_mem_v1_default_instance.cpp device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_128_128_128_mem_v1_kpadding_instance.cpp + # Col, Row + device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_comp_default_instance.cpp + device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_comp_kpadding_instance.cpp + device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_mem_v1_default_instance.cpp + device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_mem_v1_kpadding_instance.cpp + device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_128_128_128_comp_default_instance.cpp device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_128_128_128_comp_kpadding_instance.cpp device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_128_128_128_mem_v1_default_instance.cpp @@ -27,11 +44,13 @@ set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_s set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") + # Row, Row set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_128_128_128_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_128_128_128_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_128_128_128_mem_v1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_128_128_128_mem_v1_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") + # Col, Row set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_128_128_128_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_128_128_128_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128.hpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128.hpp new file mode 100644 index 0000000000..a4058ca1c2 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128.hpp @@ -0,0 +1,79 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3_ab_scale.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = f8_t; +using BF16 = bhalf_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +template +using device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_comp_instances = std::tuple< + // clang-format off + //#######################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CBlockTransferClusterLengths | CShuffleBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#######################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| Block| Block| Block| | |Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat | NRepeat | _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| Pipeline| Pipeline| + //#######################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle | PerShuffle | | | Scheduler| Version| + //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 128, 16, 16, 16, 16, 4, 2, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 128, 16, 16, 16, 16, 2, 2, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, 1, 1, S<1, 64, 1, 4>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 16, 16, 16, 2, 2, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 16, 16, 16, 2, 1, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 128, 16, 16, 16, 16, 4, 2, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 128, 16, 16, 16, 16, 2, 2, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, 1, 1, S<1, 64, 1, 4>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 16, 16, 16, 2, 2, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 16, 16, 16, 2, 1, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + // clang-format on + >; + +template +using device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_mem_instances = std::tuple< + // clang-format off + //#######################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CBlockTransferClusterLengths | CShuffleBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#######################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| Block| Block| Block| | |Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat | NRepeat | _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| Pipeline| Pipeline| + //#######################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle | PerShuffle | | | Scheduler| Version| + //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 256, 128, 8, 16, 16, 16, 1, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 128, 128, 8, 16, 16, 16, 1, 1, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 128, 16, 64, 128, 8, 16, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 0, S< 8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 256, 128, 16, 16, 16, 16, 2, 2, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 128, 128, 16, 16, 16, 16, 2, 1, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 64, 128, 16, 16, 16, 16, 1, 1, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 256, 128, 16, 16, 16, 16, 4, 2, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 16, 16, 16, 4, 1, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 16, 16, 16, 2, 1, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_comp_default_instance.cpp new file mode 100644 index 0000000000..ad0667dd10 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_comp_default_instance.cpp @@ -0,0 +1,37 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_comp_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_comp_kpadding_instance.cpp new file mode 100644 index 0000000000..dbdfd41e32 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_comp_kpadding_instance.cpp @@ -0,0 +1,37 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_comp_kpadding_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_mem_v1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_mem_v1_default_instance.cpp new file mode 100644 index 0000000000..1380df5291 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_mem_v1_default_instance.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_mem_v1_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_mem_v1_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_mem_v1_kpadding_instance.cpp new file mode 100644 index 0000000000..90dbb9c9d5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_mem_v1_kpadding_instance.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_mem_v1_kpadding_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128.hpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128.hpp new file mode 100644 index 0000000000..c45adb91c6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128.hpp @@ -0,0 +1,80 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3_ab_scale.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = f8_t; +using BF16 = bhalf_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +template +using device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_comp_instances = std::tuple< + // clang-format off + //#######################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CBlockTransferClusterLengths | CShuffleBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#######################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| Block| Block| Block| | |Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat | NRepeat | _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| Pipeline| Pipeline| + //#######################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle | PerShuffle | | | Scheduler| Version| + //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 128, 16, 16, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 128, 16, 16, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, 1, 1, S<1, 64, 1, 4>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 16, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 16, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 128, 16, 16, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 128, 16, 16, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, 1, 1, S<1, 64, 1, 4>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 16, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 16, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + // clang-format on + >; + +template +using device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_mem_instances = std::tuple< + // clang-format off + //#######################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CBlockTransferClusterLengths | CShuffleBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#######################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| Block| Block| Block| | |Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat | NRepeat | _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| Pipeline| Pipeline| + //#######################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle | PerShuffle | | | Scheduler| Version| + //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // Memory friendly + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 256, 128, 8, 16, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 128, 128, 8, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 128, 16, 64, 128, 8, 16, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 256, 128, 16, 16, 16, 16, 2, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 128, 128, 16, 16, 16, 16, 2, 1, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 64, 128, 16, 16, 16, 16, 1, 1, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 256, 128, 16, 16, 16, 16, 4, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 16, 16, 16, 4, 1, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 16, 16, 16, 2, 1, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_comp_default_instance.cpp new file mode 100644 index 0000000000..766279520a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_comp_default_instance.cpp @@ -0,0 +1,37 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_comp_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_comp_kpadding_instance.cpp new file mode 100644 index 0000000000..b837c35810 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_comp_kpadding_instance.cpp @@ -0,0 +1,37 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_comp_kpadding_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_mem_v1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_mem_v1_default_instance.cpp new file mode 100644 index 0000000000..2fc87ba6ad --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_mem_v1_default_instance.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_mem_v1_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_mem_v1_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_mem_v1_kpadding_instance.cpp new file mode 100644 index 0000000000..2188a64c98 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_mem_v1_kpadding_instance.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_mem_v1_kpadding_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128.hpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128.hpp new file mode 100644 index 0000000000..cc1be58946 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128.hpp @@ -0,0 +1,95 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3_ab_scale.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = f8_t; +using BF16 = bhalf_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +template +using device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_instances = std::tuple< + // clang-format off + //#######################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CBlockTransferClusterLengths | CShuffleBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#######################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| Block| | |Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat | NRepeat | _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| Pipeline| Pipeline| + //#######################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle | PerShuffle | | | Scheduler| Version| + //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + // Compute friendly + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 128, 16, 16, 16, 16, 4, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 128, 16, 16, 16, 16, 2, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 4>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 16, 16, 16, 2, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 16, 16, 16, 2, 1, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 128, 16, 16, 16, 16, 4, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 128, 16, 16, 16, 16, 2, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 4>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 16, 16, 16, 2, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 16, 16, 16, 2, 1, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 64, 16, 16, 16, 16, 4, 2, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 64, 16, 16, 16, 16, 2, 1, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 64, 16, 16, 16, 16, 4, 2, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 64, 16, 16, 16, 16, 2, 1, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8> + // clang-format on + >; + +template +using device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_instances = std::tuple< + // clang-format off + //#######################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CBlockTransferClusterLengths | CShuffleBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#######################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| Block| | |Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat | NRepeat | _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| Pipeline| Pipeline| + //#######################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle | PerShuffle | | | Scheduler| Version| + //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + // Memory friendly + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 256, 128, 8, 16, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 128, 128, 8, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 128, 16, 64, 128, 8, 16, 16, 16, 1, 1, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 128, 256, 8, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 128, 16, 64, 256, 8, 16, 16, 16, 1, 1, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 256, 128, 16, 16, 16, 16, 1, 4, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 128, 128, 16, 16, 16, 16, 2, 1, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 64, 128, 16, 16, 16, 16, 1, 1, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 128, 256, 16, 16, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 128, 32, 64, 256, 16, 16, 16, 16, 2, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 256, 128, 16, 16, 16, 16, 2, 4, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 16, 16, 16, 2, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 16, 16, 16, 2, 1, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 256, 16, 16, 16, 16, 2, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 256, 16, 16, 16, 16, 2, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp new file mode 100644 index 0000000000..3c140ef980 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp @@ -0,0 +1,37 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp new file mode 100644 index 0000000000..d68b755506 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp @@ -0,0 +1,37 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_kpadding_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp new file mode 100644 index 0000000000..5822fd0b2a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp new file mode 100644 index 0000000000..f4661891d1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_kpadding_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/CMakeLists.txt index b37a22d895..dd7596447e 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/CMakeLists.txt @@ -1,7 +1,7 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]|gfx12") set(GEMM_BLOCKSCALE_WP_INSTANCES) @@ -10,6 +10,9 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]|gfx12") device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp + + device_gemm_blockscale_wp_wmma_f8_f8_bf16/device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp + device_gemm_blockscale_wp_wmma_f8_f8_bf16/device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_mem_default_instance.cpp ) check_cxx_compiler_flag("-mllvm --misched-bottomup=1" HAS_MISCHED_BOTTOMUP) check_cxx_compiler_flag("-mllvm --misched-prera-direction=bottomup" HAS_MISCHED_PRERA_DIRECTION) diff --git a/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_wmma_f8_f8_bf16/device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_128_128_128.hpp b/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_wmma_f8_f8_bf16/device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_128_128_128.hpp new file mode 100644 index 0000000000..023d1ac2b8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_wmma_f8_f8_bf16/device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_128_128_128.hpp @@ -0,0 +1,77 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3_blockscale_bpreshuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = f8_t; +using BF16 = bhalf_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +template +using device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_instances = std::tuple< + // clang-format off + //######################################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CDEBlockTransferClusterLengths| CShuffleBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //######################################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVectors| Pipeline| Pipeline| + //######################################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| | | Scheduler| Verision| + //######################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 128, 16, 16, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S <1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 16, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S <1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 16, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S <1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8> + // clang-format on + >; + +template +using device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_instances = std::tuple< + // clang-format off + //######################################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CDEBlockTransferClusterLengths| CShuffleBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //######################################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVectors| Pipeline| Pipeline| + //######################################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| | | Scheduler| Verision| + //######################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 256, 128, 8, 16, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 128, 128, 8, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 128, 16, 64, 128, 8, 16, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 128, 256, 16, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 128, 16, 64, 256, 16, 16, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 256, 128, 16, 16, 16, 16, 2, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 128, 128, 16, 16, 16, 16, 2, 1, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 64, 128, 16, 16, 16, 16, 1, 1, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 128, 256, 16, 16, 16, 16, 2, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 64, 256, 16, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 256, 128, 16, 16, 16, 16, 4, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 16, 16, 16, 4, 1, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 16, 16, 16, 2, 1, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 256, 16, 16, 16, 16, 4, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 256, 16, 16, 16, 16, 2, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_wmma_f8_f8_bf16/device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_wmma_f8_f8_bf16/device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp new file mode 100644 index 0000000000..59fe63421a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_wmma_f8_f8_bf16/device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_128_128_128.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_wmma_f8_f8_bf16/device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_mem_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_wmma_f8_f8_bf16/device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_mem_default_instance.cpp new file mode 100644 index 0000000000..2b5670ead3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_wmma_f8_f8_bf16/device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_mem_default_instance.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_128_128_128.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_gemm_ab_scale_impl.hpp b/profiler/include/profiler/profile_gemm_ab_scale_impl.hpp index 5396a52e21..f3055575ea 100644 --- a/profiler/include/profiler/profile_gemm_ab_scale_impl.hpp +++ b/profiler/include/profiler/profile_gemm_ab_scale_impl.hpp @@ -109,8 +109,8 @@ bool profile_gemm_ab_scale_impl(int do_verification, case 1: a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - a1_m_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); - b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + a1_m_k.GenerateTensorValue(GeneratorTensor_2{-1, 2}); + b1_k_n.GenerateTensorValue(GeneratorTensor_2{-1, 2}); break; default: a0_m_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); @@ -302,7 +302,7 @@ bool profile_gemm_ab_scale_impl(int do_verification, float gb_per_sec = num_btype / 1.E6 / ave_time; std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " - << gb_per_sec << " GB/s, " << op_name << std::endl; + << gb_per_sec << " GB/s, " << op_name << ", KBatch " << KBatch << std::endl; if(tflops > best_tflops) { diff --git a/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp b/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp index 49fef5a0fc..8642cc59e6 100644 --- a/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp +++ b/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp @@ -29,7 +29,7 @@ void preShuffleBuffer(const InOutDataType* src, InOutDataType* dst, int N, int K { int KPack = 16; int NLane = NXdl; - int KLane = 64 / NLane; + int KLane = ck::get_warp_size() / NLane; int K0 = K / (KLane * KPack); // K -> K0 KLane KPack diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index b7db14945d..802f29024c 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -261,6 +261,7 @@ add_subdirectory(gemm_multiply_multiply_wp) add_subdirectory(gemm_split_k) add_subdirectory(gemm_universal) add_subdirectory(gemm_universal_preshuffle) +add_subdirectory(gemm_ab_scale) add_subdirectory(gemm_b_scale) add_subdirectory(gemm_universal_streamk) add_subdirectory(gemm_reduce) diff --git a/test/gemm_ab_scale/CMakeLists.txt b/test/gemm_ab_scale/CMakeLists.txt new file mode 100644 index 0000000000..21203aafaa --- /dev/null +++ b/test/gemm_ab_scale/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +if(GPU_TARGETS MATCHES "gfx9[45]|gfx12") + add_gtest_executable(test_gemm_ab_scale test_gemm_ab_scale.cpp) + if(result EQUAL 0) + target_link_libraries(test_gemm_ab_scale PRIVATE utility device_gemm_ab_scale_instance) + endif() +endif() diff --git a/test/gemm_ab_scale/test_gemm_ab_scale.cpp b/test/gemm_ab_scale/test_gemm_ab_scale.cpp new file mode 100644 index 0000000000..01c3e2ffdb --- /dev/null +++ b/test/gemm_ab_scale/test_gemm_ab_scale.cpp @@ -0,0 +1,236 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "gtest/gtest.h" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "test_gemm_ab_scale_util.hpp" + +using BF16 = ck::bhalf_t; +using F32 = float; +using F8 = ck::f8_t; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +namespace { + +template +struct tuple_concat; + +template +struct tuple_concat, std::tuple> +{ + using type = std::tuple; +}; + +} // namespace + +template +class TestGemmABScale_MK_NK : public ck::test::TestGemmABScale< + typename tuple_concat, Tuple>::type> +{ +}; + +template +class TestGemmABScale_MK_KN : public ck::test::TestGemmABScale< + typename tuple_concat, Tuple>::type> +{ +}; + +template +class TestGemmABScale_KM_KN : public ck::test::TestGemmABScale< + typename tuple_concat, Tuple>::type> +{ +}; + +// clang-format off +using KernelTypes = ::testing::Types< + // ADataType, BDataType, ComputeDataType, EDataType + std::tuple< F8, F32, F8, F32, F8, BF16> + >; +// clang-format on + +TYPED_TEST_SUITE(TestGemmABScale_MK_NK, KernelTypes); +TYPED_TEST_SUITE(TestGemmABScale_MK_KN, KernelTypes); +TYPED_TEST_SUITE(TestGemmABScale_KM_KN, KernelTypes); + +// Row Col +TYPED_TEST(TestGemmABScale_MK_NK, SmallM) +{ + std::vector Ms{1, 2, 3, 4, 5, 6}; + constexpr int N = 512; + constexpr int K = 1024; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmABScale_MK_NK, SmallMPadK) +{ + std::vector Ms{1, 2, 3, 4, 5, 6}; + constexpr int N = 512; + constexpr int K = 704; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmABScale_MK_NK, MidLargeM) +{ + std::vector Ms{127, 255, 312, 799, 1573}; + constexpr int N = 512; + constexpr int K = 1024; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmABScale_MK_NK, Regular) +{ + std::vector Ms{512}; + constexpr int N = 512; + constexpr int K = 1024; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideE = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideE); +} + +// Row Row +TYPED_TEST(TestGemmABScale_MK_KN, SmallM) +{ + std::vector Ms{1, 2, 3, 4, 5, 6}; + constexpr int N = 512; + constexpr int K = 1024; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmABScale_MK_KN, SmallMPadK) +{ + std::vector Ms{1, 2, 3, 4, 5, 6}; + constexpr int N = 512; + constexpr int K = 704; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmABScale_MK_KN, MidLargeM) +{ + std::vector Ms{127, 255, 312, 799, 1573}; + constexpr int N = 512; + constexpr int K = 1024; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmABScale_MK_KN, Regular) +{ + std::vector Ms{512}; + constexpr int N = 512; + constexpr int K = 1024; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideE = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideE); +} + +// Col Row +TYPED_TEST(TestGemmABScale_KM_KN, SmallM) +{ + std::vector Ms{16, 32}; + constexpr int N = 512; + constexpr int K = 1024; + + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + { + int StrideA = M; + this->Run(M, N, K, StrideA, StrideB, StrideC); + } +} + +TYPED_TEST(TestGemmABScale_KM_KN, SmallMPadK) +{ + std::vector Ms{16, 32}; + constexpr int N = 512; + constexpr int K = 704; + + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + { + int StrideA = M; + this->Run(M, N, K, StrideA, StrideB, StrideC); + } +} + +TYPED_TEST(TestGemmABScale_KM_KN, MidLargeM) +{ + std::vector Ms{128, 256}; + constexpr int N = 512; + constexpr int K = 1024; + + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + { + int StrideA = M; + this->Run(M, N, K, StrideA, StrideB, StrideC); + } +} + +TYPED_TEST(TestGemmABScale_KM_KN, Regular) +{ + std::vector Ms{512}; + constexpr int N = 512; + constexpr int K = 1024; + + constexpr int StrideB = N; + constexpr int StrideE = N; + + for(int M : Ms) + { + int StrideA = M; + this->Run(M, N, K, StrideA, StrideB, StrideE); + } +} diff --git a/test/gemm_ab_scale/test_gemm_ab_scale_util.hpp b/test/gemm_ab_scale/test_gemm_ab_scale_util.hpp new file mode 100644 index 0000000000..b54e5ce2e5 --- /dev/null +++ b/test/gemm_ab_scale/test_gemm_ab_scale_util.hpp @@ -0,0 +1,102 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "include/ck/utility/data_type.hpp" +#include "profiler/profile_gemm_ab_scale_impl.hpp" + +namespace ck { +namespace test { + +template +class TestGemmABScale : public testing::Test +{ + using F32 = float; + + protected: + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using ELayout = std::tuple_element_t<2, Tuple>; + using A0DataType = std::tuple_element_t<3, Tuple>; + using A1DataType = std::tuple_element_t<4, Tuple>; + using B0DataType = std::tuple_element_t<5, Tuple>; + using B1DataType = std::tuple_element_t<6, Tuple>; + using ComputeDataType = std::tuple_element_t<7, Tuple>; + using EDataType = std::tuple_element_t<8, Tuple>; + + public: + static constexpr ck::index_t ScaleBlockM = 1; + static constexpr ck::index_t ScaleBlockN = 128; + static constexpr ck::index_t ScaleBlockK = 128; + static constexpr bool verify_ = true; + static constexpr int init_method_ = 1; + static constexpr bool log_ = false; + static constexpr bool bench_ = false; + std::vector k_batches_; + + void SetUp() override { k_batches_ = {1, 2}; } + + void Run(const int M, + const int N, + const int K, + const int StrideA, + const int StrideB, + const int StrideE) + { + for(auto kb : k_batches_) + { + RunSingle(M, N, K, StrideA, StrideB, StrideE, kb); + } + } + + void RunSingle(const int M, + const int N, + const int K, + const int StrideA, + const int StrideB, + const int StrideE, + int kbatch = 1, + int n_warmup = 1, + int n_iter = 10) + { + bool pass = ck::profiler::profile_gemm_ab_scale_impl(verify_, + init_method_, + log_, + bench_, + M, + N, + K, + StrideA, + StrideB, + StrideE, + kbatch, + n_warmup, + n_iter); + EXPECT_TRUE(pass); + } +}; + +} // namespace test +} // namespace ck diff --git a/test/gemm_blockscale_wp/CMakeLists.txt b/test/gemm_blockscale_wp/CMakeLists.txt index a095968035..a0750255d1 100644 --- a/test/gemm_blockscale_wp/CMakeLists.txt +++ b/test/gemm_blockscale_wp/CMakeLists.txt @@ -2,8 +2,8 @@ # SPDX-License-Identifier: MIT if(GPU_TARGETS MATCHES "gfx9[45]|gfx12") - add_gtest_executable(test_gemm_blockscale_wp_xdl_fp8 test_gemm_blockscale_wp_xdl_fp8.cpp) + add_gtest_executable(test_gemm_blockscale_wp_fp8 test_gemm_blockscale_wp_fp8.cpp) if(result EQUAL 0) - target_link_libraries(test_gemm_blockscale_wp_xdl_fp8 PRIVATE utility device_gemm_blockscale_wp_instance) + target_link_libraries(test_gemm_blockscale_wp_fp8 PRIVATE utility device_gemm_blockscale_wp_instance) endif() endif() diff --git a/test/gemm_blockscale_wp/test_gemm_blockscale_wp_xdl_fp8.cpp b/test/gemm_blockscale_wp/test_gemm_blockscale_wp_fp8.cpp similarity index 100% rename from test/gemm_blockscale_wp/test_gemm_blockscale_wp_xdl_fp8.cpp rename to test/gemm_blockscale_wp/test_gemm_blockscale_wp_fp8.cpp From 715671e419cbbebe72109ceeeed9d582cca34d02 Mon Sep 17 00:00:00 2001 From: eliotwang <46883838+eliotwang@users.noreply.github.com> Date: Thu, 11 Dec 2025 23:20:29 +0800 Subject: [PATCH 08/10] Bf16*fp4 gemm (#2801) * support bf16*mxfp4 gemm * rebase bf16*fp4 example to develop branch * Clean up commented debug code in GEMM kernel * rename example folder * support bf16*mxfp4 gemm * rebase bf16*fp4 example to develop branch * Clean up commented debug code in GEMM kernel * rename example folder * rebase to new develop * fix clang format * update code according to reviewer's comment * Update README.md * update code according to reviewer's comment * update code according to reviewer's comment * Update CMakeLists.txt * Update README.md * Update CMakeLists.txt * Delete files * Delete files * Add unit tests * Update test_gemm_quant_base.hpp * merge bf16*fp4 example to develop branch * fix clang format * fix clang format * Update CMakeLists.txt * fix ci test * fix clang format * resolve conflicts --------- Co-authored-by: eliotwang Co-authored-by: ShaoChunLee Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> Co-authored-by: illsilin_amdeng Co-authored-by: Thomas Ning --- .../38_block_scale_gemm/CMakeLists.txt | 1 + example/ck_tile/38_block_scale_gemm/README.md | 4 +- .../gemm_bquant_quantgrouped_bf16mxfp4.cpp | 41 ++ .../38_block_scale_gemm/gemm_quant.cpp | 5 +- .../38_block_scale_gemm/gemm_utils.hpp | 6 +- .../run_gemm_quant_example.inc | 121 ++-- .../core/arch/amd_buffer_addressing.hpp | 7 +- include/ck_tile/host/check_err.hpp | 32 +- .../ck_tile/host/reference/reference_gemm.hpp | 57 ++ include/ck_tile/ops/common/utils.hpp | 1 + .../ops/epilogue/cshuffle_epilogue.hpp | 14 +- .../block/block_universal_gemm_as_bs_cr.hpp | 6 +- .../pipeline/gemm_pipeline_ag_bg_cr_base.hpp | 21 +- ...emm_universal_pipeline_ag_bg_cr_policy.hpp | 44 +- include/ck_tile/ops/gemm_quant.hpp | 3 + .../gemm_quant/kernel/gemm_quant_kernel.hpp | 49 +- .../gemm_mxfp4_pipeline_ag_bg_cr_base.hpp | 59 ++ .../gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp | 140 ++++ .../gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp | 665 ++++++++++++++++++ test/ck_tile/gemm_block_scale/CMakeLists.txt | 0 .../gemm_block_scale/test_gemm_quant_base.hpp | 6 +- .../test_gemm_quant_bquant.cpp | 6 + .../test_gemm_quant_fixtures.hpp | 109 ++- 23 files changed, 1260 insertions(+), 137 deletions(-) create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxfp4.cpp create mode 100644 include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp create mode 100644 include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp create mode 100644 include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp mode change 100644 => 100755 test/ck_tile/gemm_block_scale/CMakeLists.txt diff --git a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt index d6b63dc47b..40f06ec97a 100644 --- a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt +++ b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt @@ -16,6 +16,7 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") gemm_aquant_quantgrouped_preshufflequant.cpp gemm_bquant_quantgrouped_bf8i4.cpp gemm_bquant_quantgrouped_fp8i4.cpp + gemm_bquant_quantgrouped_bf16mxfp4.cpp gemm_bquant_quantgrouped_bf8.cpp gemm_bquant_quantgrouped_fp8.cpp gemm_bquant_quantgrouped_preshuffleb.cpp diff --git a/example/ck_tile/38_block_scale_gemm/README.md b/example/ck_tile/38_block_scale_gemm/README.md index 3a30c2bad3..eb36ae5800 100644 --- a/example/ck_tile/38_block_scale_gemm/README.md +++ b/example/ck_tile/38_block_scale_gemm/README.md @@ -23,7 +23,7 @@ This folder contains examples of quant GEMMs using the ck_tile tile-programming - **Preshuffled GEMM**: Shuffle the GEMM of B (weight) matrix in the warp layout and bypass the shared memory to do the GEMM calculation. Best performance solution for GEMM. - **TransposeC**: Transpose the C Matrix Output layout to have the best coalesced scale reading - **Preshuffled Quant**: Preshuffle the input matrix to load multiple Quant warp blocks along the selected dimension. -- **Precision**: Supports fp16, bf16, fp8, bf8, int4 (for B Matrix). +- **Precision**: Supports fp16, bf16, fp8, bf8, int4 (for B Matrix), uint8 (split into two fp4 in the pipeline (for B Matrix)). - **Validation**: CPU/GPU validation and error tolerance options. ## build @@ -53,7 +53,7 @@ args: -stride_b Tensor B stride (default:0) -stride_c Tensor C stride (default:0) -v 0: No validation, 1: Validation on CPU, 2: Validation on GPU (default:1) - -prec Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, or bf8i4 (default for both AQuant and Bquant: fp8) + -prec Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, bf8i4, or bf16fp4 (default for both AQuant and Bquant: fp8) -warmup Number of iterations before benchmarking the kernel (default:50) -repeat Number of iterations to benchmark the kernel (default:1000) -timer gpu:gpu timer, cpu:cpu timer (default:gpu) diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxfp4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxfp4.cpp new file mode 100644 index 0000000000..a022ce18e1 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxfp4.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved. + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigQuantPrefill; + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +void bquant_quantgrouped_bf16fp4_instance_factory( + std::unordered_map>& lut) +{ + using TypeConfig = decltype(GemmQuantTypeConfig{}); + + lut[hash_multiple_strings( + {"bf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x32"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"bf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x64"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"bf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp index 45d2151d5e..669bce2995 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp @@ -32,7 +32,7 @@ auto create_args(int argc, char* argv[]) .insert("prec", "fp8", "Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, " - "or bf8i4") + "bf8i4 or bf16fp4") .insert("warmup", "50", "Number of iterations before benchmarking the kernel") .insert("repeat", "1000", "Number of iterations to benchmark the kernel") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") @@ -97,6 +97,8 @@ void bquant_quantgrouped_fp8i4_instance_factory( std::unordered_map>& lut); void bquant_quantgrouped_bf8i4_instance_factory( std::unordered_map>& lut); +void bquant_quantgrouped_bf16fp4_instance_factory( + std::unordered_map>& lut); void bquant_quantgrouped_preshuffleb_instance_factory( std::unordered_map>& lut); void bquant_quantgrouped_preshufflequant_instance_factory( @@ -128,6 +130,7 @@ int main(int argc, char* argv[]) bquant_quantgrouped_bf8_instance_factory(lut); bquant_quantgrouped_fp8i4_instance_factory(lut); bquant_quantgrouped_bf8i4_instance_factory(lut); + bquant_quantgrouped_bf16fp4_instance_factory(lut); bquant_quantgrouped_preshuffleb_instance_factory(lut); bquant_quantgrouped_preshufflequant_instance_factory(lut); bquant_quantgrouped_preshuffleb_preshufflequant_instance_factory(lut); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index 2b2333b04c..aabbfff3bd 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -69,8 +69,10 @@ auto calculate_rtol_atol(const ck_tile::index_t K, const ck_tile::index_t kbatch, const float max_accumulated_value) { - using ComputeType = - std::conditional_t; + using ComputeType = std::conditional_t< + std::is_same_v, + ADataType, + std::conditional_t>; // Calculate thresholds const auto rtol = ck_tile::get_relative_threshold( ck_tile::integer_divide_ceil(K, kbatch)); diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 8a0dd9bc08..fa5e1f12e3 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -136,9 +136,13 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str std::conditional_t, ck_tile::AQuantGemmPipelineAgBgCrMem>, - std::conditional_t, - ck_tile::BQuantGemmPipelineAgBgCrCompV3>>>; + std::conditional_t< + GemmConfig::PreshuffleB == true, + ck_tile::WPQuantBPipelineAgBgCrV2, + std::conditional_t< + std::is_same_v, + ck_tile::MxFp4GemmPipelineAgBgCrCompV3, + ck_tile::BQuantGemmPipelineAgBgCrCompV3>>>>; constexpr bool TiledPermuteN = (QuantGroupSize::kN > 1) ? false : GemmConfig::TiledMMAPermuteN; @@ -147,28 +151,31 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str printf( "TiledPermuteN: %d (QuantGroupSize::kN=%d)\n", TiledPermuteN, QuantGroupSize::kN); } - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, - typename TypeConfig::AccDataType, - typename TypeConfig::CDataType, - ck_tile::tuple<>, - CLayout, - CDEElementWise, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - GemmConfig::M_Warp, - GemmConfig::N_Warp, - GemmConfig::M_Warp_Tile, - GemmConfig::N_Warp_Tile, - GemmConfig::K_Warp_Tile, - transpose_c, - ck_tile::memory_operation_enum::set, - 1, - false, - 1, - TiledPermuteN>>; + using GemmEpilogue = ck_tile::CShuffleEpilogue, + typename TypeConfig::ADataType, + typename TypeConfig::BDataType>, + ck_tile::tuple<>, + typename TypeConfig::AccDataType, + typename TypeConfig::CDataType, + ck_tile::tuple<>, + CLayout, + CDEElementWise, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, + transpose_c, + ck_tile::memory_operation_enum::set, + 1, + false, + 1, + TiledPermuteN>>; using Kernel = ck_tile::QuantGemmKernel; @@ -205,7 +212,11 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( args.M, args.K, args.stride_A, is_row_major(ALayout{}))); ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( - args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + std::is_same_v ? args.K / 2 + : args.K, + args.N, + args.stride_B, + is_row_major(BLayout{}))); auto size_a_buffer = a_m.get_element_space_size_in_bytes(); auto size_b_buffer = b_n.get_element_space_size_in_bytes(); @@ -427,7 +438,11 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, int rotating_count = arg_parser.get_int("rotating_count"); stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); - stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); + stride_B = ck_tile::get_default_stride( + (std::is_same_v) ? (K / 2) : K, + N, + stride_B, + is_row_major(b_layout)); stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{})); // Conditional stride calculation based on QuantMode @@ -454,8 +469,11 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, ck_tile::HostTensor a_m_k( ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout))); - ck_tile::HostTensor b_k_n( - ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout))); + ck_tile::HostTensor b_k_n(ck_tile::host_tensor_descriptor( + (std::is_same_v) ? (K / 2) : K, + N, + stride_B, + is_row_major(b_layout))); ck_tile::HostTensor c_m_n_dev_result( ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); @@ -499,13 +517,22 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, { ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( b_k_n); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *bq_tensor_ptr); + } + else if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); + ck_tile::FillUniformDistribution{125.f, 130.f, fill_seed(gen)}( + *bq_tensor_ptr); } else { ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *bq_tensor_ptr); } - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *bq_tensor_ptr); + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(a_m_k); } else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) @@ -721,13 +748,23 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, } else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) { - ck_tile::reference_gemm_quant(a_m_k, *bq_tensor_ptr, b_k_n, c_m_n_host_ref); + if constexpr(std::is_same_v) + ck_tile::reference_mxfp4gemm_quant( + a_m_k, *bq_tensor_ptr, b_k_n, c_m_n_host_ref); + else + ck_tile::reference_gemm_quant(a_m_k, *bq_tensor_ptr, b_k_n, c_m_n_host_ref); } else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant) { @@ -787,16 +824,18 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser) using Col = ck_tile::tensor_layout::gemm::ColumnMajor; if((QuantMode == ck_tile::QuantType::AQuantGrouped || - QuantMode == ck_tile::QuantType::RowColQuant) && + QuantMode == ck_tile::QuantType::RowColQuant || + std::is_same_v) && GemmConfig::PreshuffleB) { throw std::runtime_error( - "Preshuffling weight matrix is not supported for AQuant or RowColQuant"); + "Preshuffling weight matrix is not supported for AQuant, RowColQuant or bf16_fp4_gemm"); } if constexpr(std::is_same_v || std::is_same_v || - std::is_same_v) + std::is_same_v || + std::is_same_v) { std::string a_layout = arg_parser.get_str("a_layout"); std::string b_layout = arg_parser.get_str("b_layout"); diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 8830adfdd9..9c2ce62856 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -1550,9 +1550,10 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && - (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32) || - (std::is_same::value && - (N == 1 || N == 2 || N == 4 || N == 8 || N == 16))), + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) || + (std::is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), "wrong! not implemented"); using rtn_type = thread_buffer; diff --git a/include/ck_tile/host/check_err.hpp b/include/ck_tile/host/check_err.hpp index ac388992d1..a1be8027b2 100644 --- a/include/ck_tile/host/check_err.hpp +++ b/include/ck_tile/host/check_err.hpp @@ -52,9 +52,19 @@ template ::value, - "Warning: Unhandled ComputeDataType for setting up the relative threshold!"); + static_assert(is_any_of::value, + "Warning: Unhandled ComputeDataType for setting up the relative threshold!"); double compute_error = 0; if constexpr(is_any_of::value) @@ -113,9 +123,19 @@ CK_TILE_HOST double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations = 1) { - static_assert( - is_any_of::value, - "Warning: Unhandled ComputeDataType for setting up the absolute threshold!"); + static_assert(is_any_of::value, + "Warning: Unhandled ComputeDataType for setting up the absolute threshold!"); auto expo = std::log2(std::abs(max_possible_num)); double compute_error = 0; diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index 883b08fcaa..0aa296b8d9 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -246,6 +246,63 @@ CK_TILE_HOST void reference_gemm_tensor_quant(const HostTensor& a_m_k make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency()); } +template +CK_TILE_HOST void reference_mxfp4gemm_quant(const HostTensor& a_m_k, + const HostTensor& q, + const HostTensor& b_k_n, + HostTensor& c_m_n, + const AElementOp& a_element_op = {}, + const BElementOp& b_element_op = {}, + const ACCElementOp& acc_element_op = {}) +{ + const std::size_t M = a_m_k.get_length(0); + const std::size_t N = b_k_n.get_length(1); + const std::size_t K = a_m_k.get_length(1); + + auto f_mn = [&](auto m, auto n) { + AccDataType v_acc = 0; + AccDataType pasual = 0; + for(std::size_t k = 0; k < (K / 2); k++) + { + using ComputeType = float; + auto b_scale = type_convert(q((2 * k) / QuantGroupSize::kK, n)) - 127; + ComputeType v_a_0, v_a_1; + ComputeType v_b_0, v_b_1; + + v_a_0 = ck_tile::type_convert((a_element_op(a_m_k(m, 2 * k)))); + v_a_1 = ck_tile::type_convert((a_element_op(a_m_k(m, 2 * k + 1)))); + + if constexpr(std::is_same_v) + { + auto b_pack = type_convert(b_element_op(b_k_n(k, n))); + auto b_scale_fp4 = type_convert(std::pow(2.0f, b_scale)); + + auto b_f4_lo = type_convert(b_pack.unpack(number<0>{})); + auto b_f4_hi = type_convert(b_pack.unpack(number<1>{})); + + v_b_0 = type_convert(b_f4_lo) * b_scale_fp4; + v_b_1 = type_convert(b_f4_hi) * b_scale_fp4; + } + + pasual = v_a_0 * v_b_0 + v_a_1 * v_b_1; + v_acc += pasual; + } + c_m_n(m, n) = ck_tile::type_convert(acc_element_op(v_acc)); + }; + + make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency()); + std::cout << std::endl; +} + template struct DataTypeTraits { static constexpr const char * name = template <> struct DataTypeTraits { static constexpr const char * name = "int8"; }; template <> struct DataTypeTraits { static constexpr const char * name = "pk_int4"; }; template <> struct DataTypeTraits { static constexpr const char * name = "pk_fp4"; }; +template <> struct DataTypeTraits { static constexpr const char * name = "pk_fp4_raw"; }; template struct memOpToStr; template <> struct memOpToStr { static constexpr const char * name = "set"; }; diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 9a7876f6a5..ad1862306a 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -92,11 +92,17 @@ struct CShuffleEpilogue using ADataType = remove_cvref_t{}, AsDataTypeTuple>>; using BDataType = remove_cvref_t{}, BsDataTypeTuple>>; - using ATypeToUse = - std::conditional_t, BDataType, ADataType>; + using ATypeToUse = std::conditional_t || + std::is_same_v, + BDataType, + ADataType>; // Used for weight-only quantization kernel, B would be dequantized to the same data type as A - using BTypeToUse = - std::conditional_t, ADataType, BDataType>; + using BTypeToUse = std::conditional_t || + std::is_same_v || + std::is_same_v, + ADataType, + BDataType>; + using ELayout = remove_cvref_t; using CDElementwise = remove_cvref_t; static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation; diff --git a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index 8541ffa3a9..f6e26ad206 100644 --- a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -96,8 +96,10 @@ struct BlockUniversalGemmAsBsCr using ATypeToUse = std::conditional_t, BDataType, ADataType>; - using BTypeToUse = - std::conditional_t, ADataType, BDataType>; + using BTypeToUse = std::conditional_t || + std::is_same_v, + ADataType, + BDataType>; using WarpGemm = remove_cvref_t; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp index f39d41a653..343e37ed66 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp @@ -17,10 +17,12 @@ struct GemmPipelineAgBgCrImplBase using BsLayout = remove_cvref_t; using BlockGemmShape = remove_cvref_t; - using ADataType = remove_cvref_t{}, AsDataType>>; - using ALayout = remove_cvref_t{}, AsLayout>>; - using BDataType = remove_cvref_t{}, BsDataType>>; - using BLayout = remove_cvref_t{}, BsLayout>>; + using ADataType = remove_cvref_t{}, AsDataType>>; + using ALayout = remove_cvref_t{}, AsLayout>>; + using BInDataType = remove_cvref_t{}, BsDataType>>; + using BDataType = + std::conditional_t, ADataType, BInDataType>; + using BLayout = remove_cvref_t{}, BsLayout>>; static constexpr index_t MPerBlock = BlockGemmShape::kM; static constexpr index_t NPerBlock = BlockGemmShape::kN; @@ -270,12 +272,17 @@ struct GemmPipelineAgBgCrImplBase }(); auto b_copy_lds_window = make_tile_window(b_lds_block_view, b_lds_shape, {0, 0}); + using BLdsDataType = + std::conditional_t, + typename Problem::ADataType, + typename Problem::BDataType>; + auto b_lds_load_tile_distr = []() { if constexpr(is_b_load_tr) return make_static_tile_distribution( - typename InputTileDistributionTraits< - typename BLdsLoadTileDistr::DstrEncode, - typename Problem::BDataType>::TransposedDstrEncode{}); + typename InputTileDistributionTraits::TransposedDstrEncode{}); + else return BLdsLoadTileDistr{}; }(); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index 76341af70b..a45d41189b 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -303,8 +303,11 @@ struct UniversalGemmBasePolicy template CK_TILE_DEVICE static constexpr auto MakeBLdsBlockDescriptor() { - using BLayout = remove_cvref_t; - using BDataType = remove_cvref_t; + using BLayout = remove_cvref_t; + using BDataType = + std::conditional_t, + typename Problem::ADataType, + typename Problem::BDataType>; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; @@ -585,9 +588,12 @@ struct UniversalGemmBasePolicy using BsDataType = remove_cvref_t; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + using BLayout = remove_cvref_t{}, BsLayout>>; + using BInDataType = remove_cvref_t{}, BsDataType>>; - using BLayout = remove_cvref_t{}, BsLayout>>; - using BDataType = remove_cvref_t{}, BsDataType>>; + using BDataType = std::conditional_t, + typename Problem::ADataType, + typename Problem::BDataType>; if constexpr(Problem::FixedVectorSize) { @@ -729,13 +735,17 @@ struct UniversalGemmBasePolicy { constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + using BDataType = remove_cvref_t; + constexpr index_t KPerBlock = std::is_same_v + ? Problem::BlockGemmShape::kK / 2 + : Problem::BlockGemmShape::kK; constexpr index_t VecLoadSize = - Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB(); + std::is_same_v + ? 4 + : (Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB()); constexpr index_t NumWaveGroups = Problem::NumWaveGroups; - - using BLayout = remove_cvref_t< - std::tuple_element_t{}, remove_cvref_t>>; + using BLayout = remove_cvref_t< + std::tuple_element_t{}, remove_cvref_t>>; // Tile: KPerBlock X NPerBlock if constexpr(std::is_same_v) { @@ -841,10 +851,12 @@ struct UniversalGemmBasePolicy template CK_TILE_DEVICE static constexpr index_t GetSmemSizeB() { - constexpr index_t smem_size_b = - integer_least_multiple(sizeof(typename Problem::BDataType) * - Problem::BlockGemmShape::kN * Problem::BlockGemmShape::kK, - 16); + using BDataType = + std::conditional_t, + typename Problem::ADataType, + typename Problem::BDataType>; + constexpr index_t smem_size_b = integer_least_multiple( + sizeof(BDataType) * Problem::BlockGemmShape::kN * Problem::BlockGemmShape::kK, 16); return smem_size_b; } @@ -882,8 +894,10 @@ struct UniversalGemmPipelineAgBgCrPolicy using BDataType = remove_cvref_t; using ATypeToUse = std::conditional_t, BDataType, ADataType>; - using BTypeToUse = - std::conditional_t, ADataType, BDataType>; + using BTypeToUse = std::conditional_t || + std::is_same_v, + ADataType, + BDataType>; using WarpGemm = WarpGemmDispatcher( - b_ptr, - make_tuple(kargs.N, splitk_batch_offset.splitted_k), - make_tuple(kargs.stride_B, 1), - number{}, - number<1>{}); + if constexpr(std::is_same_v) + return make_naive_tensor_view( + b_ptr, + make_tuple(kargs.N, splitk_batch_offset.splitted_k / 2), + make_tuple(kargs.stride_B, 1), + number{}, + number<1>{}); + else + return make_naive_tensor_view( + b_ptr, + make_tuple(kargs.N, splitk_batch_offset.splitted_k), + make_tuple(kargs.stride_B, 1), + number{}, + number<1>{}); } } } @@ -885,10 +893,16 @@ struct QuantGemmKernel const auto& b_tensor_view = views.at(I2); if constexpr(std::is_same_v) { - return pad_tensor_view(b_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); + if constexpr(std::is_same_v) + return pad_tensor_view(b_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + else + return pad_tensor_view(b_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); } else { @@ -1020,10 +1034,17 @@ struct QuantGemmKernel { if constexpr(std::is_same_v) { - return make_tile_window(b_pad_view, - make_tuple(number{}, - number{}), - {i_n, 0}); + if constexpr(std::is_same_v) + return make_tile_window( + b_pad_view, + make_tuple(number{}, + number{}), + {i_n, 0}); + else + return make_tile_window(b_pad_view, + make_tuple(number{}, + number{}), + {i_n, 0}); } else { diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp new file mode 100644 index 0000000000..58019d703e --- /dev/null +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" + +namespace ck_tile { + +template +struct GemmMxFp4PipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase +{ + using Base = GemmPipelineAgBgCrImplBase; + using ADataType = typename Base::ADataType; + using ALayout = typename Base::ALayout; + using BDataType = typename Base::BDataType; + using BLayout = typename Base::BLayout; + using BlockGemmShape = typename Base::BlockGemmShape; + using QuantGroupSize = remove_cvref_t; + + using BQLayout = remove_cvref_t; + + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + + static constexpr index_t NPerBlockBQ = NPerBlock / QuantGroupSize::kN; + static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize::kK; + + static_assert(NPerBlockBQ >= 1, "NPerBlock must be >= QuantGroupSize"); + static_assert(KPerBlockBQ >= 1, "KPerBlock must be >= QuantGroupSize"); + + static_assert(NPerBlock % QuantGroupSize::kN == 0, + "NPerBlock must be a multiple of QuantGroupSize::kN"); + static_assert(KPerBlock % QuantGroupSize::kK == 0, + "KPerBlock must be a multiple of QuantGroupSize::kK"); + + // Create DRAM tile window for BQ + template + CK_TILE_DEVICE constexpr auto + GetBQDramLoadWindow(const BQDramBlockWindowTmp& bq_dram_block_window_tmp) const + { + static_assert(std::is_same_v); + + using YPerTile = number; + using XPerTile = number; + + auto bq_copy_dram_window = + make_tile_window(bq_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(YPerTile(), XPerTile()), + bq_dram_block_window_tmp.get_window_origin(), + Policy::template MakeBQDramTileDistribution()); + return bq_copy_dram_window; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp new file mode 100644 index 0000000000..6ce2ff10fa --- /dev/null +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp @@ -0,0 +1,140 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" +#include "gemm_group_quant_utils.hpp" + +namespace ck_tile { + +struct GemmMxFp4PipelineAgBgCrPolicy : public UniversalGemmPipelineAgBgCrPolicy +{ + using Base = UniversalGemmPipelineAgBgCrPolicy; + using Base::I0; + using Base::I1; + using Base::I2; + + template + CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeBQ() + { + using BQLayout = remove_cvref_t; + using BQDataType = remove_cvref_t; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t NPerBlockBQ = NPerBlock / Problem::QuantGroupSize::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPerBlockBQ = KPerBlock / Problem::QuantGroupSize::kK; + + static_assert(std::is_same_v); + return GetABQGlobalVectorLoadSize(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBRegTileDistribution() + { + using BLayout = remove_cvref_t; + + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t VecLoadSize = + Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB(); + constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + // Tile: KPerBlock X NPerBlock + if constexpr(std::is_same_v) + { + using TileEncodingPattern = + tile_distribution_encoding_pattern_2d; + return TileEncodingPattern::make_2d_static_tile_distribution(); + } + // Tile: NPerBlock X KPerBlock + else + { + using TileEncodingPattern = + tile_distribution_encoding_pattern_2d; + return TileEncodingPattern::make_2d_static_tile_distribution(); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBQDramTileDistribution() + { + // using BLayout = remove_cvref_t; + + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t KScale = KPerBlock / Problem::QuantGroupSize::kK; // k_scale num //2 + constexpr index_t VecLoadSize = + Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB(); + constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + + constexpr index_t warp_size = get_warp_size(); + constexpr index_t num_warps = BlockSize / get_warp_size(); + constexpr index_t LargestVec = (KPerBlock * NPerBlock) / (num_warps * warp_size); + constexpr index_t b_vec = VecLoadSize > LargestVec ? LargestVec : VecLoadSize; + constexpr index_t K0 = KPerBlock / b_vec; + constexpr index_t K1 = K0 / KScale; + constexpr index_t K3 = K0 / K1; + constexpr index_t K2 = 1; + + constexpr index_t N0 = num_warps / NumWaveGroups; + constexpr index_t N1 = warp_size / K0; + constexpr index_t N2 = NPerBlock / (N0 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2, 0>>, + tuple, sequence<1, 0, 0>>, + sequence<1, 2>, + sequence<2, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() + { + using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + + static_assert(Problem::QuantGroupSize::kK % WarpTile::at(I2) == 0, + "KPerWarpGemm must be a multiple of QuantGroupSize!"); + + using WarpGemm = WarpGemmDispatcher; + static_assert(std::is_same_v || + std::is_same_v || + std::is_same_v); + static_assert(std::is_same_v); + + using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy< + typename Problem::ADataType, + std::conditional_t, + typename Problem::ADataType, + typename Problem::BDataType>, + typename Problem::CDataType, + BlockWarps, + WarpGemm>; + + return BlockUniversalGemmAsBsCr{}; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp new file mode 100644 index 0000000000..c113521d6b --- /dev/null +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp @@ -0,0 +1,665 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp" +#include "ck_tile/host/concat.hpp" + +namespace ck_tile { + +// A Tile Window: global memory +// B Tile Window: global memory +// C Distributed tensor: register + +template +struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 +{ + using Base = BaseGemmPipelineAgBgCrCompV3; + using PipelineImplBase = GemmMxFp4PipelineAgBgCrImplBase; + + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using BDqDataType = remove_cvref_t; + using BQDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + using QuantGroupSize = remove_cvref_t; + + static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!"); + + using I0 = number<0>; + using I1 = number<1>; + using I2 = number<2>; + + static constexpr index_t APackedSize = + ck_tile::numeric_traits>::PackedSize; + static constexpr index_t BPackedSize = + ck_tile::numeric_traits>::PackedSize; + + static constexpr index_t BQPackedSize = + ck_tile::numeric_traits>::PackedSize; + + using ALayout = remove_cvref_t; + using BQLayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using BlockGemm = remove_cvref_t())>; + + static constexpr index_t BlockSize = Problem::kBlockSize; + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + + static constexpr index_t NPerBlockBQ = BlockGemmShape::kN / QuantGroupSize::kN; + static constexpr index_t KPerBlockBQ = BlockGemmShape::kK / QuantGroupSize::kK; + + static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA(); } + static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB(); } + static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC(); } + static constexpr index_t GetVectorSizeBQ() + { + return Policy::template GetVectorSizeBQ(); + } + + static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA(); } + static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB(); } + + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kPadK = Problem::kPadK; + + static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; + + static constexpr bool HasHotLoop = Problem::HasHotLoop; + static constexpr auto TailNum = Problem::TailNum; + static constexpr auto Scheduler = Problem::Scheduler; + + using Base::PrefetchStages; + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{}); + constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); + return concat('_', "mxfp4gemm_pipeline_AgBgCrCompV3", + concat('x', MPerBlock, NPerBlock, KPerBlock), BlockSize, + concat('x', WaveNumM, WaveNumN), + concat('x', kPadM, kPadN, kPadK), + concat('x', kPadM, kPadN, kPadK), QuantGroupSize::GetName()); + // clang-format on + } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + CK_TILE_HOST static std::string Print() + { + constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM; + constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN; + constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK; + + constexpr index_t WaveSize = 64; + constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{}); + constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); + + constexpr index_t A_LDS_Read_Width = GetSmemPackA(); + constexpr index_t B_LDS_Read_Width = GetSmemPackB(); + + constexpr index_t A_LDS_Write_Width = GetSmemPackA(); + constexpr index_t B_LDS_Write_Width = GetSmemPackB(); + + constexpr index_t A_Buffer_Load_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA()); + constexpr index_t B_Buffer_Load_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB()); + constexpr index_t BQ_Buffer_Load_Inst_Num = + NPerBlock * KPerBlockBQ / (BlockSize * GetVectorSizeBQ()); + + constexpr index_t A_LDS_Write_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width); + constexpr index_t B_LDS_Write_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width); + + constexpr index_t A_LDS_Read_Inst_Num = + WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width); + constexpr index_t B_LDS_Read_Inst_Num = + WaveNumM * NPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width); + + constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / + (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL); + + auto str = std::stringstream{}; + + str << "A/B vector size: " << GetVectorSizeA() << ", " << GetVectorSizeB() << ", " + << "BQ vector size: " << GetVectorSizeBQ() << "\n" + << "A/B LDS read/write width: " << A_LDS_Read_Width << ", " << B_LDS_Read_Width << "\n" + << "A/B buffer load inst: " << A_Buffer_Load_Inst_Num << ", " << B_Buffer_Load_Inst_Num + << ", " << "BQ buffer load inst: " << BQ_Buffer_Load_Inst_Num << "\n" + << "A/B LDS write inst: " << A_LDS_Write_Inst_Num << ", " << B_LDS_Write_Inst_Num + << "\n" + << "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n" + << "C MFMA inst: " << C_MFMA_Inst_Num << "\n" + << "QuantGroupSize: " << QuantGroupSize::GetName() << "\n" + << "KPack: " << BlockGemm::Traits::KPack << "\n" + << "PrefetchStages: " << PrefetchStages << "\n"; + return str.str(); + } + + template + struct PipelineImpl : public PipelineImplBase + { + }; + + template <> + struct PipelineImpl : public PipelineImplBase + { + using Base = PipelineImplBase; + + CK_TILE_DEVICE static constexpr auto HotLoopScheduler() + { + constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM; + constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN; + constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK; + + constexpr index_t WaveSize = 64; + constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{}); + constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); + + // Below should be equal to AK1|BK1 + constexpr index_t A_LDS_Read_Width = GetSmemPackA(); + constexpr index_t B_LDS_Read_Width = GetSmemPackB(); + + constexpr index_t A_LDS_Write_Width = GetSmemPackA(); + constexpr index_t B_LDS_Write_Width = GetSmemPackB(); + + constexpr index_t A_Buffer_Load_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA()); + constexpr index_t B_Buffer_Load_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB()); + + constexpr index_t A_LDS_Write_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width); + constexpr index_t B_LDS_Write_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width); + + constexpr index_t A_LDS_Read_Inst_Num = + WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width); + constexpr index_t B_LDS_Read_Inst_Num = + WaveNumM * NPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width); + + constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / + (BlockSize / WaveSize) / + (MPerXDL * NPerXDL * KPerXDL); + + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? A_LDS_Read_Inst_Num + : A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = + B_LDS_Read_Width * sizeof(BDqDataType) / BPackedSize == 16 + ? B_LDS_Read_Inst_Num + : B_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_a = A_LDS_Write_Inst_Num; + constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num; + + constexpr auto num_buffer_load_inst_a = A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = B_Buffer_Load_Inst_Num; + + constexpr auto num_mfma_inst = C_MFMA_Inst_Num; + + constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32; + constexpr auto ds_read_a_issue_cycle = + A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? 8 : 4; + constexpr auto ds_read_b_issue_cycle = + B_LDS_Read_Width * sizeof(BDqDataType) / BPackedSize == 16 ? 8 : 4; + constexpr auto ds_read_a_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle); + constexpr auto ds_read_b_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle); + + constexpr auto num_dsread_a_mfma = + (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + constexpr auto num_dsread_b_mfma = + (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; + + // stage 1 + // Separate this part? + // constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) > + // sizeof(ComputeDataType) / + // sizeof(BDataType) + // ? sizeof(ComputeDataType) / + // sizeof(ADataType) : sizeof(ComputeDataType) + // / sizeof(BDataType); + constexpr auto num_mfma_stage1 = + num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma); + constexpr auto num_mfma_per_issue = + num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b); + constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a; + constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b; + + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA + }); + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA + }); + + // stage 2 + static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >= + ds_read_a_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x100, + num_ds_read_inst_a - (num_dsread_a_mfma - 1) * ds_read_a_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + + static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >= + ds_read_b_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x100, + num_ds_read_inst_b - (num_dsread_b_mfma - 1) * ds_read_b_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + } + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + const BQDramBlockWindowTmp& bq_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "A/B/BQ Dram block window should have the same data type as appropriate " + "([A|B|BQ]DataType) defined in Problem definition!"); + + constexpr bool is_a_col_major = + std::is_same_v; + constexpr bool is_bq_col_major = + std::is_same_v; + constexpr bool is_b_row_major = std::is_same_v; + + static_assert(is_bq_col_major, "Bq must be col major (row major not supported yet)"); + static_assert(NPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}], + "Bq block window has incorrect lengths for defined BqLayout!"); + + static_assert(is_a_col_major + ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "A block window has incorrect lengths for defined ALayout!"); + static_assert( + is_b_row_major + ? (KPerBlock / 2 == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlock / 2 == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "B block window has incorrect lengths for defined BLayout!"); + + // ------------------------------------------------------------------------------------ + // Definitions of all needed tiles + // int b_block_stride = 0; + // A/B tiles in LDS + auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem); + + // Tile distribution for load from lds + constexpr auto a_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + constexpr auto b_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); + + // A DRAM tile window for load + // A LDS tile window for store + // A LDS tile for block GEMM + auto&& [a_copy_dram_window, a_copy_lds_window, a_lds_gemm_window] = + Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr); + + // B DRAM tile window for load, (kN, kK/2) + // B LDS tile window for store, (kN, kK) + // B LDS tile for block GEMM + auto&& [b_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] = + Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr); + + // B scale DRAM tile window for load + // auto b_scale_copy_dram_window = + // make_tile_window(bq_dram_block_window_tmp.get_bottom_tensor_view(), + // bq_dram_block_window_tmp.get_window_lengths(), + // bq_dram_block_window_tmp.get_window_origin(), + // Policy::template GetBQDramLoadWindow()); + auto bq_copy_dram_window = Base::GetBQDramLoadWindow(bq_dram_block_window_tmp); + + auto bq_block_tile = decltype(load_tile(bq_copy_dram_window)){}; + + // Block GEMM + auto block_gemm = BlockGemm(); + auto c_block_tile = block_gemm.MakeCBlockTile(); + using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); + // using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); + using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); + + using ABlockTile = + decltype(make_static_distributed_tensor(ABlockTileDistr{})); + using BBlockTile = + decltype(make_static_distributed_tensor(BBlockTileDistr{})); + + ABlockTile a_block_tile; + BBlockTile b_fp4_block_tile; + + using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; + using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; + + constexpr ADramTileWindowStep a_dram_tile_window_step = + is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); + constexpr BDramTileWindowStep b_dram_tile_window_step = + is_b_row_major ? make_array(KPerBlock / 2, 0) : make_array(0, KPerBlock / 2); + + constexpr index_t b_scale_dram_tile_window_step = KPerBlock / QuantGroupSize::kK; + // ----------------------------------------------------------------------------------------- + // Gemm pipeline start + + // prefetch + // global read 0 + // auto a_scale_block_tile = decltype(load_tile(a_scale_copy_dram_window)){}; + Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch(b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step); + // BDataType + auto b_block_tile = make_static_distributed_tensor( + Policy::template MakeBRegTileDistribution()); + + bq_block_tile = load_tile(bq_copy_dram_window); + move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step}); + + constexpr auto idx1_js = tile_distributed_index<0>{}; + constexpr auto b_block = decltype(b_fp4_block_tile)::get_distributed_spans(); + sweep_tile_span(b_block[number<0>{}], [&](auto idx0) { + sweep_tile_span(b_block[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + constexpr auto i_j_idx_scale = make_tuple(idx0, idx1_js); + auto b_scale_uint = type_convert(bq_block_tile(i_j_idx_scale)) - 127; + auto b_scale = type_convert(std::pow(2.0f, b_scale_uint)); + constexpr auto idx1_lo = tile_distributed_index{}; + constexpr auto idx1_hi = tile_distributed_index{}; + constexpr auto i_j_idx_lo = make_tuple(idx0, idx1_lo); + constexpr auto i_j_idx_hi = make_tuple(idx0, idx1_hi); + + auto b_pack = type_convert(b_fp4_block_tile(i_j_idx)); + auto b_f4_lo = type_convert(b_pack.unpack(number<0>{})); + auto b_f4_hi = type_convert(b_pack.unpack(number<1>{})); + b_block_tile(i_j_idx_lo) = + type_convert(type_convert(b_f4_lo) * b_scale); + b_block_tile(i_j_idx_hi) = + type_convert(type_convert(b_f4_hi) * b_scale); + }); + }); + + // initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + block_sync_lds(); + + // LDS write 0 + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_block_tile); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + } + + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_block_tile); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); + } + + Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch(b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step); + + bq_block_tile = load_tile(bq_copy_dram_window); + move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step}); + + sweep_tile_span(b_block[number<0>{}], [&](auto idx0) { + sweep_tile_span(b_block[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + constexpr auto i_j_idx_scale = make_tuple(idx0, idx1_js); + + auto b_scale_uint = type_convert(bq_block_tile(i_j_idx_scale)) - 127; + auto b_scale = type_convert(std::pow(2.0f, b_scale_uint)); + constexpr auto idx1_lo = tile_distributed_index{}; + constexpr auto idx1_hi = tile_distributed_index{}; + constexpr auto i_j_idx_lo = make_tuple(idx0, idx1_lo); + constexpr auto i_j_idx_hi = make_tuple(idx0, idx1_hi); + + auto b_pack = type_convert(b_fp4_block_tile(i_j_idx)); + auto b_f4_lo = type_convert(b_pack.unpack(number<0>{})); + auto b_f4_hi = type_convert(b_pack.unpack(number<1>{})); + b_block_tile(i_j_idx_lo) = + type_convert(type_convert(b_f4_lo) * b_scale); + b_block_tile(i_j_idx_hi) = + type_convert(type_convert(b_f4_hi) * b_scale); + }); + }); + + block_sync_lds(); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasHotLoop) + { + index_t i = 0; + do + { + block_sync_lds(); + + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_block_tile); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + } + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_block_tile); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); + } + + Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch( + b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step); + + bq_block_tile = load_tile(bq_copy_dram_window); + move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step}); + + sweep_tile_span(b_block[number<0>{}], [&](auto idx0) { + sweep_tile_span(b_block[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + constexpr auto i_j_idx_scale = make_tuple(idx0, idx1_js); + + auto b_scale_uint = + type_convert(bq_block_tile(i_j_idx_scale)) - 127; + auto b_scale = type_convert(std::pow(2.0f, b_scale_uint)); + constexpr auto idx1_lo = tile_distributed_index{}; + constexpr auto idx1_hi = + tile_distributed_index{}; + constexpr auto i_j_idx_lo = make_tuple(idx0, idx1_lo); + constexpr auto i_j_idx_hi = make_tuple(idx0, idx1_hi); + + auto b_pack = type_convert(b_fp4_block_tile(i_j_idx)); + auto b_f4_lo = type_convert(b_pack.unpack(number<0>{})); + auto b_f4_hi = type_convert(b_pack.unpack(number<1>{})); + b_block_tile(i_j_idx_lo) = + type_convert(type_convert(b_f4_lo) * b_scale); + b_block_tile(i_j_idx_hi) = + type_convert(type_convert(b_f4_hi) * b_scale); + }); + }); + + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + block_sync_lds(); + + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + + i += 1; + // b_block_stride +=1; + } while(i < (num_loop - 1)); + } + // tile_elementwise_inout([](auto& c) { c = 0; }, acc_block_tile); + // tail + if constexpr((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd)) + { + // Leak last MFMA block to epilogue region, cover the potential lds-shuffle + // latency + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + } + else + { + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + block_sync_lds(); + + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_block_tile); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + } + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_block_tile); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); + } + + block_sync_lds(); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + block_sync_lds(); + } + __builtin_amdgcn_sched_barrier(0); + return c_block_tile; + } + }; + + /** + * @brief This function runs the pipeline using compile-time known hot loop and tail number. + * @param num_loop The number of loop iterations. This is determined at runtime due to e.g. + * SplitK. + * @note This is used by the kernel variants that are able to determine + * hot loop and tail number on the host side, e.g. non-persistent gemm kernel. + */ + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BQDramBlockWindowTmp& bq_dram_block_window_tmp, + index_t num_loop, + void* p_smem, + index_t n = 0) const + { + ck_tile::ignore = n; + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_dram_block_window_tmp, + [](const BDqDataType& b) { return b; }, + bq_dram_block_window_tmp, + num_loop, + p_smem); + } +}; + +} // namespace ck_tile diff --git a/test/ck_tile/gemm_block_scale/CMakeLists.txt b/test/ck_tile/gemm_block_scale/CMakeLists.txt old mode 100644 new mode 100755 diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp index 39a7c66f38..fe5d2bd7e1 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp @@ -131,8 +131,10 @@ class TestCkTileGemmQuantBase : public ::testing::Test const ck_tile::index_t kbatch, const float max_accumulated_value) { - using ComputeType = - std::conditional_t; + using ComputeType = std::conditional_t< + std::is_same_v, + ADataType_, + std::conditional_t>; // Calculate thresholds const auto rtol = ck_tile::get_relative_threshold( ck_tile::integer_divide_ceil(K, kbatch)); diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant.cpp index ef0d41909b..ec123364cb 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant.cpp @@ -16,9 +16,12 @@ using FP8 = ck_tile::fp8_t; using BF8 = ck_tile::bf8_t; using Half = ck_tile::half_t; using PkInt4 = ck_tile::pk_int4_t; +using BF16 = ck_tile::bf16_t; +using UInt8 = ck_tile::pk_fp4_raw_t; using BQuantGrouped = std::integral_constant; using GroupSize = ck_tile::QuantGroupShape>; using GroupSize64 = ck_tile::QuantGroupShape>; +using GroupSize32 = ck_tile::QuantGroupShape>; // 2d block sizes for BQuant using GroupSize2D8N = ck_tile::QuantGroupShape>; @@ -42,6 +45,9 @@ using BQuantTypes = ::testing::Types< std::tuple, std::tuple, std::tuple, + std::tuple, + + std::tuple, // 2d cases with grouping also on the n axis std::tuple, diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index bf9c7a138d..4f2edb3609 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -60,6 +60,13 @@ struct GemmConfigPrefill : public GemmConfigBase static constexpr ck_tile::index_t K_Tile = 128; }; +struct GemmConfigMxFp4 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128; +}; + struct GemmConfigPreshuffleQuant : public GemmConfigBase { static constexpr bool PreshuffleQuant = true; @@ -403,7 +410,8 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase ? (K / 2) : K; const ck_tile::index_t stride_C = N; // BQuant uses block/grouped quantization for B matrix @@ -414,15 +422,27 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase a_m_k( ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{}))); - ck_tile::HostTensor b_k_n( - ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{}))); + ck_tile::HostTensor b_k_n(ck_tile::host_tensor_descriptor( + std::is_same_v ? K / 2 : K, + N, + stride_B, + this->is_row_major(BLayout{}))); ck_tile::HostTensor bq_bqk_bqn( ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, this->is_row_major(BQLayout{}))); // Initialize data with random values ck_tile::FillUniformDistribution{-0.5f, 0.5f}(a_m_k); - ck_tile::FillUniformDistribution{0.f, 1.f}(b_k_n); - ck_tile::FillUniformDistribution{-1.0f, 1.0f}(bq_bqk_bqn); + if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f}(b_k_n); + ck_tile::FillUniformDistribution{125.f, 130.f}(bq_bqk_bqn); + } + else + { + ck_tile::FillUniformDistribution{0.f, 1.f}(b_k_n); + ck_tile::FillUniformDistribution{-1.0f, 1.0f}(bq_bqk_bqn); + } + // Allocate device memory ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size() * sizeof(ADataType)); ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size() * sizeof(BDataType)); @@ -501,13 +521,22 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase(a_m_k, bq_bqk_bqn, b_k_n, c_m_n_host_ref); + if constexpr(std::is_same_v) + ck_tile::reference_mxfp4gemm_quant(a_m_k, bq_bqk_bqn, b_k_n, c_m_n_host_ref); + else + ck_tile::reference_gemm_quant(a_m_k, bq_bqk_bqn, b_k_n, c_m_n_host_ref); // Get device result ck_tile::HostTensor c_m_n_dev_result( @@ -580,33 +609,37 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase; - using GemmPipeline = - std::conditional_t, - ck_tile::WPQuantBPipelineAgBgCrV2>; + using GemmPipeline = std::conditional_t< + PreshuffleB == false, + std::conditional_t, + ck_tile::MxFp4GemmPipelineAgBgCrCompV3, + ck_tile::BQuantGemmPipelineAgBgCrCompV3>, + ck_tile::WPQuantBPipelineAgBgCrV2>; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, - AccDataType, - CDataType, - ck_tile::tuple<>, - CLayout, - ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - Base::M_Warp, - Base::N_Warp, - Base::M_Warp_Tile, - Base::N_Warp_Tile, - Base::K_Warp_Tile, - false, // transpose_c - ck_tile::memory_operation_enum::set, - 1, - false, - 1, - TiledMMAPermuteN>>; + using GemmEpilogue = ck_tile::CShuffleEpilogue, + ADataType, + BDataType>, + ck_tile::tuple<>, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + Base::M_Warp, + Base::N_Warp, + Base::M_Warp_Tile, + Base::N_Warp_Tile, + Base::K_Warp_Tile, + false, // transpose_c + ck_tile::memory_operation_enum::set, + 1, + false, + 1, + TiledMMAPermuteN>>; using Kernel = ck_tile::QuantGemmKernel Date: Thu, 11 Dec 2025 08:09:29 -0800 Subject: [PATCH 09/10] Fix compilation errors with latest clang22 version. (#3396) * remove target attributes from deduction guides * switch CK_TILE_HOST_DEVICE_EXTERN based on clang version --- include/ck_tile/core/config.hpp | 4 ++++ include/ck_tile/core/numeric/math.hpp | 21 +++++++------------ .../core/utility/unary_element_function.hpp | 3 +-- 3 files changed, 12 insertions(+), 16 deletions(-) diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index 678a2fbfff..0e7d1def75 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -39,8 +39,12 @@ #define CK_TILE_DEVICE inline __device__ #define CK_TILE_HOST_DEVICE inline __host__ __device__ #define CK_TILE_DEVICE_EXTERN __device__ +#if __clang_major__ < 22 #define CK_TILE_HOST_DEVICE_EXTERN __host__ __device__ #else +#define CK_TILE_HOST_DEVICE_EXTERN +#endif +#else #define CK_TILE_HOST inline #define CK_TILE_DEVICE inline #define CK_TILE_HOST_DEVICE inline diff --git a/include/ck_tile/core/numeric/math.hpp b/include/ck_tile/core/numeric/math.hpp index 57f3953514..8a0e3b3408 100644 --- a/include/ck_tile/core/numeric/math.hpp +++ b/include/ck_tile/core/numeric/math.hpp @@ -41,9 +41,8 @@ struct scales Scale lhs_; }; -/// FIXME: create macro to replace '__host__ __device__' and nothing more template -__host__ __device__ scales(Scale) -> scales; +CK_TILE_HOST_DEVICE_EXTERN scales(Scale) -> scales; template struct plus @@ -66,8 +65,7 @@ struct plus } }; -/// FIXME: create macro to replace '__host__ __device__' and nothing more -__host__ __device__ plus() -> plus; +CK_TILE_HOST_DEVICE_EXTERN plus() -> plus; template struct minus @@ -90,8 +88,7 @@ struct minus } }; -/// FIXME: create macro to replace '__host__ __device__' and nothing more -__host__ __device__ minus() -> minus; +CK_TILE_HOST_DEVICE_EXTERN minus() -> minus; template struct multiplies @@ -114,8 +111,7 @@ struct multiplies } }; -/// FIXME: create macro to replace '__host__ __device__' and nothing more -__host__ __device__ multiplies() -> multiplies; +CK_TILE_HOST_DEVICE_EXTERN multiplies() -> multiplies; template struct maximize @@ -345,8 +341,7 @@ struct equal } }; -/// FIXME: create macro to replace '__host__ __device__' and nothing more -__host__ __device__ equal() -> equal; +CK_TILE_HOST_DEVICE_EXTERN equal() -> equal; template <> struct equal @@ -387,8 +382,7 @@ struct less } }; -/// FIXME: create macro to replace '__host__ __device__' and nothing more -__host__ __device__ less() -> less; +CK_TILE_HOST_DEVICE_EXTERN less() -> less; template struct less_equal @@ -411,8 +405,7 @@ struct less_equal } }; -/// FIXME: create macro to replace '__host__ __device__' and nothing more -__host__ __device__ less_equal() -> less_equal; +CK_TILE_HOST_DEVICE_EXTERN less_equal() -> less_equal; template <> struct less_equal diff --git a/include/ck_tile/core/utility/unary_element_function.hpp b/include/ck_tile/core/utility/unary_element_function.hpp index b195275bdc..595b8522da 100644 --- a/include/ck_tile/core/utility/unary_element_function.hpp +++ b/include/ck_tile/core/utility/unary_element_function.hpp @@ -47,9 +47,8 @@ struct composes F f_; }; -/// FIXME: create macro to replace '__host__ __device__' and nothing more template -__host__ __device__ composes(Ts&&...) -> composes...>; +CK_TILE_HOST_DEVICE_EXTERN composes(Ts&&...) -> composes...>; template struct saturates From 4dcc3e59c1c0195dae7ee9da9ab76d18a4cafe9f Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Thu, 11 Dec 2025 20:25:29 +0400 Subject: [PATCH 10/10] chore: update copyright header for misc files (#3402) * chore: update copyright header for misc files * fix: typo in kernel resulting in ci failure --- docs/conceptual/ck_tile/convert_mermaid_to_svg.py | 3 +++ docs/conceptual/ck_tile/convert_raw_html_to_commented.py | 3 +++ docs/conceptual/ck_tile/update_diagrams.py | 3 +++ example/test_old_ck_gpu_reference.cpp | 2 +- experimental/builder/test/test_ckb_conv_builder.cpp | 2 ++ include/ck_tile/ref/conv_common.hpp | 2 +- include/ck_tile/ref/naive_grouped_conv_bwd_data_gpu.hpp | 2 +- include/ck_tile/ref/naive_grouped_conv_bwd_weight_gpu.hpp | 2 +- include/ck_tile/ref/naive_grouped_conv_fwd_gpu.hpp | 2 +- .../device_grouped_gemm_wmma_splitk_instance.hpp | 2 +- test/ck_tile/gemm_block_scale/test_gemm_quant_aquant.cpp | 2 +- test/ck_tile/gemm_block_scale/test_gemm_quant_bquant.cpp | 2 +- .../test_gemm_quant_bquant_preshuffle.cpp | 2 +- test/ck_tile/gemm_block_scale/test_gemm_quant_rowcol.cpp | 2 +- test/ck_tile/gemm_block_scale/test_gemm_quant_tensor.cpp | 2 +- test/ck_tile/utility/test_fill.cpp | 2 +- test/ck_tile/warp_gemm/CMakeLists.txt | 3 +++ test/ck_tile/warp_gemm/test_f32_16x16x128_fp4.cpp | 2 +- .../practice_gemm_host_pipeline_agmem_bgmem_creg.hpp | 8 ++++---- 19 files changed, 31 insertions(+), 17 deletions(-) diff --git a/docs/conceptual/ck_tile/convert_mermaid_to_svg.py b/docs/conceptual/ck_tile/convert_mermaid_to_svg.py index 1d62405e53..2bfaffdb57 100644 --- a/docs/conceptual/ck_tile/convert_mermaid_to_svg.py +++ b/docs/conceptual/ck_tile/convert_mermaid_to_svg.py @@ -1,4 +1,7 @@ #!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + """ Script to convert all mermaid diagrams in CK Tile docs to SVGs. This script: diff --git a/docs/conceptual/ck_tile/convert_raw_html_to_commented.py b/docs/conceptual/ck_tile/convert_raw_html_to_commented.py index e90bf9def0..8e4a849e7f 100644 --- a/docs/conceptual/ck_tile/convert_raw_html_to_commented.py +++ b/docs/conceptual/ck_tile/convert_raw_html_to_commented.py @@ -1,4 +1,7 @@ #!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + """Convert raw HTML mermaid blocks to commented format for SVG conversion.""" import os diff --git a/docs/conceptual/ck_tile/update_diagrams.py b/docs/conceptual/ck_tile/update_diagrams.py index 2fbe2ef5a9..f78599010e 100644 --- a/docs/conceptual/ck_tile/update_diagrams.py +++ b/docs/conceptual/ck_tile/update_diagrams.py @@ -1,4 +1,7 @@ #!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + """ Helper script to update SVG diagrams from commented mermaid sources in RST files. diff --git a/example/test_old_ck_gpu_reference.cpp b/example/test_old_ck_gpu_reference.cpp index 0bcf43d20b..9f12eaea4d 100644 --- a/example/test_old_ck_gpu_reference.cpp +++ b/example/test_old_ck_gpu_reference.cpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. // Standalone test program for Old CK GPU references // Tests naive_conv_fwd (existing) and future backward ops diff --git a/experimental/builder/test/test_ckb_conv_builder.cpp b/experimental/builder/test/test_ckb_conv_builder.cpp index e69de29bb2..81e63887c1 100644 --- a/experimental/builder/test/test_ckb_conv_builder.cpp +++ b/experimental/builder/test/test_ckb_conv_builder.cpp @@ -0,0 +1,2 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT diff --git a/include/ck_tile/ref/conv_common.hpp b/include/ck_tile/ref/conv_common.hpp index ed43e87b14..50ae18eb99 100644 --- a/include/ck_tile/ref/conv_common.hpp +++ b/include/ck_tile/ref/conv_common.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ref/naive_grouped_conv_bwd_data_gpu.hpp b/include/ck_tile/ref/naive_grouped_conv_bwd_data_gpu.hpp index a5f6a697f2..f75bdda912 100644 --- a/include/ck_tile/ref/naive_grouped_conv_bwd_data_gpu.hpp +++ b/include/ck_tile/ref/naive_grouped_conv_bwd_data_gpu.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ref/naive_grouped_conv_bwd_weight_gpu.hpp b/include/ck_tile/ref/naive_grouped_conv_bwd_weight_gpu.hpp index 2ac9c19892..0839074dd4 100644 --- a/include/ck_tile/ref/naive_grouped_conv_bwd_weight_gpu.hpp +++ b/include/ck_tile/ref/naive_grouped_conv_bwd_weight_gpu.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ref/naive_grouped_conv_fwd_gpu.hpp b/include/ck_tile/ref/naive_grouped_conv_fwd_gpu.hpp index 720fa40297..f582fcd71a 100644 --- a/include/ck_tile/ref/naive_grouped_conv_fwd_gpu.hpp +++ b/include/ck_tile/ref/naive_grouped_conv_fwd_gpu.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp index 6d5da9208b..d0de1c859b 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant.cpp index 9ba0b9c804..b6e69cd649 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant.cpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck_tile/host.hpp" #include "ck_tile/ops/gemm.hpp" diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant.cpp index ec123364cb..4b1ad068a7 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant.cpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck_tile/host.hpp" #include "ck_tile/ops/gemm.hpp" diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp index 3a62fc091a..ae01bddf96 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck_tile/host.hpp" #include "ck_tile/ops/gemm.hpp" diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_rowcol.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_rowcol.cpp index 5a58ed886a..bb0fa21899 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_rowcol.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_rowcol.cpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck_tile/host.hpp" #include "ck_tile/ops/gemm.hpp" diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_tensor.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_tensor.cpp index 0fa4048dab..8b4c90f8b9 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_tensor.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_tensor.cpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck_tile/host.hpp" #include "ck_tile/ops/gemm.hpp" diff --git a/test/ck_tile/utility/test_fill.cpp b/test/ck_tile/utility/test_fill.cpp index 18f42c4ad0..3633f8bbff 100644 --- a/test/ck_tile/utility/test_fill.cpp +++ b/test/ck_tile/utility/test_fill.cpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck_tile/host/fill.hpp" #include "ck_tile/host/joinable_thread.hpp" diff --git a/test/ck_tile/warp_gemm/CMakeLists.txt b/test/ck_tile/warp_gemm/CMakeLists.txt index 664ebc003b..5079741e1b 100644 --- a/test/ck_tile/warp_gemm/CMakeLists.txt +++ b/test/ck_tile/warp_gemm/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + if(GPU_TARGETS MATCHES "gfx95") add_gtest_executable(test_ck_tile_wg_16x16x128_fp4 test_f32_16x16x128_fp4.cpp) endif() diff --git a/test/ck_tile/warp_gemm/test_f32_16x16x128_fp4.cpp b/test/ck_tile/warp_gemm/test_f32_16x16x128_fp4.cpp index 7878fda618..47fa1ff43e 100644 --- a/test/ck_tile/warp_gemm/test_f32_16x16x128_fp4.cpp +++ b/test/ck_tile/warp_gemm/test_f32_16x16x128_fp4.cpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #include "ck_tile/host.hpp" diff --git a/tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_pipeline_agmem_bgmem_creg.hpp b/tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_pipeline_agmem_bgmem_creg.hpp index 15c1743a86..45f439e8fa 100644 --- a/tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_pipeline_agmem_bgmem_creg.hpp +++ b/tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_pipeline_agmem_bgmem_creg.hpp @@ -28,9 +28,9 @@ struct PracticeGemmHostPipeline { // Size of the entire problem - const auto M = a_dram.get_tensor_descriptor().get_length(number<0>{}); // M x K - const auto N = c_dram_ref.get_tensor_descriptor().get_length(number<1>{}); // M x N - const auto K = a_dram.get_tensor_descriptor().get_length(number<1>{}); // M x K + const auto M = a_dram.get_tensor_descriptor().get_length(number<0>{}); // M x K + const auto N = c_dram.get_tensor_descriptor().get_length(number<1>{}); // M x N + const auto K = a_dram.get_tensor_descriptor().get_length(number<1>{}); // M x K // Size of the block tile const auto MPerBlock = BlockTile::at(number<0>{}); @@ -83,7 +83,7 @@ struct PracticeGemmHostPipeline __shared__ char p_smem_char[block_gemm_pipeline.GetStaticLDSSize()]; const auto c_block_tile = block_gemm_pipeline(a_block_window, b_block_window, num_loops_k, p_smem_char); - auto c_window = make_tile_window(c_dram_ref, + auto c_window = make_tile_window(c_dram, make_tuple(number{}, number{}), {tile_origin_m, tile_origin_n}); store_tile(c_window, c_block_tile);