[CK_TILE] Move DataTypeTraits into a Common File (#3146)

This renames the typeToStr struct in the common utilities to DataTypeTraits and removes all duplication of DataTypeTraits across files in CK Tile.

Co-authored-by: Christopher Millette <63608002+cgmillette@users.noreply.github.com>
This commit is contained in:
arai713
2025-11-27 09:09:54 -08:00
committed by GitHub
parent 678298d4c7
commit 24d88d2472
17 changed files with 92 additions and 472 deletions

View File

@@ -10,6 +10,7 @@
#include <tuple>
#include "ck_tile/host.hpp"
#include "ck_tile/ops/common/utils.hpp"
#include "ck_tile/ops/reduce.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "gemm_utils.hpp"
@@ -589,9 +590,10 @@ float invoke_gemm_splitk_two_stage(ck_tile::DeviceMem& a_m_k_dev_buf,
<< " StrideA=" << stride_A << " StrideB=" << stride_B << " StrideC=" << stride_C
<< " kbatch=" << kbatch << " WorkspaceSize=" << workspace_size << " bytes"
<< " A_Layout=" << ALayout::name << " B_Layout =" << BLayout::name
<< " C_Layout=" << CLayout::name << " A_Type=" << DataTypeTraits<ADataType>::name
<< " B_Type=" << DataTypeTraits<BDataType>::name
<< " C_Type=" << DataTypeTraits<CDataType>::name
<< " C_Layout=" << CLayout::name
<< " A_Type=" << ck_tile::DataTypeTraits<ADataType>::name
<< " B_Type=" << ck_tile::DataTypeTraits<BDataType>::name
<< " C_Type=" << ck_tile::DataTypeTraits<CDataType>::name
<< " StructuredSparsity=" << (GemmConfig::UseStructuredSparsity ? "on" : "off")
<< " Persistent=" << (persistent ? "on" : "off") << " : " << ave_time << " ms, "
<< tflops << " TFlops, " << gb_per_sec << " GB/s" << std::endl;

View File

@@ -401,63 +401,6 @@ struct GemmTypeConfig<ck_tile::int8_t, ck_tile::int8_t, int32_t>
using CDataType = int32_t;
};
template <typename T>
struct DataTypeTraits;
template <>
struct DataTypeTraits<float>
{
static constexpr const char* name = "fp32";
};
template <>
struct DataTypeTraits<double>
{
static constexpr const char* name = "fp64";
};
template <>
struct DataTypeTraits<int32_t>
{
static constexpr const char* name = "int32";
};
template <>
struct DataTypeTraits<ck_tile::half_t>
{
static constexpr const char* name = "fp16";
};
template <>
struct DataTypeTraits<ck_tile::bf16_t>
{
static constexpr const char* name = "bf16";
};
template <>
struct DataTypeTraits<ck_tile::fp8_t>
{
static constexpr const char* name = "fp8";
};
template <>
struct DataTypeTraits<ck_tile::bf8_t>
{
static constexpr const char* name = "bf8";
};
template <>
struct DataTypeTraits<ck_tile::pk_int4_t>
{
static constexpr const char* name = "pk_int4_t";
};
template <>
struct DataTypeTraits<ck_tile::int8_t>
{
static constexpr const char* name = "int8";
};
template <ck_tile::GemmPipeline PipelineId>
struct PipelineTypeTraits;

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/host/permute_pk_int4.hpp"
#include "ck_tile/host/tensor_shuffle_utils.hpp"
#include "ck_tile/ops/common/utils.hpp"
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
@@ -372,9 +373,10 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
std::cout << "Run Gemm kernel with M=" << M << " N=" << N << " K=" << K
<< " StrideA=" << stride_A << " StrideB=" << stride_B << " StrideC=" << stride_C
<< " A_Layout=" << ALayout::name << " B_Layout =" << BLayout::name
<< " C_Layout=" << CLayout::name << " A_Type=" << DataTypeTraits<ADataType>::name
<< " B_Type=" << DataTypeTraits<BDataType>::name
<< " C_Type=" << DataTypeTraits<CDataType>::name
<< " C_Layout=" << CLayout::name
<< " A_Type=" << ck_tile::DataTypeTraits<ADataType>::name
<< " B_Type=" << ck_tile::DataTypeTraits<BDataType>::name
<< " C_Type=" << ck_tile::DataTypeTraits<CDataType>::name
<< " StructuredSparsity=" << (GemmConfig::UseStructuredSparsity ? "on" : "off")
<< " Persistent=" << (persistent ? "on" : "off") << " : " << ave_time << " ms, "
<< tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
@@ -442,18 +444,18 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
BDataType,
CDataType,
GemmConfig,
DataTypeTraits>(arg_parser.get_str("jsonfile"),
M,
N,
K,
stride_A,
stride_B,
stride_C,
persistent,
pass,
ave_time,
tflops,
gb_per_sec);
ck_tile::DataTypeTraits>(arg_parser.get_str("jsonfile"),
M,
N,
K,
stride_A,
stride_B,
stride_C,
persistent,
pass,
ave_time,
tflops,
gb_per_sec);
}
return pass;

