diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark_single.cpp b/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark_single.cpp index 41d2f736e1..25ac342f3e 100644 --- a/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark_single.cpp +++ b/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark_single.cpp @@ -80,12 +80,12 @@ void benchmark_single(const ck_tile::ArgParser& arg_parser) { // Use DataTypeTraits to get the actual type names from the generated header // The generated header defines ADataType, BDataType, AccDataType, CDataType - std::string dtype_a = DataTypeTraits::name; - std::string dtype_b = DataTypeTraits::name; - std::string dtype_acc = DataTypeTraits::name; - std::string dtype_c = DataTypeTraits::name; - std::string dtype_d0 = DataTypeTraits::name; - std::string dtype_d1 = DataTypeTraits::name; + std::string dtype_a = ck_tile::DataTypeTraits::name; + std::string dtype_b = ck_tile::DataTypeTraits::name; + std::string dtype_acc = ck_tile::DataTypeTraits::name; + std::string dtype_c = ck_tile::DataTypeTraits::name; + std::string dtype_d0 = ck_tile::DataTypeTraits::name; + std::string dtype_d1 = ck_tile::DataTypeTraits::name; // Layout names from the layout types std::string layout_a = ALayout::name; diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark_single.cpp b/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark_single.cpp index 13cadcd55a..5e88dc486a 100644 --- a/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark_single.cpp +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark_single.cpp @@ -83,10 +83,10 @@ void benchmark_gemm_single(const ck_tile::ArgParser& arg_parser) { // Use DataTypeTraits to get the actual type names from the generated header // The generated header defines ADataType, BDataType, AccDataType, CDataType - std::string dtype_a = DataTypeTraits::name; - std::string dtype_b = DataTypeTraits::name; - std::string dtype_acc = DataTypeTraits::name; - std::string dtype_c = DataTypeTraits::name; + std::string dtype_a = ck_tile::DataTypeTraits::name; + std::string dtype_b = ck_tile::DataTypeTraits::name; + std::string dtype_acc = ck_tile::DataTypeTraits::name; + std::string dtype_c = ck_tile::DataTypeTraits::name; // Layout names from the layout types std::string layout_a = ALayout::name; diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_common.hpp b/tile_engine/ops/gemm_streamk/gemm_streamk_common.hpp index 179aeb7307..15a3c91964 100644 --- a/tile_engine/ops/gemm_streamk/gemm_streamk_common.hpp +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_common.hpp @@ -6,67 +6,10 @@ #include #include "ck_tile/core.hpp" #include "ck_tile/host.hpp" +#include "ck_tile/ops/common/utils.hpp" #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/pk_int4.hpp" -// DataTypeTraits for all supported types -template -struct DataTypeTraits; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp32"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp64"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp16"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "bf16"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "bf8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "int8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "int32"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "pk_int4_t"; -}; - // Helper function to determine if a layout is row-major template constexpr auto is_row_major(Layout)