diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9c2a4ade95..daf3c258d9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -36,7 +36,7 @@ repos: name: Run ck_tile remod.py entry: python projects/composablekernel/script/remod_for_ck_tile.py language: python - files: '^(include|example)/ck_tile/.*$' + files: '^projects/composablekernel/(include|example)/ck_tile/.*$' additional_dependencies: - dos2unix - clang-format==18.1.3 diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index 9cff75f049..b045fb04fe 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -104,7 +104,7 @@ concept LdsTransferDescriptor = requires(T t) { template concept EpilogueDescriptor = requires(T t) { { t.m_xdl_per_wave_per_shuffle } -> SizeType; - { t.n_per_wave_per_shuffle } -> SizeType; + { t.n_xdl_per_wave_per_shuffle } -> SizeType; { t.scalar_per_vector } -> SizeType; }; diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp index 5196eae6c7..973fb2b012 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp @@ -6,24 +6,36 @@ #include #include #include -#include "ck_tile/core/utility/type_traits.hpp" #include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +/** + * @file conv_algorithm_limits.hpp + * @brief Compile-time validation concepts and helpers for convolution algorithm configurations + * + * This file provides C++20 concepts and compile-time validation functions for validating + * block transfer configurations, memory access patterns, and hardware instruction constraints + * in convolution algorithms. + * + * Key features: + * - Vector transfer size validation for VMEM and LDS operations + * - Access order permutation validation + * - Thread cluster dimension validation + * - Tile coverage validation for block transfers + */ namespace ck_tile::builder { -// Limits for input vector transfer. template concept InputVectorTransferLimits = requires { requires Value.src_vector_dim > 0 && Value.src_scalar_per_vector > 0 && Value.lds_dst_scalar_per_vector > 0; }; -// Limits for input and output vector transfer (CK Tile). template concept TileInputOutputVectorTransferLimits = requires { requires Value.a > 0 && Value.b > 0 && Value.c > 0; }; -// Limits for output vector transfer. template concept OutputVectorTransferLimits = requires { requires Value.scalar_per_vector > 0 && Value.m_xdl_per_wave_per_shuffle > 0 && @@ -174,13 +186,70 @@ constexpr auto get_mn_coverage() return mn; } -template -constexpr auto get_data_max_vec_size() +template +constexpr bool IsVmemVectorSizeValid() { - constexpr auto max_vec_inst_size_bytes = get_max_mem_vec_inst_width(); - static_assert(max_vec_inst_size_bytes % DataTypeSize == 0, - "The max vec instruction size is not a multiple of given data type size."); - return max_vec_inst_size_bytes / DataTypeSize; + using enum builder::DataType; + // We have following type & VectorSize pair constraints. + //----------------------------------------------------------------------------------- + // (std::is_same_v && (N == 1 || N == 2 || N == 4 || N == 8)) || + // (std::is_same_v && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + // (std::is_same_v && + // (N == 1 || N == 2 || N == 4 || N == 6 || N == 8 || N == 16 || N == 32)) || + // (std::is_same_v && + // (N == 1 || N == 2 || N == 4 || N == 6 || N == 8 || N == 16 || N == 32)) || + // (std::is_same_v && + // (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + // (std::is_same_v && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + // (std::is_same_v && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + // (std::is_same_v && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + // (std::is_same_v && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + // (std::is_same_v && + // (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) || + // (std::is_same_v && + // (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + // (std::is_same_v && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) + //----------------------------------------------------------------------------------- + // explicitly not using switch statement since we do not handle all possible data types + // in DataType structure yet, so that I could cover all of them in `else` branch. + if constexpr(Type == FP64) + { + return N == 1 || N == 2 || N == 4 || N == 8; + } + else if constexpr(Type == FP32) + { + return N == 1 || N == 2 || N == 4 || N == 8 || N == 16; + } + else if constexpr(Type == I32) + { + return N == 1 || N == 2 || N == 4 || N == 8 || N == 16; + } + else if constexpr(Type == FP16 || Type == BF16) + { + return N == 1 || N == 2 || N == 4 || N == 6 || N == 8 || N == 16 || N == 32; + } + else if constexpr(Type == FP8 || Type == BF8) + { + return N == 1 || N == 2 || N == 4 || N == 8 || N == 16; + } + else if constexpr(Type == I8) + { + return N == 1 || N == 2 || N == 4 || N == 8 || N == 16; + } + else + { + static_assert(always_false, "Unsupported memory instruction data type!"); + } +} + +// Valid LDS instruction bit sizes based on supported DS_READ/DS_WRITE operations +// DS_READ_{B32,B64,B96,B128,U8,I8,U16,I16} +// DS_WRITE_{B32,B64,B96,B128,B8,B16} +template +constexpr bool IsLDSVectorSizeValid() +{ + constexpr size_t bits = N * DataTypeSize * 8; + return ck_tile::is_any_value_of(bits, 8, 16, 32, 64, 96, 128); } } // namespace detail @@ -217,52 +286,52 @@ concept ThreadsCoverCTile = requires { CBlockTransfer.scalar_per_vector) == 0; }; -template -concept IsPowerOf2 = (Value > 0) && ((Value & (Value - 1)) == 0); +template +concept IsVmemVectorSizeValid = detail::IsVmemVectorSizeValid(); -template -concept IsVectorSizeValid = - IsPowerOf2 && (ScalarPerVec <= detail::get_data_max_vec_size()); +template +concept IsLDSVectorSizeValid = detail::IsLDSVectorSizeValid(); // Composite concept for input block transfer validation (A) // Includes all validations: vector transfer limits, access order, cluster size, // vector size validity, and tile coverage -template +template concept ValidABlockTransfer = - InputVectorTransferLimits && - AccessOrderLimits && - AccessOrderLimits && - ValidBlockTransferClusterSize && - IsVectorSizeValid && - IsVectorSizeValid && - ThreadsCoverATile; + InputVectorTransferLimits && + AccessOrderLimits && + AccessOrderLimits && + ValidBlockTransferClusterSize && + IsVmemVectorSizeValid && + IsLDSVectorSizeValid && + ThreadsCoverATile; // Composite concept for input block transfer validation (B) -template +template concept ValidBBlockTransfer = - InputVectorTransferLimits && - AccessOrderLimits && - AccessOrderLimits && - ValidBlockTransferClusterSize && - IsVectorSizeValid && - IsVectorSizeValid && - ThreadsCoverBTile; + InputVectorTransferLimits && + AccessOrderLimits && + AccessOrderLimits && + ValidBlockTransferClusterSize && + IsVmemVectorSizeValid && + IsLDSVectorSizeValid && + ThreadsCoverBTile; // Composite concept for output block transfer validation (C) -template -concept ValidCBlockTransfer = - OutputVectorTransferLimits && - ValidBlockTransferClusterSize && - IsVectorSizeValid && - ThreadsCoverCTile; +template +concept ValidCBlockTransfer = OutputVectorTransferLimits && + ValidBlockTransferClusterSize && + IsVmemVectorSizeValid && + ThreadsCoverCTile; // Usage: IsValidLayout template diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp index f7c98f244d..038f9847a6 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp @@ -48,15 +48,17 @@ struct ConvFwdLargeTensorFactory // Check limits for the data transfer parameters. static_assert(ValidABlockTransfer); static_assert(ValidBBlockTransfer); static_assert(ValidCBlockTransfer); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp index 7ea9938ea4..a417242e54 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp @@ -53,15 +53,17 @@ struct ConvFwdXdlV3Factory // Check limits for the algorithm parameters. static_assert(ValidABlockTransfer); static_assert(ValidBBlockTransfer); static_assert(ValidCBlockTransfer); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp index 652b032a9b..67cc5ce450 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp @@ -49,15 +49,17 @@ struct ConvFwdWmmaFactory // Check limits for the algorithm parameters. static_assert(ValidABlockTransfer); static_assert(ValidBBlockTransfer); static_assert(ValidCBlockTransfer); // TODO: verify Ds transfer as well diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp index 79bcd84981..bb1f5e8dda 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp @@ -48,15 +48,20 @@ struct ConvFwdXdlFactory // Check limits for the algorithm parameters. static_assert(ValidABlockTransfer); + static_assert(A_BLOCK_TRANSFER.src_vector_dim == 2 || + (ALGORITHM.num_conv_groups_to_merge > 1 && A_BLOCK_TRANSFER.src_vector_dim == 1)); static_assert(ValidBBlockTransfer); + static_assert(B_BLOCK_TRANSFER.src_vector_dim == 2); static_assert(ValidCBlockTransfer); @@ -74,8 +79,7 @@ struct ConvFwdXdlFactory NDHWGC, NGCW, NGCHW, - NGCDHW> && - A_BLOCK_TRANSFER.src_vector_dim == 2); + NGCDHW>); static_assert(IsValidLayout && - B_BLOCK_TRANSFER.src_vector_dim == 2); + GKCZYX>); static_assert(IsValidLayout -consteval auto GetTensorDataAndComputeTypes() +consteval auto ExtractTensorDataType() { - constexpr auto data_type = Config.data_type; - constexpr auto compute_type = Config.compute_type; + constexpr auto data_type = Config.data_type; using enum DataType; - - if constexpr(data_type == UNDEFINED_DATA_TYPE && compute_type == UNDEFINED_DATA_TYPE) + if constexpr(data_type == UNDEFINED_DATA_TYPE) { - return std::make_pair(ConvertDataTypeToCK(), - ConvertDataTypeToCK()); - } - else if constexpr(data_type == UNDEFINED_DATA_TYPE) - { - return std::make_pair(ConvertDataTypeToCK(), - ConvertDataTypeToCK()); - } - else if constexpr(compute_type == UNDEFINED_DATA_TYPE) - { - return std::make_pair(ConvertDataTypeToCK(), - ConvertDataTypeToCK()); + return SignatureDataType; } else { - return std::make_pair(ConvertDataTypeToCK(), - ConvertDataTypeToCK()); + return data_type; } } +template +consteval auto ExtractTensorComputeType() +{ + constexpr auto compute_type = Config.compute_type; + + using enum DataType; + if constexpr(compute_type == UNDEFINED_DATA_TYPE) + { + return SignatureDataType; + } + else + { + return compute_type; + } +} + +template +consteval auto GetTensorDataAndComputeTypes() +{ + constexpr auto data_type = ExtractTensorDataType(); + constexpr auto compute_type = ExtractTensorComputeType(); + + return std::make_pair(data_type, compute_type); +} + template consteval auto GetTensorAccumulationType() { @@ -158,6 +169,7 @@ consteval auto GetAuxiliaryTensorDataTypes() template struct ConvTensorDataTypes { + // Builder enumerator types static constexpr auto input_types = GetTensorDataAndComputeTypes(); static constexpr auto weight_types = @@ -165,12 +177,12 @@ struct ConvTensorDataTypes static constexpr auto output_types = GetTensorDataAndComputeTypes(); - using InDataType = typename decltype(input_types.first)::type; - using InComputeType = typename decltype(input_types.second)::type; - using WeiDataType = typename decltype(weight_types.first)::type; - using WeiComputeType = typename decltype(weight_types.second)::type; - using OutDataType = typename decltype(output_types.first)::type; - using OutComputeType = typename decltype(output_types.second)::type; + using InDataType = typename DataTypeToCK::type; + using InComputeType = typename DataTypeToCK::type; + using WeiDataType = typename DataTypeToCK::type; + using WeiComputeType = typename DataTypeToCK::type; + using OutDataType = typename DataTypeToCK::type; + using OutComputeType = typename DataTypeToCK::type; using AccDataType = typename decltype(GetTensorAccumulationType())::type; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp index d3ace110c4..a7af9f313f 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp @@ -29,7 +29,7 @@ TEST(FwdConvInstances, ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} .with_thread_block(ThreadBlock_64_64x32x32) .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) - .with_transfer(Transfer_4x16x1) + .with_transfer(Transfer_4x16x1_asrc_vec_dim1) .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) .with_prefetch_config(1, PipelineScheduler::DEFAULT) .with_num_conv_groups_to_merge(2); diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp index 06d200429c..1c180f4859 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp @@ -31,7 +31,7 @@ TEST(FwdConvInstances, ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle{} .with_thread_block(ThreadBlock_128_64x64x64) .with_gemm_config(GemmParams_Wmma_2x1_per_wave) - .with_transfer(Transfer_4x32x1) + .with_transfer(Transfer_4x16x1) .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) .with_prefetch_config(1, PipelineScheduler::DEFAULT) .with_num_conv_groups_to_merge(2) diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp index f5779bf5ae..c41a88fa1a 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp @@ -48,4 +48,81 @@ TEST(FwdConvInstances, "MNKPadding"}); } +// 3D FP32 NGCDHW (channels-first) with Pipeline V1 and FILTER_1X1_PAD0 +TEST( + FwdConvInstances, + Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP32_ChannelsFirst_LargeVecSize) +{ + using enum ck_tile::builder::ConvDirection; + using enum ck_tile::builder::DataType; + using enum ck_tile::builder::TensorLayout; + + constexpr ConvSignature FwdConvSignature{.spatial_dim = 3, + .direction = FORWARD, + .data_type = FP32, + .accumulation_data_type = FP32, + .input = {.config = {.layout = NGCDHW}}, + .weight = {.config = {.layout = GKCZYX}}, + .output = {.config = {.layout = NGKDHW}}}; + + constexpr Transfer<> Transfer_4x64x1_Vec16{ + .a = + { + .block_transfer = {.k0 = 2, .m_n = 128, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 16, + .lds_dst_scalar_per_vector = 4, + .is_direct_load = false, + .lds_padding = false}, + .thread_cluster_arrange_order = {1, 0, 2}, + .src_access_order = {1, 0, 2}, + }, + .b = + { + .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 4, + .lds_dst_scalar_per_vector = 4, + .is_direct_load = false, + .lds_padding = false}, + .thread_cluster_arrange_order = {1, 0, 2}, + .src_access_order = {1, 0, 2}, + }, + .c = + { + .thread_cluster_dims = + {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8}, + .epilogue = {.m_xdl_per_wave_per_shuffle = 1, + .n_xdl_per_wave_per_shuffle = 1, + .scalar_per_vector = 4}, + }, + }; + + constexpr GridwiseFwdXdlGemm FwdGemmParams{ + .ak1 = 16, + .bk1 = 8, + .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}}; + + constexpr auto FwdConvAlgorithm = + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} + .with_thread_block(ThreadBlock_256_256x256x32) + .with_gemm_config(FwdGemmParams) + .with_transfer(Transfer_4x64x1_Vec16) + .with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0, + GemmSpecialization::MNKPadding) + .with_block_gemm(BlockGemmDesc_v1_intrawave); + + using Builder = ConvBuilder; + + const auto expected_transfer_parameters = to_string(FwdConvAlgorithm); + run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3", + expected_transfer_parameters, + "Filter1x1Pad0", + "Intrawave", + "v1", + "NGCDHW,GKCZYX,EmptyTuple,NGKDHW", + "PassThrough,PassThrough,PassThrough", + "MNKPadding"}); +} + } // namespace diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index bcf17fd087..59d29b1280 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -133,7 +133,7 @@ static_assert(LdsTransferDescriptor); struct Epilogue { size_t m_xdl_per_wave_per_shuffle; - size_t n_per_wave_per_shuffle; + size_t n_xdl_per_wave_per_shuffle; size_t scalar_per_vector; }; static_assert(EpilogueDescriptor); diff --git a/experimental/builder/test/test_conv_description.cpp b/experimental/builder/test/test_conv_description.cpp index 40ea364ba9..aa2700c80e 100644 --- a/experimental/builder/test/test_conv_description.cpp +++ b/experimental/builder/test/test_conv_description.cpp @@ -154,7 +154,7 @@ struct DefaultAlgorithm .thread_cluster_dims = {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8}, .epilogue = {.m_xdl_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, + .n_xdl_per_wave_per_shuffle = 1, .scalar_per_vector = 2}, }, }; diff --git a/experimental/builder/test/utils/ckb_conv_test_configs.hpp b/experimental/builder/test/utils/ckb_conv_test_configs.hpp index 8b7d68f8db..641787f7df 100644 --- a/experimental/builder/test/utils/ckb_conv_test_configs.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_configs.hpp @@ -78,7 +78,7 @@ constexpr Transfer<> Transfer_4x64x1{ .thread_cluster_dims = {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8}, .epilogue = {.m_xdl_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, + .n_xdl_per_wave_per_shuffle = 1, .scalar_per_vector = 4}, }, }; @@ -111,7 +111,7 @@ constexpr Transfer<4> BwdTransfer_4x64x1{ .thread_cluster_dims = {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8}, .epilogue = {.m_xdl_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, + .n_xdl_per_wave_per_shuffle = 1, .scalar_per_vector = 8}, }, }; @@ -144,7 +144,7 @@ constexpr Transfer<> BwdTransfer_4x8x1_4x16x1_v3{ .thread_cluster_dims = {.m_block = 1, .m_wave_per_xdl = 8, .n_block = 1, .n_wave_per_xdl = 8}, .epilogue = {.m_xdl_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, + .n_xdl_per_wave_per_shuffle = 1, .scalar_per_vector = 2}, }, }; @@ -177,7 +177,7 @@ constexpr Transfer<> Transfer_4x64x1_fp8{ .thread_cluster_dims = {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8}, .epilogue = {.m_xdl_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, + .n_xdl_per_wave_per_shuffle = 1, .scalar_per_vector = 8}, }, }; @@ -210,12 +210,46 @@ constexpr Transfer<> Transfer_4x16x1{ .thread_cluster_dims = {.m_block = 1, .m_wave_per_xdl = 16, .n_block = 1, .n_wave_per_xdl = 4}, .epilogue = {.m_xdl_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, + .n_xdl_per_wave_per_shuffle = 1, .scalar_per_vector = 8}, }, }; +constexpr Transfer<> Transfer_4x16x1_asrc_vec_dim1{ + .a = + { + .block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 1, + .src_scalar_per_vector = 4, + .lds_dst_scalar_per_vector = 4, + .is_direct_load = false, + .lds_padding = true}, + .thread_cluster_arrange_order = {0, 2, 1}, + .src_access_order = {0, 2, 1}, + }, + .b = + { + .block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 1, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = true}, + .thread_cluster_arrange_order = {1, 0, 2}, + .src_access_order = {1, 0, 2}, + }, + .c = + { + .thread_cluster_dims = + {.m_block = 1, .m_wave_per_xdl = 16, .n_block = 1, .n_wave_per_xdl = 4}, + .epilogue = {.m_xdl_per_wave_per_shuffle = 1, + .n_xdl_per_wave_per_shuffle = 1, + .scalar_per_vector = 1}, + + }, +}; + constexpr Transfer<> Transfer_4x32x1{ .a = { @@ -244,7 +278,7 @@ constexpr Transfer<> Transfer_4x32x1{ .thread_cluster_dims = {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 4}, .epilogue = {.m_xdl_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, + .n_xdl_per_wave_per_shuffle = 1, .scalar_per_vector = 8}, }, }; diff --git a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp index cc7dde885a..ccf1b8da2f 100644 --- a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp +++ b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp @@ -194,8 +194,8 @@ template <> inline std::string to_string(OutputTransfer t) { std::ostringstream oss; - oss << t.epilogue.m_xdl_per_wave_per_shuffle << "," << t.epilogue.n_per_wave_per_shuffle << "," - << to_string(t.thread_cluster_dims) << "," << t.epilogue.scalar_per_vector; + oss << t.epilogue.m_xdl_per_wave_per_shuffle << "," << t.epilogue.n_xdl_per_wave_per_shuffle + << "," << to_string(t.thread_cluster_dims) << "," << t.epilogue.scalar_per_vector; return oss.str(); } diff --git a/script/remod_for_ck_tile.py b/script/remod_for_ck_tile.py index feb50dc290..84652680ee 100755 --- a/script/remod_for_ck_tile.py +++ b/script/remod_for_ck_tile.py @@ -4,8 +4,8 @@ import os root_dir = os.getcwd() -ck_tile_include = root_dir + "/include/ck_tile" -ck_tile_example = root_dir + "/example/ck_tile" +ck_tile_include = root_dir + "/projects/composablekernel/include/ck_tile" +ck_tile_example = root_dir + "/projects/composablekernel/example/ck_tile" # Run for include os.chdir(ck_tile_include)