View File

@@ -6,21 +6,6 @@
#include "ck_tile/utility/json_dump.hpp"
#include <cstring>
template <typename T>
struct DataTypeTraits;
template <>
struct DataTypeTraits<ck_tile::half_t>
{
static constexpr const char* name = "fp16";
};
template <>
struct DataTypeTraits<ck_tile::bf16_t>
{
static constexpr const char* name = "bf16";
};
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
@@ -145,7 +130,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(arg_parser.get_int("json") == 1)
{
dump_reduce_json_results<DataType, DataTypeTraits>(
dump_reduce_json_results<DataType, ck_tile::DataTypeTraits>(
arg_parser.get_str("jsonfile"), N, C, H, W, pass, ave_time, 0, gb_per_sec);
}

View File

@@ -136,38 +136,6 @@ struct GemmBasicTypeConfig<ck_tile::bf8_t>
using CDataType = ck_tile::half_t;
};
template <typename T>
struct DataTypeTraits;
template <>
struct DataTypeTraits<ck_tile::fp8_t>
{
static constexpr const char* name = "fp8";
};
template <>
struct DataTypeTraits<ck_tile::bf8_t>
{
static constexpr const char* name = "bf8";
};
template <>
struct DataTypeTraits<float>
{
static constexpr const char* name = "fp32";
};
template <>
struct DataTypeTraits<double>
{
static constexpr const char* name = "fp64";
};
template <>
struct DataTypeTraits<ck_tile::half_t>
{
static constexpr const char* name = "fp16";
};
template <typename T>
struct is_8bit_type
: std::bool_constant<std::is_same_v<T, ck_tile::fp8_t> || std::is_same_v<T, ck_tile::bf8_t>>

View File

@@ -134,38 +134,6 @@ struct GemmBasicTypeConfig<ck_tile::bf8_t>
using CDataType = ck_tile::half_t;
};
template <typename T>
struct DataTypeTraits;
template <>
struct DataTypeTraits<ck_tile::fp8_t>
{
static constexpr const char* name = "fp8";
};
template <>
struct DataTypeTraits<ck_tile::bf8_t>
{
static constexpr const char* name = "bf8";
};
template <>
struct DataTypeTraits<float>
{
static constexpr const char* name = "fp32";
};
template <>
struct DataTypeTraits<double>
{
static constexpr const char* name = "fp64";
};
template <>
struct DataTypeTraits<ck_tile::half_t>
{
static constexpr const char* name = "fp16";
};
template <typename T>
struct is_8bit_type
: std::bool_constant<std::is_same_v<T, ck_tile::fp8_t> || std::is_same_v<T, ck_tile::bf8_t>>

View File

@@ -254,27 +254,6 @@ struct ConvTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t>
using OutDataType = ck_tile::bf16_t;
};
template <typename T>
struct DataTypeTraits;
template <>
struct DataTypeTraits<float>
{
static constexpr const char* name = "fp32";
};
template <>
struct DataTypeTraits<ck_tile::half_t>
{
static constexpr const char* name = "fp16";
};
template <>
struct DataTypeTraits<ck_tile::bf16_t>
{
static constexpr const char* name = "bf16";
};
template <ck_tile::GemmPipeline PipelineId>
struct PipelineTypeTraits;

