From 46f27e2ab0c43581de97abd3a066eeebdb21fa24 Mon Sep 17 00:00:00 2001 From: Aleksander Dudek Date: Wed, 22 Oct 2025 10:22:47 -0500 Subject: [PATCH] [CK_TILE] working version and tests --- .../gemm/test_gemm_pipeline_kernel_types.hpp | 15 +++++++++++++++ test/ck_tile/gemm/test_gemm_pipeline_util.hpp | 17 ++++++++++++++--- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp b/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp index 48e3d0651e..34761831a5 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp @@ -79,6 +79,7 @@ using KernelTypesMemWmma = ::testing::Types< >; using KernelTypesCompV3 = ::testing::Types< +<<<<<<< HEAD <<<<<<< HEAD std::tuple< Row, Row, Col, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3> //std::tuple< Row, Row, Row, F16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, @@ -130,6 +131,20 @@ using KernelTypesCompV3 = ::testing::Types< //std::tuple< Col, Col, Row, F8, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, //std::tuple< Col, Col, Row, BF8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, //std::tuple< Col, Col, Row, BF8, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3> +======= + std::tuple< Row, Row, Col, F16, F16, F32, F16, I128, I128, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Row, Row, Col, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Row, Col, Col, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Row, Col, Col, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Col, Row, Col, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Col, Row, Col, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Col, Col, Col, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Col, Col, Col, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Row, Row, Col, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Row, Col, Col, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Col, Row, Col, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Col, Col, Col, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3> +>>>>>>> 12c48382b ([CK_TILE] working version and tests) >; using KernelTypesCompV3Wmma = ::testing::Types< diff --git a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp index 2b2810c7b6..4774c2c55a 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp @@ -343,6 +343,20 @@ class TestCkTileGemmPipeline : public ::testing::Test return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz}); } else + { + return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_host_tensor_descriptor_out = [](std::size_t row, + std::size_t col, + std::size_t stride, + auto layout) { + if constexpr(std::is_same_v) + { + return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else { return ck_tile::HostTensorDescriptor({col, row}, {stride, 1_uz}); } @@ -386,9 +400,6 @@ class TestCkTileGemmPipeline : public ::testing::Test ck_tile::FillUniformDistributionIntegerValue{-5, 5, 11939}(a_m_k); ck_tile::FillUniformDistributionIntegerValue{-5, 5, 11940}(b_k_n); - //ck_tile::FillConstant{1}(a_m_k); - //ck_tile::FillConstant{2}(b_k_n); - // FillConstant ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());