[CK_TILE] Fix for Moving DataTypeTraits into a Common File (#3335)

This PR fixes a mismatch caused when PR #3146 was merged out of sync with develop, which made its intended changes ineffective. This PR reapplies those changes to move DataTypeTraits into a common file to mitigate code duplication.

Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
arai713
2025-12-03 22:46:22 -08:00
committed by GitHub
parent ffc3120f63
commit 583fafc803
3 changed files with 11 additions and 68 deletions

View File

@@ -80,12 +80,12 @@ void benchmark_single(const ck_tile::ArgParser& arg_parser)
{
// Use DataTypeTraits to get the actual type names from the generated header
// The generated header defines ADataType, BDataType, AccDataType, CDataType
std::string dtype_a = DataTypeTraits<ADataType>::name;
std::string dtype_b = DataTypeTraits<BDataType>::name;
std::string dtype_acc = DataTypeTraits<AccDataType>::name;
std::string dtype_c = DataTypeTraits<CDataType>::name;
std::string dtype_d0 = DataTypeTraits<D0DataType>::name;
std::string dtype_d1 = DataTypeTraits<D1DataType>::name;
std::string dtype_a = ck_tile::DataTypeTraits<ADataType>::name;
std::string dtype_b = ck_tile::DataTypeTraits<BDataType>::name;
std::string dtype_acc = ck_tile::DataTypeTraits<AccDataType>::name;
std::string dtype_c = ck_tile::DataTypeTraits<CDataType>::name;
std::string dtype_d0 = ck_tile::DataTypeTraits<D0DataType>::name;
std::string dtype_d1 = ck_tile::DataTypeTraits<D1DataType>::name;
// Layout names from the layout types
std::string layout_a = ALayout::name;

View File

@@ -83,10 +83,10 @@ void benchmark_gemm_single(const ck_tile::ArgParser& arg_parser)
{
// Use DataTypeTraits to get the actual type names from the generated header
// The generated header defines ADataType, BDataType, AccDataType, CDataType
std::string dtype_a = DataTypeTraits<ADataType>::name;
std::string dtype_b = DataTypeTraits<BDataType>::name;
std::string dtype_acc = DataTypeTraits<AccDataType>::name;
std::string dtype_c = DataTypeTraits<CDataType>::name;
std::string dtype_a = ck_tile::DataTypeTraits<ADataType>::name;
std::string dtype_b = ck_tile::DataTypeTraits<BDataType>::name;
std::string dtype_acc = ck_tile::DataTypeTraits<AccDataType>::name;
std::string dtype_c = ck_tile::DataTypeTraits<CDataType>::name;
// Layout names from the layout types
std::string layout_a = ALayout::name;

View File

@@ -6,67 +6,10 @@
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/common/utils.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/pk_int4.hpp"
// DataTypeTraits for all supported types
template <typename T>
struct DataTypeTraits;
template <>
struct DataTypeTraits<float>
{
static constexpr const char* name = "fp32";
};
template <>
struct DataTypeTraits<double>
{
static constexpr const char* name = "fp64";
};
template <>
struct DataTypeTraits<ck_tile::half_t>
{
static constexpr const char* name = "fp16";
};
template <>
struct DataTypeTraits<ck_tile::bf16_t>
{
static constexpr const char* name = "bf16";
};
template <>
struct DataTypeTraits<ck_tile::fp8_t>
{
static constexpr const char* name = "fp8";
};
template <>
struct DataTypeTraits<ck_tile::bf8_t>
{
static constexpr const char* name = "bf8";
};
template <>
struct DataTypeTraits<ck_tile::int8_t>
{
static constexpr const char* name = "int8";
};
template <>
struct DataTypeTraits<ck_tile::int32_t>
{
static constexpr const char* name = "int32";
};
template <>
struct DataTypeTraits<ck_tile::pk_int4_t>
{
static constexpr const char* name = "pk_int4_t";
};
// Helper function to determine if a layout is row-major
template <typename Layout>
constexpr auto is_row_major(Layout)