View File

@@ -281,59 +281,36 @@ struct GemmQuantTypeConfig
using CDataType = CDataType_;
};
template <typename T>
struct DataTypeTraits;
template <>
struct DataTypeTraits<float>
auto create_args(int argc, char* argv[])
{
static constexpr const char* name = "fp32";
};
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3840", "m dimension")
.insert("n", "4096", "n dimension")
.insert("k", "2048", "k dimension")
.insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("b_layout", "C", "B tensor data layout - Column by default")
.insert("bq_layout", "C", "Bq tensor data layout - Column by default")
.insert("c_layout", "R", "C tensor data layout - Row by default")
.insert("stride_a", "0", "Tensor A stride")
.insert("stride_q", "0", "Tensor AQ stride")
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")
.insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec",
"fp8",
"data type. For AQuant: fp8/bf8/i4fp8/i4bf8, For Bquant: fp8/bf8/fp8i4/bf8i4")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "1000", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value")
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
.insert("flush_cache", "true", "flush cache before running the kernel, defaults to true")
.insert("rotating_count", "1000", "rotating count, defaults to 1")
.insert("quant_mode", "bquant", "Choose aquant (default), bquant, tensor or rowcol")
.insert("group_size",
"1x1x128",
"Quantization group size as MxNxK, e.g., 1x1x128, 1x32x128, 1x64x128");
template <>
struct DataTypeTraits<double>
{
static constexpr const char* name = "fp64";
};
template <>
struct DataTypeTraits<int32_t>
{
static constexpr const char* name = "int32";
};
template <>
struct DataTypeTraits<ck_tile::half_t>
{
static constexpr const char* name = "fp16";
};
template <>
struct DataTypeTraits<ck_tile::bf16_t>
{
static constexpr const char* name = "bf16";
};
template <>
struct DataTypeTraits<ck_tile::fp8_t>
{
static constexpr const char* name = "fp8";
};
template <>
struct DataTypeTraits<ck_tile::bf8_t>
{
static constexpr const char* name = "bf8";
};
template <>
struct DataTypeTraits<ck_tile::pk_int4_t>
{
static constexpr const char* name = "pk_int4_t";
};
template <>
struct DataTypeTraits<ck_tile::int8_t>
{
static constexpr const char* name = "int8";
};
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}

View File

@@ -11,6 +11,7 @@
#include <tuple>
#include "ck_tile/core/config.hpp"
#include "ck_tile/ops/common/utils.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/host/permute_pk_int4.hpp"
#include "ck_tile/host/tensor_shuffle_utils.hpp"
@@ -321,15 +322,15 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
{
std::cout << " StrideBQ =" << stride_BQ;
}
std::cout << " A_Type = " << DataTypeTraits<typename TypeConfig::ADataType>::name
<< " AQ_Type = " << DataTypeTraits<typename TypeConfig::QDataType>::name
<< " B_Type = " << DataTypeTraits<typename TypeConfig::BDataType>::name;
std::cout << " A_Type = " << ck_tile::DataTypeTraits<typename TypeConfig::ADataType>::name
<< " AQ_Type = " << ck_tile::DataTypeTraits<typename TypeConfig::QDataType>::name
<< " B_Type = " << ck_tile::DataTypeTraits<typename TypeConfig::BDataType>::name;
if constexpr(!std::is_same_v<typename TypeConfig::QDataType, void>)
{
std::cout << " BQ_Type = " << DataTypeTraits<typename TypeConfig::QDataType>::name;
std::cout << " BQ_Type = " << ck_tile::DataTypeTraits<typename TypeConfig::QDataType>::name;
}
std::cout << " Acc_Type = " << DataTypeTraits<typename TypeConfig::AccDataType>::name
<< " C_Type = " << DataTypeTraits<typename TypeConfig::CDataType>::name
std::cout << " Acc_Type = " << ck_tile::DataTypeTraits<typename TypeConfig::AccDataType>::name
<< " C_Type = " << ck_tile::DataTypeTraits<typename TypeConfig::CDataType>::name
<< " QuantMode = " << quant_type_to_string(QuantMode)
<< " PreshuffleQuant = " << (GemmConfig::PreshuffleQuant ? "true" : "false") << " : "
<< " PreshuffleB = " << (GemmConfig::PreshuffleB ? "true" : "false") << " : "

