From beaa1aa47ced919db89e3901dbe0cc8236b9ba49 Mon Sep 17 00:00:00 2001 From: arai713 <67439843+arai713@users.noreply.github.com> Date: Wed, 3 Dec 2025 22:46:22 -0800 Subject: [PATCH] [CK_TILE] Fix for Moving DataTypeTraits into a Common File (#3335) This PR fixes a mismatch caused when PR #3146 was merged out of sync with develop, which made its intended changes ineffective. This PR reapplies those changes to move DataTypeTraits into a common file to mitigate code duplication. Co-authored-by: Thomas Ning [ROCm/composable_kernel commit: 583fafc803a0ec9d0edc902fc6b9ecfdc42fb09b] --- .../gemm_multi_d_benchmark_single.cpp | 12 ++-- .../gemm_streamk_benchmark_single.cpp | 8 +-- .../ops/gemm_streamk/gemm_streamk_common.hpp | 59 +------------------ 3 files changed, 11 insertions(+), 68 deletions(-) diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark_single.cpp b/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark_single.cpp index 41d2f736e1..25ac342f3e 100644 --- a/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark_single.cpp +++ b/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark_single.cpp @@ -80,12 +80,12 @@ void benchmark_single(const ck_tile::ArgParser& arg_parser) { // Use DataTypeTraits to get the actual type names from the generated header // The generated header defines ADataType, BDataType, AccDataType, CDataType - std::string dtype_a = DataTypeTraits::name; - std::string dtype_b = DataTypeTraits::name; - std::string dtype_acc = DataTypeTraits::name; - std::string dtype_c = DataTypeTraits::name; - std::string dtype_d0 = DataTypeTraits::name; - std::string dtype_d1 = DataTypeTraits::name; + std::string dtype_a = ck_tile::DataTypeTraits::name; + std::string dtype_b = ck_tile::DataTypeTraits::name; + std::string dtype_acc = ck_tile::DataTypeTraits::name; + std::string dtype_c = ck_tile::DataTypeTraits::name; + std::string dtype_d0 = ck_tile::DataTypeTraits::name; + std::string dtype_d1 = ck_tile::DataTypeTraits::name; // Layout names from the layout types std::string layout_a = ALayout::name; diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark_single.cpp b/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark_single.cpp index 13cadcd55a..5e88dc486a 100644 --- a/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark_single.cpp +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark_single.cpp @@ -83,10 +83,10 @@ void benchmark_gemm_single(const ck_tile::ArgParser& arg_parser) { // Use DataTypeTraits to get the actual type names from the generated header // The generated header defines ADataType, BDataType, AccDataType, CDataType - std::string dtype_a = DataTypeTraits::name; - std::string dtype_b = DataTypeTraits::name; - std::string dtype_acc = DataTypeTraits::name; - std::string dtype_c = DataTypeTraits::name; + std::string dtype_a = ck_tile::DataTypeTraits::name; + std::string dtype_b = ck_tile::DataTypeTraits::name; + std::string dtype_acc = ck_tile::DataTypeTraits::name; + std::string dtype_c = ck_tile::DataTypeTraits::name; // Layout names from the layout types std::string layout_a = ALayout::name; diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_common.hpp b/tile_engine/ops/gemm_streamk/gemm_streamk_common.hpp index 179aeb7307..15a3c91964 100644 --- a/tile_engine/ops/gemm_streamk/gemm_streamk_common.hpp +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_common.hpp @@ -6,67 +6,10 @@ #include #include "ck_tile/core.hpp" #include "ck_tile/host.hpp" +#include "ck_tile/ops/common/utils.hpp" #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/pk_int4.hpp" -// DataTypeTraits for all supported types -template -struct DataTypeTraits; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp32"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp64"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp16"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "bf16"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "bf8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "int8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "int32"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "pk_int4_t"; -}; - // Helper function to determine if a layout is row-major template constexpr auto is_row_major(Layout)