diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 09d97bfb1e..65997a3308 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -311,9 +311,9 @@ struct CShuffleEpilogue using CWarpDstr = typename WG::CWarpDstr; using CWarpTensor = typename WG::CWarpTensor; using CWarpDstrEncoding = typename WG::CWarpDstrEncoding; - using SFC = space_filling_curve, - sequence<0, 1>, - sequence>; + using SFC = space_filling_curve, + sequence<1, 0>, + sequence>; template CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDescriptor() @@ -667,7 +667,7 @@ struct CShuffleEpilogue const ScaleN& scale_n = {}) { constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode()); - + //print(LdsTileDistr); auto lds_tile = make_static_distributed_tensor(LdsTileDistr); constexpr auto lds_block_desc = MakeLdsBlockDescriptor(); @@ -736,8 +736,8 @@ struct CShuffleEpilogue static_assert(GetVectorSizeC() > 1, "VectorSizeC is not greater than 1!"); using TileEncodingPattern = tile_distribution_encoding_pattern_2d; diff --git a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp index 96203b2cd2..a5dab8aa68 100644 --- a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp +++ b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp @@ -21,7 +21,7 @@ struct TileGemmTraits static constexpr bool kPadK = kPadK_; // TODO this can't be hardcoded here! Should be in policy! - static constexpr int _VectorSize = 16; + static constexpr int _VectorSize = 2; using AsLayout = AsLayout_; using BsLayout = BsLayout_; @@ -49,7 +49,7 @@ struct TileGemmUniversalTraits static constexpr bool kPadM = kPadM_; static constexpr bool kPadN = kPadN_; static constexpr bool kPadK = kPadK_; - static constexpr int _VectorSize = 16; + static constexpr int _VectorSize = 2; static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_; using AsLayout = AsLayout_; diff --git a/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp b/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp index 88f25834b1..48e3d0651e 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp @@ -29,6 +29,7 @@ using NonPersistent = std::false_type; using I16 = ck_tile::number<16>; using I32 = ck_tile::number<32>; using I64 = ck_tile::number<64>; +using I128 = ck_tile::number<128>; using I256 = ck_tile::number<256>; // clang-format off @@ -78,10 +79,21 @@ using KernelTypesMemWmma = ::testing::Types< >; using KernelTypesCompV3 = ::testing::Types< +<<<<<<< HEAD std::tuple< Row, Row, Col, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3> //std::tuple< Row, Row, Row, F16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, //std::tuple< Row, Row, Row, BF16, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, //std::tuple< Row, Row, Row, BF16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, +======= + std::tuple< Row, Row, Col, F16, F16, F32, F16, I128, I128, I64, I32, I32, I16, Intrawave, CompV3>//, + //std::tuple< Row, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + //std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + //std::tuple< Row, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + //std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + //std::tuple< Col, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + //std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + //std::tuple< Col, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, +>>>>>>> 0a9ceadca ([CK_TILE] working version) //std::tuple< Row, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, //std::tuple< Row, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, //std::tuple< Row, Row, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, diff --git a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp index d3d7485954..2b2810c7b6 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp @@ -334,10 +334,45 @@ class TestCkTileGemmPipeline : public ::testing::Test ck_tile::index_t stride_C = ck_tile::get_default_stride(M, N, StrideC, is_row_major(CLayout{})); - ck_tile::HostTensor a_m_k( - ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(ALayout{}))); - ck_tile::HostTensor b_k_n( - ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(BLayout{}))); + auto f_host_tensor_descriptor = [](std::size_t row, + std::size_t col, + std::size_t stride, + auto layout) { + if constexpr(std::is_same_v) + { + return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return ck_tile::HostTensorDescriptor({col, row}, {stride, 1_uz}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(stride == 0) + { + // give a chance if stride is zero, return a default packed stride + if constexpr(std::is_same_v) + { + return col; + } + else + { + return row; + } + } + else + return stride; + }; + + ck_tile::index_t stride_A = f_get_default_stride(M, K, StrideA, ALayout{}); + ck_tile::index_t stride_B = f_get_default_stride(K, N, StrideB, BLayout{}); + ck_tile::index_t stride_C = f_get_default_stride(M, N, StrideC, CLayout{}); + + ck_tile::HostTensor a_m_k(f_host_tensor_descriptor(M, K, stride_A, ALayout{})); + ck_tile::HostTensor b_k_n(f_host_tensor_descriptor(K, N, stride_B, BLayout{})); ck_tile::HostTensor c_m_n_dev_result( ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); @@ -348,9 +383,12 @@ class TestCkTileGemmPipeline : public ::testing::Test std::cout << "c_m_n_dev_result: "; c_m_n_dev_result.print_first_n(std::cout) << '\n'; - ck_tile::FillUniformDistributionIntegerValue{1, 2, 11939}(a_m_k); - ck_tile::FillUniformDistributionIntegerValue{1, 2, 11940}(b_k_n); + ck_tile::FillUniformDistributionIntegerValue{-5, 5, 11939}(a_m_k); + ck_tile::FillUniformDistributionIntegerValue{-5, 5, 11940}(b_k_n); + //ck_tile::FillConstant{1}(a_m_k); + //ck_tile::FillConstant{2}(b_k_n); + // FillConstant ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());