View File

@@ -54,39 +54,6 @@ struct StreamKGemmTypeConfig
using CDataType = CDataType_;
};
template <typename T>
struct DataTypeTraits;
template <>
struct DataTypeTraits<float>
{
static constexpr const char* name = "fp32";
};
template <>
struct DataTypeTraits<ck_tile::half_t>
{
static constexpr const char* name = "fp16";
};
template <>
struct DataTypeTraits<ck_tile::bf16_t>
{
static constexpr const char* name = "bf16";
};
template <>
struct DataTypeTraits<ck_tile::fp8_t>
{
static constexpr const char* name = "fp8";
};
template <>
struct DataTypeTraits<ck_tile::bf8_t>
{
static constexpr const char* name = "bf8";
};
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;

View File

@@ -2,6 +2,8 @@
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/ops/common/utils.hpp"
template <typename Layout>
static constexpr inline auto is_row_major(Layout)
{
@@ -271,9 +273,10 @@ int run_gemm_example_with_layouts(int argc,
std::cout << "Run Gemm kernel with M=" << M << " N=" << N << " K=" << K
<< " StrideA=" << stride_A << " StrideB=" << stride_B << " StrideC=" << stride_C
<< " A_Layout=" << ALayout::name << " B_Layout=" << BLayout::name
<< " C_Layout=" << CLayout::name << " A_Type=" << DataTypeTraits<ADataType>::name
<< " B_Type=" << DataTypeTraits<BDataType>::name
<< " C_Type=" << DataTypeTraits<CDataType>::name
<< " C_Layout=" << CLayout::name
<< " A_Type=" << ck_tile::DataTypeTraits<ADataType>::name
<< " B_Type=" << ck_tile::DataTypeTraits<BDataType>::name
<< " C_Type=" << ck_tile::DataTypeTraits<CDataType>::name
<< " reduction_strategy=" << arg_parser.get_str("reduction_strategy") << " "
<< " persistent_dp=" << arg_parser.get_str("persistent_dp") << " " << ave_time
<< " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;

View File

@@ -11,15 +11,17 @@
namespace ck_tile {
// clang-format off
template <typename T> struct typeToStr;
template <> struct typeToStr<float> { static constexpr const char * name = "fp32"; };
template <> struct typeToStr<fp16_t> { static constexpr const char * name = "fp16"; };
template <> struct typeToStr<bf16_t> { static constexpr const char * name = "bf16"; };
template <> struct typeToStr<fp8_t> { static constexpr const char * name = "fp8"; };
template <> struct typeToStr<bf8_t> { static constexpr const char * name = "bf8"; };
template <> struct typeToStr<int8_t> { static constexpr const char * name = "int8"; };
template <> struct typeToStr<pk_int4_t> { static constexpr const char * name = "pk_int4"; };
template <> struct typeToStr<pk_fp4_t> { static constexpr const char * name = "pk_fp4"; };
template <typename T> struct DataTypeTraits;
template <> struct DataTypeTraits<float> { static constexpr const char * name = "fp32"; };
template <> struct DataTypeTraits<double> { static constexpr const char * name = "fp64"; };
template <> struct DataTypeTraits<int32_t> { static constexpr const char * name = "int32"; };
template <> struct DataTypeTraits<fp16_t> { static constexpr const char * name = "fp16"; };
template <> struct DataTypeTraits<bf16_t> { static constexpr const char * name = "bf16"; };
template <> struct DataTypeTraits<fp8_t> { static constexpr const char * name = "fp8"; };
template <> struct DataTypeTraits<bf8_t> { static constexpr const char * name = "bf8"; };
template <> struct DataTypeTraits<int8_t> { static constexpr const char * name = "int8"; };
template <> struct DataTypeTraits<pk_int4_t> { static constexpr const char * name = "pk_int4"; };
template <> struct DataTypeTraits<pk_fp4_t> { static constexpr const char * name = "pk_fp4"; };
template <memory_operation_enum MemOp> struct memOpToStr;
template <> struct memOpToStr<memory_operation_enum::set> { static constexpr const char * name = "set"; };
@@ -31,10 +33,10 @@ template <> struct memOpToStr<memory_operation_enum::add> { static constexpr con
template <typename ADataType_, typename BDataType_>
std::string gemm_prec_str()
{
std::string base_str = std::string(typeToStr<ADataType_>::name);
std::string base_str = std::string(DataTypeTraits<ADataType_>::name);
if(!std::is_same_v<ADataType_, BDataType_>)
{
base_str += "_" + std::string(typeToStr<BDataType_>::name);
base_str += "_" + std::string(DataTypeTraits<BDataType_>::name);
}
return base_str;
}

View File

@@ -11,12 +11,12 @@
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/common/utils.hpp"
#include "gemm_profiler.hpp"
#include "gemm_common.hpp"
// The kernel header is included via the compile command line with -include flag
// It defines SelectedKernel struct and KERNEL_NAME
// DataTypeTraits are now defined in gemm_common.hpp
// Create argument parser
inline auto create_args(int argc, char* argv[])
@@ -77,12 +77,12 @@ inline auto create_args(int argc, char* argv[])
void benchmark_single(const ck_tile::ArgParser& arg_parser)
{
// Use DataTypeTraits to get the actual type names from the generated header
// Use ck_tile::DataTypeTraits to get the actual type names from the generated header
// The generated header defines ADataType, BDataType, AccDataType, CDataType
std::string dtype_a = DataTypeTraits<ADataType>::name;
std::string dtype_b = DataTypeTraits<BDataType>::name;
std::string dtype_acc = DataTypeTraits<AccDataType>::name;
std::string dtype_c = DataTypeTraits<CDataType>::name;
std::string dtype_a = ck_tile::DataTypeTraits<ADataType>::name;
std::string dtype_b = ck_tile::DataTypeTraits<BDataType>::name;
std::string dtype_acc = ck_tile::DataTypeTraits<AccDataType>::name;
std::string dtype_c = ck_tile::DataTypeTraits<CDataType>::name;
// Layout names from the layout types
std::string layout_a = ALayout::name;

View File

@@ -9,65 +9,6 @@
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/pk_int4.hpp"
//[TODO] This can be moved to commons
// DataTypeTraits for all supported types
template <typename T>
struct DataTypeTraits;
template <>
struct DataTypeTraits<float>
{
static constexpr const char* name = "fp32";
};
template <>
struct DataTypeTraits<double>
{
static constexpr const char* name = "fp64";
};
template <>
struct DataTypeTraits<ck_tile::half_t>
{
static constexpr const char* name = "fp16";
};
template <>
struct DataTypeTraits<ck_tile::bf16_t>
{
static constexpr const char* name = "bf16";
};
template <>
struct DataTypeTraits<ck_tile::fp8_t>
{
static constexpr const char* name = "fp8";
};
template <>
struct DataTypeTraits<ck_tile::bf8_t>
{
static constexpr const char* name = "bf8";
};
template <>
struct DataTypeTraits<ck_tile::int8_t>
{
static constexpr const char* name = "int8";
};
template <>
struct DataTypeTraits<ck_tile::int32_t>
{
static constexpr const char* name = "int32";
};
template <>
struct DataTypeTraits<ck_tile::pk_int4_t>
{
static constexpr const char* name = "pk_int4_t";
};
// Helper function to determine if a layout is row-major
template <typename Layout>
constexpr auto is_row_major(Layout)

View File

@@ -9,65 +9,6 @@
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/pk_int4.hpp"
//[TODO] This can be moved to commons
// DataTypeTraits for all supported types
template <typename T>
struct DataTypeTraits;
template <>
struct DataTypeTraits<float>
{
static constexpr const char* name = "fp32";
};
template <>
struct DataTypeTraits<double>
{
static constexpr const char* name = "fp64";
};
template <>
struct DataTypeTraits<ck_tile::half_t>
{
static constexpr const char* name = "fp16";
};
template <>
struct DataTypeTraits<ck_tile::bf16_t>
{
static constexpr const char* name = "bf16";
};
template <>
struct DataTypeTraits<ck_tile::fp8_t>
{
static constexpr const char* name = "fp8";
};
template <>
struct DataTypeTraits<ck_tile::bf8_t>
{
static constexpr const char* name = "bf8";
};
template <>
struct DataTypeTraits<ck_tile::int8_t>
{
static constexpr const char* name = "int8";
};
template <>
struct DataTypeTraits<ck_tile::int32_t>
{
static constexpr const char* name = "int32";
};
template <>
struct DataTypeTraits<ck_tile::pk_int4_t>
{
static constexpr const char* name = "pk_int4_t";
};
// Helper function to determine if a layout is row-major
template <typename Layout>
constexpr auto is_row_major(Layout)

View File

@@ -11,12 +11,12 @@
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/common/utils.hpp"
#include "gemm_preshuffle_profiler.hpp"
#include "gemm_preshuffle_common.hpp"
// The kernel header is included via the compile command line with -include flag
// It defines SelectedKernel struct and KERNEL_NAME
// DataTypeTraits are now defined in gemm_common.hpp
// Create argument parser
inline auto create_args(int argc, char* argv[])
@@ -77,12 +77,12 @@ inline auto create_args(int argc, char* argv[])
void benchmark_single(const ck_tile::ArgParser& arg_parser)
{
// Use DataTypeTraits to get the actual type names from the generated header
// Use ck_tile::DataTypeTraits to get the actual type names from the generated header
// The generated header defines ADataType, BDataType, AccDataType, CDataType
std::string dtype_a = DataTypeTraits<ADataType>::name;
std::string dtype_b = DataTypeTraits<BDataType>::name;
std::string dtype_acc = DataTypeTraits<AccDataType>::name;
std::string dtype_c = DataTypeTraits<CDataType>::name;
std::string dtype_a = ck_tile::DataTypeTraits<ADataType>::name;
std::string dtype_b = ck_tile::DataTypeTraits<BDataType>::name;
std::string dtype_acc = ck_tile::DataTypeTraits<AccDataType>::name;
std::string dtype_c = ck_tile::DataTypeTraits<CDataType>::name;
// Layout names from the layout types
std::string layout_a = ALayout::name;

View File

@@ -9,65 +9,6 @@
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/pk_int4.hpp"
//[TODO] This can be moved to commons
// DataTypeTraits for all supported types
template <typename T>
struct DataTypeTraits;
template <>
struct DataTypeTraits<float>
{
static constexpr const char* name = "fp32";
};
template <>
struct DataTypeTraits<double>
{
static constexpr const char* name = "fp64";
};
template <>
struct DataTypeTraits<ck_tile::half_t>
{
static constexpr const char* name = "fp16";
};
template <>
struct DataTypeTraits<ck_tile::bf16_t>
{
static constexpr const char* name = "bf16";
};
template <>
struct DataTypeTraits<ck_tile::fp8_t>
{
static constexpr const char* name = "fp8";
};
template <>
struct DataTypeTraits<ck_tile::bf8_t>
{
static constexpr const char* name = "bf8";
};
template <>
struct DataTypeTraits<ck_tile::int8_t>
{
static constexpr const char* name = "int8";
};
template <>
struct DataTypeTraits<ck_tile::int32_t>
{
static constexpr const char* name = "int32";
};
template <>
struct DataTypeTraits<ck_tile::pk_int4_t>
{
static constexpr const char* name = "pk_int4_t";
};
// Helper function to determine if a layout is row-major
template <typename Layout>
constexpr auto is_row_major(Layout)