From a06b88b3fa704363c6d6f03c9508a5ceba145032 Mon Sep 17 00:00:00 2001 From: Adam Osewski <19374865+aosewski@users.noreply.github.com> Date: Fri, 27 Feb 2026 14:48:01 +0100 Subject: [PATCH] [CK_BUILDER] ck builder conv transfer fix (#4750) ## Motivation This PR fixes how CK Builder is validating transfer vector size and adds proper validation for LDS transfer vector size as well. ## Changes: * [__source vector dim__] -- Before this PR the data transfer validation logic didn't allow to set the source vectorized dimension to 1. However there are CK instances that are doing this when the group merging is used. This is used only for `DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle` kernel. * [__valid vector size__] -- Before this PR the validation logic concerned only single instruction maximum vector size. However our buffer loading logic has implemented support for loading more values through multiple buffer instructions. This again was discovered to be used in some of the convolution instances. Thus this behavior was reflected in validation logic. * [__valid LDS vector size__] -- Before this PR the LDS vector size validation was done in the same way as VMEM. This PR adds proper LDS vector size validation based on the available LDS instruction sizes. ## Test Plan Run CK BUILDER conv fwd factories tests ## Test Result All CK BUILDER conv fwd factories work (except DL one & ck tile since they're not yet added now) ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --- .pre-commit-config.yaml | 2 +- .../builder/conv_algorithm_concepts.hpp | 2 +- .../ck_tile/builder/conv_algorithm_limits.hpp | 159 +++++++++++++----- .../factory/conv_fwd_large_tensor_factory.hpp | 8 +- .../builder/factory/conv_fwd_v3_factory.hpp | 8 +- .../builder/factory/conv_fwd_wmma_factory.hpp | 8 +- .../builder/factory/conv_fwd_xdl_factory.hpp | 17 +- .../helpers/ck/conv_block_transfer.hpp | 2 +- .../factory/helpers/ck/conv_tensor_type.hpp | 62 ++++--- .../conv/ck/test_ckb_conv_fwd_1d_fp16.cpp | 2 +- .../test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp | 2 +- .../conv/ck/test_ckb_conv_fwd_3d_fp32.cpp | 77 +++++++++ .../test/impl/conv_algorithm_types.hpp | 2 +- .../builder/test/test_conv_description.cpp | 2 +- .../test/utils/ckb_conv_test_configs.hpp | 46 ++++- .../test/utils/conv_algorithm_type_utils.hpp | 4 +- script/remod_for_ck_tile.py | 4 +- 17 files changed, 304 insertions(+), 103 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9c2a4ade95..daf3c258d9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -36,7 +36,7 @@ repos: name: Run ck_tile remod.py entry: python projects/composablekernel/script/remod_for_ck_tile.py language: python - files: '^(include|example)/ck_tile/.*$' + files: '^projects/composablekernel/(include|example)/ck_tile/.*$' additional_dependencies: - dos2unix - clang-format==18.1.3 diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index 9cff75f049..b045fb04fe 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -104,7 +104,7 @@ concept LdsTransferDescriptor = requires(T t) { template concept EpilogueDescriptor = requires(T t) { { t.m_xdl_per_wave_per_shuffle } -> SizeType; - { t.n_per_wave_per_shuffle } -> SizeType; + { t.n_xdl_per_wave_per_shuffle } -> SizeType; { t.scalar_per_vector } -> SizeType; }; diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp index 5196eae6c7..973fb2b012 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp @@ -6,24 +6,36 @@ #include #include #include -#include "ck_tile/core/utility/type_traits.hpp" #include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +/** + * @file conv_algorithm_limits.hpp + * @brief Compile-time validation concepts and helpers for convolution algorithm configurations + * + * This file provides C++20 concepts and compile-time validation functions for validating + * block transfer configurations, memory access patterns, and hardware instruction constraints + * in convolution algorithms. + * + * Key features: + * - Vector transfer size validation for VMEM and LDS operations + * - Access order permutation validation + * - Thread cluster dimension validation + * - Tile coverage validation for block transfers + */ namespace ck_tile::builder { -// Limits for input vector transfer. template concept InputVectorTransferLimits = requires { requires Value.src_vector_dim > 0 && Value.src_scalar_per_vector > 0 && Value.lds_dst_scalar_per_vector > 0; }; -// Limits for input and output vector transfer (CK Tile). template concept TileInputOutputVectorTransferLimits = requires { requires Value.a > 0 && Value.b > 0 && Value.c > 0; }; -// Limits for output vector transfer. template concept OutputVectorTransferLimits = requires { requires Value.scalar_per_vector > 0 && Value.m_xdl_per_wave_per_shuffle > 0 && @@ -174,13 +186,70 @@ constexpr auto get_mn_coverage() return mn; } -template -constexpr auto get_data_max_vec_size() +template +constexpr bool IsVmemVectorSizeValid() { - constexpr auto max_vec_inst_size_bytes = get_max_mem_vec_inst_width(); - static_assert(max_vec_inst_size_bytes % DataTypeSize == 0, - "The max vec instruction size is not a multiple of given data type size."); - return max_vec_inst_size_bytes / DataTypeSize; + using enum builder::DataType; + // We have following type & VectorSize pair constraints. + //----------------------------------------------------------------------------------- + // (std::is_same_v && (N == 1 || N == 2 || N == 4 || N == 8)) || + // (std::is_same_v && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + // (std::is_same_v && + // (N == 1 || N == 2 || N == 4 || N == 6 || N == 8 || N == 16 || N == 32)) || + // (std::is_same_v && + // (N == 1 || N == 2 || N == 4 || N == 6 || N == 8 || N == 16 || N == 32)) || + // (std::is_same_v && + // (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + // (std::is_same_v && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + // (std::is_same_v && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + // (std::is_same_v && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + // (std::is_same_v && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + // (std::is_same_v && + // (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) || + // (std::is_same_v && + // (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + // (std::is_same_v && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) + //----------------------------------------------------------------------------------- + // explicitly not using switch statement since we do not handle all possible data types + // in DataType structure yet, so that I could cover all of them in `else` branch. + if constexpr(Type == FP64) + { + return N == 1 || N == 2 || N == 4 || N == 8; + } + else if constexpr(Type == FP32) + { + return N == 1 || N == 2 || N == 4 || N == 8 || N == 16; + } + else if constexpr(Type == I32) + { + return N == 1 || N == 2 || N == 4 || N == 8 || N == 16; + } + else if constexpr(Type == FP16 || Type == BF16) + { + return N == 1 || N == 2 || N == 4 || N == 6 || N == 8 || N == 16 || N == 32; + } + else if constexpr(Type == FP8 || Type == BF8) + { + return N == 1 || N == 2 || N == 4 || N == 8 || N == 16; + } + else if constexpr(Type == I8) + { + return N == 1 || N == 2 || N == 4 || N == 8 || N == 16; + } + else + { + static_assert(always_false, "Unsupported memory instruction data type!"); + } +} + +// Valid LDS instruction bit sizes based on supported DS_READ/DS_WRITE operations +// DS_READ_{B32,B64,B96,B128,U8,I8,U16,I16} +// DS_WRITE_{B32,B64,B96,B128,B8,B16} +template +constexpr bool IsLDSVectorSizeValid() +{ + constexpr size_t bits = N * DataTypeSize * 8; + return ck_tile::is_any_value_of(bits, 8, 16, 32, 64, 96, 128); } } // namespace detail @@ -217,52 +286,52 @@ concept ThreadsCoverCTile = requires { CBlockTransfer.scalar_per_vector) == 0; }; -template -concept IsPowerOf2 = (Value > 0) && ((Value & (Value - 1)) == 0); +template +concept IsVmemVectorSizeValid = detail::IsVmemVectorSizeValid(); -template -concept IsVectorSizeValid = - IsPowerOf2 && (ScalarPerVec <= detail::get_data_max_vec_size()); +template +concept IsLDSVectorSizeValid = detail::IsLDSVectorSizeValid(); // Composite concept for input block transfer validation (A) // Includes all validations: vector transfer limits, access order, cluster size, // vector size validity, and tile coverage -template +template concept ValidABlockTransfer = - InputVectorTransferLimits && - AccessOrderLimits && - AccessOrderLimits && - ValidBlockTransferClusterSize && - IsVectorSizeValid && - IsVectorSizeValid && - ThreadsCoverATile; + InputVectorTransferLimits && + AccessOrderLimits && + AccessOrderLimits && + ValidBlockTransferClusterSize && + IsVmemVectorSizeValid && + IsLDSVectorSizeValid && + ThreadsCoverATile; // Composite concept for input block transfer validation (B) -template +template concept ValidBBlockTransfer = - InputVectorTransferLimits && - AccessOrderLimits && - AccessOrderLimits && - ValidBlockTransferClusterSize && - IsVectorSizeValid && - IsVectorSizeValid && - ThreadsCoverBTile; + InputVectorTransferLimits && + AccessOrderLimits && + AccessOrderLimits && + ValidBlockTransferClusterSize && + IsVmemVectorSizeValid && + IsLDSVectorSizeValid && + ThreadsCoverBTile; // Composite concept for output block transfer validation (C) -template -concept ValidCBlockTransfer = - OutputVectorTransferLimits && - ValidBlockTransferClusterSize && - IsVectorSizeValid && - ThreadsCoverCTile; +template +concept ValidCBlockTransfer = OutputVectorTransferLimits && + ValidBlockTransferClusterSize && + IsVmemVectorSizeValid && + ThreadsCoverCTile; // Usage: IsValidLayout template diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp index f7c98f244d..038f9847a6 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp @@ -48,15 +48,17 @@ struct ConvFwdLargeTensorFactory // Check limits for the data transfer parameters. static_assert(ValidABlockTransfer); static_assert(ValidBBlockTransfer); static_assert(ValidCBlockTransfer); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp index 7ea9938ea4..a417242e54 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp @@ -53,15 +53,17 @@ struct ConvFwdXdlV3Factory // Check limits for the algorithm parameters. static_assert(ValidABlockTransfer); static_assert(ValidBBlockTransfer); static_assert(ValidCBlockTransfer); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp index 652b032a9b..67cc5ce450 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp @@ -49,15 +49,17 @@ struct ConvFwdWmmaFactory // Check limits for the algorithm parameters. static_assert(ValidABlockTransfer); static_assert(ValidBBlockTransfer); static_assert(ValidCBlockTransfer); // TODO: verify Ds transfer as well diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp index 79bcd84981..bb1f5e8dda 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp @@ -48,15 +48,20 @@ struct ConvFwdXdlFactory // Check limits for the algorithm parameters. static_assert(ValidABlockTransfer); + static_assert(A_BLOCK_TRANSFER.src_vector_dim == 2 || + (ALGORITHM.num_conv_groups_to_merge > 1 && A_BLOCK_TRANSFER.src_vector_dim == 1)); static_assert(ValidBBlockTransfer); + static_assert(B_BLOCK_TRANSFER.src_vector_dim == 2); static_assert(ValidCBlockTransfer); @@ -74,8 +79,7 @@ struct ConvFwdXdlFactory NDHWGC, NGCW, NGCHW, - NGCDHW> && - A_BLOCK_TRANSFER.src_vector_dim == 2); + NGCDHW>); static_assert(IsValidLayout && - B_BLOCK_TRANSFER.src_vector_dim == 2); + GKCZYX>); static_assert(IsValidLayout -consteval auto GetTensorDataAndComputeTypes() +consteval auto ExtractTensorDataType() { - constexpr auto data_type = Config.data_type; - constexpr auto compute_type = Config.compute_type; + constexpr auto data_type = Config.data_type; using enum DataType; - - if constexpr(data_type == UNDEFINED_DATA_TYPE && compute_type == UNDEFINED_DATA_TYPE) + if constexpr(data_type == UNDEFINED_DATA_TYPE) { - return std::make_pair(ConvertDataTypeToCK(), - ConvertDataTypeToCK()); - } - else if constexpr(data_type == UNDEFINED_DATA_TYPE) - { - return std::make_pair(ConvertDataTypeToCK(), - ConvertDataTypeToCK()); - } - else if constexpr(compute_type == UNDEFINED_DATA_TYPE) - { - return std::make_pair(ConvertDataTypeToCK(), - ConvertDataTypeToCK()); + return SignatureDataType; } else { - return std::make_pair(ConvertDataTypeToCK(), - ConvertDataTypeToCK()); + return data_type; } } +template +consteval auto ExtractTensorComputeType() +{ + constexpr auto compute_type = Config.compute_type; + + using enum DataType; + if constexpr(compute_type == UNDEFINED_DATA_TYPE) + { + return SignatureDataType; + } + else + { + return compute_type; + } +} + +template +consteval auto GetTensorDataAndComputeTypes() +{ + constexpr auto data_type = ExtractTensorDataType(); + constexpr auto compute_type = ExtractTensorComputeType(); + + return std::make_pair(data_type, compute_type); +} + template consteval auto GetTensorAccumulationType() { @@ -158,6 +169,7 @@ consteval auto GetAuxiliaryTensorDataTypes() template struct ConvTensorDataTypes { + // Builder enumerator types static constexpr auto input_types = GetTensorDataAndComputeTypes(); static constexpr auto weight_types = @@ -165,12 +177,12 @@ struct ConvTensorDataTypes static constexpr auto output_types = GetTensorDataAndComputeTypes(); - using InDataType = typename decltype(input_types.first)::type; - using InComputeType = typename decltype(input_types.second)::type; - using WeiDataType = typename decltype(weight_types.first)::type; - using WeiComputeType = typename decltype(weight_types.second)::type; - using OutDataType = typename decltype(output_types.first)::type; - using OutComputeType = typename decltype(output_types.second)::type; + using InDataType = typename DataTypeToCK::type; + using InComputeType = typename DataTypeToCK::type; + using WeiDataType = typename DataTypeToCK::type; + using WeiComputeType = typename DataTypeToCK::type; + using OutDataType = typename DataTypeToCK::type; + using OutComputeType = typename DataTypeToCK::type; using AccDataType = typename decltype(GetTensorAccumulationType())::type; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp index d3ace110c4..a7af9f313f 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp @@ -29,7 +29,7 @@ TEST(FwdConvInstances, ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} .with_thread_block(ThreadBlock_64_64x32x32) .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) - .with_transfer(Transfer_4x16x1) + .with_transfer(Transfer_4x16x1_asrc_vec_dim1) .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) .with_prefetch_config(1, PipelineScheduler::DEFAULT) .with_num_conv_groups_to_merge(2); diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp index 06d200429c..1c180f4859 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp @@ -31,7 +31,7 @@ TEST(FwdConvInstances, ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle{} .with_thread_block(ThreadBlock_128_64x64x64) .with_gemm_config(GemmParams_Wmma_2x1_per_wave) - .with_transfer(Transfer_4x32x1) + .with_transfer(Transfer_4x16x1) .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) .with_prefetch_config(1, PipelineScheduler::DEFAULT) .with_num_conv_groups_to_merge(2) diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp index f5779bf5ae..c41a88fa1a 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp @@ -48,4 +48,81 @@ TEST(FwdConvInstances, "MNKPadding"}); } +// 3D FP32 NGCDHW (channels-first) with Pipeline V1 and FILTER_1X1_PAD0 +TEST( + FwdConvInstances, + Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP32_ChannelsFirst_LargeVecSize) +{ + using enum ck_tile::builder::ConvDirection; + using enum ck_tile::builder::DataType; + using enum ck_tile::builder::TensorLayout; + + constexpr ConvSignature FwdConvSignature{.spatial_dim = 3, + .direction = FORWARD, + .data_type = FP32, + .accumulation_data_type = FP32, + .input = {.config = {.layout = NGCDHW}}, + .weight = {.config = {.layout = GKCZYX}}, + .output = {.config = {.layout = NGKDHW}}}; + + constexpr Transfer<> Transfer_4x64x1_Vec16{ + .a = + { + .block_transfer = {.k0 = 2, .m_n = 128, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 16, + .lds_dst_scalar_per_vector = 4, + .is_direct_load = false, + .lds_padding = false}, + .thread_cluster_arrange_order = {1, 0, 2}, + .src_access_order = {1, 0, 2}, + }, + .b = + { + .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 4, + .lds_dst_scalar_per_vector = 4, + .is_direct_load = false, + .lds_padding = false}, + .thread_cluster_arrange_order = {1, 0, 2}, + .src_access_order = {1, 0, 2}, + }, + .c = + { + .thread_cluster_dims = + {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8}, + .epilogue = {.m_xdl_per_wave_per_shuffle = 1, + .n_xdl_per_wave_per_shuffle = 1, + .scalar_per_vector = 4}, + }, + }; + + constexpr GridwiseFwdXdlGemm FwdGemmParams{ + .ak1 = 16, + .bk1 = 8, + .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}}; + + constexpr auto FwdConvAlgorithm = + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} + .with_thread_block(ThreadBlock_256_256x256x32) + .with_gemm_config(FwdGemmParams) + .with_transfer(Transfer_4x64x1_Vec16) + .with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0, + GemmSpecialization::MNKPadding) + .with_block_gemm(BlockGemmDesc_v1_intrawave); + + using Builder = ConvBuilder; + + const auto expected_transfer_parameters = to_string(FwdConvAlgorithm); + run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3", + expected_transfer_parameters, + "Filter1x1Pad0", + "Intrawave", + "v1", + "NGCDHW,GKCZYX,EmptyTuple,NGKDHW", + "PassThrough,PassThrough,PassThrough", + "MNKPadding"}); +} + } // namespace diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index bcf17fd087..59d29b1280 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -133,7 +133,7 @@ static_assert(LdsTransferDescriptor); struct Epilogue { size_t m_xdl_per_wave_per_shuffle; - size_t n_per_wave_per_shuffle; + size_t n_xdl_per_wave_per_shuffle; size_t scalar_per_vector; }; static_assert(EpilogueDescriptor); diff --git a/experimental/builder/test/test_conv_description.cpp b/experimental/builder/test/test_conv_description.cpp index 40ea364ba9..aa2700c80e 100644 --- a/experimental/builder/test/test_conv_description.cpp +++ b/experimental/builder/test/test_conv_description.cpp @@ -154,7 +154,7 @@ struct DefaultAlgorithm .thread_cluster_dims = {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8}, .epilogue = {.m_xdl_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, + .n_xdl_per_wave_per_shuffle = 1, .scalar_per_vector = 2}, }, }; diff --git a/experimental/builder/test/utils/ckb_conv_test_configs.hpp b/experimental/builder/test/utils/ckb_conv_test_configs.hpp index 8b7d68f8db..641787f7df 100644 --- a/experimental/builder/test/utils/ckb_conv_test_configs.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_configs.hpp @@ -78,7 +78,7 @@ constexpr Transfer<> Transfer_4x64x1{ .thread_cluster_dims = {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8}, .epilogue = {.m_xdl_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, + .n_xdl_per_wave_per_shuffle = 1, .scalar_per_vector = 4}, }, }; @@ -111,7 +111,7 @@ constexpr Transfer<4> BwdTransfer_4x64x1{ .thread_cluster_dims = {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8}, .epilogue = {.m_xdl_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, + .n_xdl_per_wave_per_shuffle = 1, .scalar_per_vector = 8}, }, }; @@ -144,7 +144,7 @@ constexpr Transfer<> BwdTransfer_4x8x1_4x16x1_v3{ .thread_cluster_dims = {.m_block = 1, .m_wave_per_xdl = 8, .n_block = 1, .n_wave_per_xdl = 8}, .epilogue = {.m_xdl_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, + .n_xdl_per_wave_per_shuffle = 1, .scalar_per_vector = 2}, }, }; @@ -177,7 +177,7 @@ constexpr Transfer<> Transfer_4x64x1_fp8{ .thread_cluster_dims = {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8}, .epilogue = {.m_xdl_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, + .n_xdl_per_wave_per_shuffle = 1, .scalar_per_vector = 8}, }, }; @@ -210,12 +210,46 @@ constexpr Transfer<> Transfer_4x16x1{ .thread_cluster_dims = {.m_block = 1, .m_wave_per_xdl = 16, .n_block = 1, .n_wave_per_xdl = 4}, .epilogue = {.m_xdl_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, + .n_xdl_per_wave_per_shuffle = 1, .scalar_per_vector = 8}, }, }; +constexpr Transfer<> Transfer_4x16x1_asrc_vec_dim1{ + .a = + { + .block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 1, + .src_scalar_per_vector = 4, + .lds_dst_scalar_per_vector = 4, + .is_direct_load = false, + .lds_padding = true}, + .thread_cluster_arrange_order = {0, 2, 1}, + .src_access_order = {0, 2, 1}, + }, + .b = + { + .block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 1, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = true}, + .thread_cluster_arrange_order = {1, 0, 2}, + .src_access_order = {1, 0, 2}, + }, + .c = + { + .thread_cluster_dims = + {.m_block = 1, .m_wave_per_xdl = 16, .n_block = 1, .n_wave_per_xdl = 4}, + .epilogue = {.m_xdl_per_wave_per_shuffle = 1, + .n_xdl_per_wave_per_shuffle = 1, + .scalar_per_vector = 1}, + + }, +}; + constexpr Transfer<> Transfer_4x32x1{ .a = { @@ -244,7 +278,7 @@ constexpr Transfer<> Transfer_4x32x1{ .thread_cluster_dims = {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 4}, .epilogue = {.m_xdl_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, + .n_xdl_per_wave_per_shuffle = 1, .scalar_per_vector = 8}, }, }; diff --git a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp index cc7dde885a..ccf1b8da2f 100644 --- a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp +++ b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp @@ -194,8 +194,8 @@ template <> inline std::string to_string(OutputTransfer t) { std::ostringstream oss; - oss << t.epilogue.m_xdl_per_wave_per_shuffle << "," << t.epilogue.n_per_wave_per_shuffle << "," - << to_string(t.thread_cluster_dims) << "," << t.epilogue.scalar_per_vector; + oss << t.epilogue.m_xdl_per_wave_per_shuffle << "," << t.epilogue.n_xdl_per_wave_per_shuffle + << "," << to_string(t.thread_cluster_dims) << "," << t.epilogue.scalar_per_vector; return oss.str(); } diff --git a/script/remod_for_ck_tile.py b/script/remod_for_ck_tile.py index feb50dc290..84652680ee 100755 --- a/script/remod_for_ck_tile.py +++ b/script/remod_for_ck_tile.py @@ -4,8 +4,8 @@ import os root_dir = os.getcwd() -ck_tile_include = root_dir + "/include/ck_tile" -ck_tile_example = root_dir + "/example/ck_tile" +ck_tile_include = root_dir + "/projects/composablekernel/include/ck_tile" +ck_tile_example = root_dir + "/projects/composablekernel/example/ck_tile" # Run for include os.chdir(ck_tile_include)