diff --git a/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp b/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp index 354d236c20..74edddb6c9 100644 --- a/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp +++ b/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp @@ -10,6 +10,7 @@ #include #include "ck_tile/host.hpp" +#include "ck_tile/ops/common/utils.hpp" #include "ck_tile/ops/reduce.hpp" #include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" #include "gemm_utils.hpp" @@ -589,9 +590,10 @@ float invoke_gemm_splitk_two_stage(ck_tile::DeviceMem& a_m_k_dev_buf, << " StrideA=" << stride_A << " StrideB=" << stride_B << " StrideC=" << stride_C << " kbatch=" << kbatch << " WorkspaceSize=" << workspace_size << " bytes" << " A_Layout=" << ALayout::name << " B_Layout =" << BLayout::name - << " C_Layout=" << CLayout::name << " A_Type=" << DataTypeTraits::name - << " B_Type=" << DataTypeTraits::name - << " C_Type=" << DataTypeTraits::name + << " C_Layout=" << CLayout::name + << " A_Type=" << ck_tile::DataTypeTraits::name + << " B_Type=" << ck_tile::DataTypeTraits::name + << " C_Type=" << ck_tile::DataTypeTraits::name << " StructuredSparsity=" << (GemmConfig::UseStructuredSparsity ? "on" : "off") << " Persistent=" << (persistent ? "on" : "off") << " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" << std::endl; diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index bdc37e5a94..b25aec101b 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -401,63 +401,6 @@ struct GemmTypeConfig using CDataType = int32_t; }; -template -struct DataTypeTraits; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp32"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp64"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "int32"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp16"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "bf16"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "bf8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "pk_int4_t"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "int8"; -}; - template struct PipelineTypeTraits; diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index 204114d6bb..30cb3d3476 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/host/permute_pk_int4.hpp" #include "ck_tile/host/tensor_shuffle_utils.hpp" +#include "ck_tile/ops/common/utils.hpp" template static constexpr inline auto is_row_major(Layout layout_) @@ -372,9 +373,10 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser, std::cout << "Run Gemm kernel with M=" << M << " N=" << N << " K=" << K << " StrideA=" << stride_A << " StrideB=" << stride_B << " StrideC=" << stride_C << " A_Layout=" << ALayout::name << " B_Layout =" << BLayout::name - << " C_Layout=" << CLayout::name << " A_Type=" << DataTypeTraits::name - << " B_Type=" << DataTypeTraits::name - << " C_Type=" << DataTypeTraits::name + << " C_Layout=" << CLayout::name + << " A_Type=" << ck_tile::DataTypeTraits::name + << " B_Type=" << ck_tile::DataTypeTraits::name + << " C_Type=" << ck_tile::DataTypeTraits::name << " StructuredSparsity=" << (GemmConfig::UseStructuredSparsity ? "on" : "off") << " Persistent=" << (persistent ? "on" : "off") << " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; @@ -442,18 +444,18 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser, BDataType, CDataType, GemmConfig, - DataTypeTraits>(arg_parser.get_str("jsonfile"), - M, - N, - K, - stride_A, - stride_B, - stride_C, - persistent, - pass, - ave_time, - tflops, - gb_per_sec); + ck_tile::DataTypeTraits>(arg_parser.get_str("jsonfile"), + M, + N, + K, + stride_A, + stride_B, + stride_C, + persistent, + pass, + ave_time, + tflops, + gb_per_sec); } return pass; diff --git a/example/ck_tile/05_reduce/reduce.cpp b/example/ck_tile/05_reduce/reduce.cpp index f1509bfeef..677065c78d 100644 --- a/example/ck_tile/05_reduce/reduce.cpp +++ b/example/ck_tile/05_reduce/reduce.cpp @@ -6,21 +6,6 @@ #include "ck_tile/utility/json_dump.hpp" #include -template -struct DataTypeTraits; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp16"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "bf16"; -}; - auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; @@ -145,7 +130,7 @@ bool run(const ck_tile::ArgParser& arg_parser) if(arg_parser.get_int("json") == 1) { - dump_reduce_json_results( + dump_reduce_json_results( arg_parser.get_str("jsonfile"), N, C, H, W, pass, ave_time, 0, gb_per_sec); } diff --git a/example/ck_tile/18_flatmm/flatmm_basic.hpp b/example/ck_tile/18_flatmm/flatmm_basic.hpp index 47211bdbbc..ae1fa22bb0 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.hpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.hpp @@ -136,38 +136,6 @@ struct GemmBasicTypeConfig using CDataType = ck_tile::half_t; }; -template -struct DataTypeTraits; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "bf8"; -}; -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp32"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp64"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp16"; -}; - template struct is_8bit_type : std::bool_constant || std::is_same_v> diff --git a/example/ck_tile/18_flatmm/moe_flatmm.hpp b/example/ck_tile/18_flatmm/moe_flatmm.hpp index b464aaa73a..47d969fadb 100644 --- a/example/ck_tile/18_flatmm/moe_flatmm.hpp +++ b/example/ck_tile/18_flatmm/moe_flatmm.hpp @@ -134,38 +134,6 @@ struct GemmBasicTypeConfig using CDataType = ck_tile::half_t; }; -template -struct DataTypeTraits; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "bf8"; -}; -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp32"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp64"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp16"; -}; - template struct is_8bit_type : std::bool_constant || std::is_same_v> diff --git a/example/ck_tile/20_grouped_convolution/conv_configs.hpp b/example/ck_tile/20_grouped_convolution/conv_configs.hpp index 238e3810f0..620b505820 100644 --- a/example/ck_tile/20_grouped_convolution/conv_configs.hpp +++ b/example/ck_tile/20_grouped_convolution/conv_configs.hpp @@ -254,27 +254,6 @@ struct ConvTypeConfig using OutDataType = ck_tile::bf16_t; }; -template -struct DataTypeTraits; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp32"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp16"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "bf16"; -}; - template struct PipelineTypeTraits; diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index 95b0a73ede..bd9f93ce23 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -281,59 +281,36 @@ struct GemmQuantTypeConfig using CDataType = CDataType_; }; -template -struct DataTypeTraits; - -template <> -struct DataTypeTraits +auto create_args(int argc, char* argv[]) { - static constexpr const char* name = "fp32"; -}; + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3840", "m dimension") + .insert("n", "4096", "n dimension") + .insert("k", "2048", "k dimension") + .insert("a_layout", "R", "A tensor data layout - Row by default") + .insert("b_layout", "C", "B tensor data layout - Column by default") + .insert("bq_layout", "C", "Bq tensor data layout - Column by default") + .insert("c_layout", "R", "C tensor data layout - Row by default") + .insert("stride_a", "0", "Tensor A stride") + .insert("stride_q", "0", "Tensor AQ stride") + .insert("stride_b", "0", "Tensor B stride") + .insert("stride_c", "0", "Tensor C stride") + .insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") + .insert("prec", + "fp8", + "data type. For AQuant: fp8/bf8/i4fp8/i4bf8, For Bquant: fp8/bf8/fp8i4/bf8i4") + .insert("warmup", "50", "number of iterations before benchmark the kernel") + .insert("repeat", "1000", "number of iterations to benchmark the kernel") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("split_k", "1", "splitK value") + .insert("init", "0", "0:random, 1:linear, 2:constant(1)") + .insert("flush_cache", "true", "flush cache before running the kernel, defaults to true") + .insert("rotating_count", "1000", "rotating count, defaults to 1") + .insert("quant_mode", "bquant", "Choose aquant (default), bquant, tensor or rowcol") + .insert("group_size", + "1x1x128", + "Quantization group size as MxNxK, e.g., 1x1x128, 1x32x128, 1x64x128"); -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp64"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "int32"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp16"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "bf16"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "bf8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "pk_int4_t"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "int8"; -}; + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 2162141156..c9a57e7754 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -11,6 +11,7 @@ #include #include "ck_tile/core/config.hpp" +#include "ck_tile/ops/common/utils.hpp" #include "ck_tile/host.hpp" #include "ck_tile/host/permute_pk_int4.hpp" #include "ck_tile/host/tensor_shuffle_utils.hpp" @@ -321,15 +322,15 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, { std::cout << " StrideBQ =" << stride_BQ; } - std::cout << " A_Type = " << DataTypeTraits::name - << " AQ_Type = " << DataTypeTraits::name - << " B_Type = " << DataTypeTraits::name; + std::cout << " A_Type = " << ck_tile::DataTypeTraits::name + << " AQ_Type = " << ck_tile::DataTypeTraits::name + << " B_Type = " << ck_tile::DataTypeTraits::name; if constexpr(!std::is_same_v) { - std::cout << " BQ_Type = " << DataTypeTraits::name; + std::cout << " BQ_Type = " << ck_tile::DataTypeTraits::name; } - std::cout << " Acc_Type = " << DataTypeTraits::name - << " C_Type = " << DataTypeTraits::name + std::cout << " Acc_Type = " << ck_tile::DataTypeTraits::name + << " C_Type = " << ck_tile::DataTypeTraits::name << " QuantMode = " << quant_type_to_string(QuantMode) << " PreshuffleQuant = " << (GemmConfig::PreshuffleQuant ? "true" : "false") << " : " << " PreshuffleB = " << (GemmConfig::PreshuffleB ? "true" : "false") << " : " diff --git a/example/ck_tile/40_streamk_gemm/gemm_utils.hpp b/example/ck_tile/40_streamk_gemm/gemm_utils.hpp index 37aeec868a..dad31ec637 100644 --- a/example/ck_tile/40_streamk_gemm/gemm_utils.hpp +++ b/example/ck_tile/40_streamk_gemm/gemm_utils.hpp @@ -54,39 +54,6 @@ struct StreamKGemmTypeConfig using CDataType = CDataType_; }; -template -struct DataTypeTraits; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp32"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp16"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "bf16"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "bf8"; -}; - auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; diff --git a/example/ck_tile/40_streamk_gemm/run_gemm_example.inc b/example/ck_tile/40_streamk_gemm/run_gemm_example.inc index 041acff509..206b9c37fc 100644 --- a/example/ck_tile/40_streamk_gemm/run_gemm_example.inc +++ b/example/ck_tile/40_streamk_gemm/run_gemm_example.inc @@ -2,6 +2,8 @@ // SPDX-License-Identifier: MIT #pragma once +#include "ck_tile/ops/common/utils.hpp" + template static constexpr inline auto is_row_major(Layout) { @@ -271,9 +273,10 @@ int run_gemm_example_with_layouts(int argc, std::cout << "Run Gemm kernel with M=" << M << " N=" << N << " K=" << K << " StrideA=" << stride_A << " StrideB=" << stride_B << " StrideC=" << stride_C << " A_Layout=" << ALayout::name << " B_Layout=" << BLayout::name - << " C_Layout=" << CLayout::name << " A_Type=" << DataTypeTraits::name - << " B_Type=" << DataTypeTraits::name - << " C_Type=" << DataTypeTraits::name + << " C_Layout=" << CLayout::name + << " A_Type=" << ck_tile::DataTypeTraits::name + << " B_Type=" << ck_tile::DataTypeTraits::name + << " C_Type=" << ck_tile::DataTypeTraits::name << " reduction_strategy=" << arg_parser.get_str("reduction_strategy") << " " << " persistent_dp=" << arg_parser.get_str("persistent_dp") << " " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; diff --git a/include/ck_tile/ops/common/utils.hpp b/include/ck_tile/ops/common/utils.hpp index 38da8eed1e..318a1a5860 100644 --- a/include/ck_tile/ops/common/utils.hpp +++ b/include/ck_tile/ops/common/utils.hpp @@ -11,15 +11,17 @@ namespace ck_tile { // clang-format off -template struct typeToStr; -template <> struct typeToStr { static constexpr const char * name = "fp32"; }; -template <> struct typeToStr { static constexpr const char * name = "fp16"; }; -template <> struct typeToStr { static constexpr const char * name = "bf16"; }; -template <> struct typeToStr { static constexpr const char * name = "fp8"; }; -template <> struct typeToStr { static constexpr const char * name = "bf8"; }; -template <> struct typeToStr { static constexpr const char * name = "int8"; }; -template <> struct typeToStr { static constexpr const char * name = "pk_int4"; }; -template <> struct typeToStr { static constexpr const char * name = "pk_fp4"; }; +template struct DataTypeTraits; +template <> struct DataTypeTraits { static constexpr const char * name = "fp32"; }; +template <> struct DataTypeTraits { static constexpr const char * name = "fp64"; }; +template <> struct DataTypeTraits { static constexpr const char * name = "int32"; }; +template <> struct DataTypeTraits { static constexpr const char * name = "fp16"; }; +template <> struct DataTypeTraits { static constexpr const char * name = "bf16"; }; +template <> struct DataTypeTraits { static constexpr const char * name = "fp8"; }; +template <> struct DataTypeTraits { static constexpr const char * name = "bf8"; }; +template <> struct DataTypeTraits { static constexpr const char * name = "int8"; }; +template <> struct DataTypeTraits { static constexpr const char * name = "pk_int4"; }; +template <> struct DataTypeTraits { static constexpr const char * name = "pk_fp4"; }; template struct memOpToStr; template <> struct memOpToStr { static constexpr const char * name = "set"; }; @@ -31,10 +33,10 @@ template <> struct memOpToStr { static constexpr con template std::string gemm_prec_str() { - std::string base_str = std::string(typeToStr::name); + std::string base_str = std::string(DataTypeTraits::name); if(!std::is_same_v) { - base_str += "_" + std::string(typeToStr::name); + base_str += "_" + std::string(DataTypeTraits::name); } return base_str; } diff --git a/tile_engine/ops/gemm/gemm_benchmark_single.cpp b/tile_engine/ops/gemm/gemm_benchmark_single.cpp index 6323c066a1..26f3a3928a 100644 --- a/tile_engine/ops/gemm/gemm_benchmark_single.cpp +++ b/tile_engine/ops/gemm/gemm_benchmark_single.cpp @@ -11,12 +11,12 @@ #include "ck_tile/core.hpp" #include "ck_tile/host.hpp" +#include "ck_tile/ops/common/utils.hpp" #include "gemm_profiler.hpp" #include "gemm_common.hpp" // The kernel header is included via the compile command line with -include flag // It defines SelectedKernel struct and KERNEL_NAME -// DataTypeTraits are now defined in gemm_common.hpp // Create argument parser inline auto create_args(int argc, char* argv[]) @@ -77,12 +77,12 @@ inline auto create_args(int argc, char* argv[]) void benchmark_single(const ck_tile::ArgParser& arg_parser) { - // Use DataTypeTraits to get the actual type names from the generated header + // Use ck_tile::DataTypeTraits to get the actual type names from the generated header // The generated header defines ADataType, BDataType, AccDataType, CDataType - std::string dtype_a = DataTypeTraits::name; - std::string dtype_b = DataTypeTraits::name; - std::string dtype_acc = DataTypeTraits::name; - std::string dtype_c = DataTypeTraits::name; + std::string dtype_a = ck_tile::DataTypeTraits::name; + std::string dtype_b = ck_tile::DataTypeTraits::name; + std::string dtype_acc = ck_tile::DataTypeTraits::name; + std::string dtype_c = ck_tile::DataTypeTraits::name; // Layout names from the layout types std::string layout_a = ALayout::name; diff --git a/tile_engine/ops/gemm/gemm_common.hpp b/tile_engine/ops/gemm/gemm_common.hpp index 899221547f..1fdc63b33b 100644 --- a/tile_engine/ops/gemm/gemm_common.hpp +++ b/tile_engine/ops/gemm/gemm_common.hpp @@ -9,65 +9,6 @@ #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/pk_int4.hpp" -//[TODO] This can be moved to commons -// DataTypeTraits for all supported types -template -struct DataTypeTraits; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp32"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp64"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp16"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "bf16"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "bf8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "int8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "int32"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "pk_int4_t"; -}; - // Helper function to determine if a layout is row-major template constexpr auto is_row_major(Layout) diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_common.hpp b/tile_engine/ops/gemm_multi_d/gemm_multi_d_common.hpp index 899221547f..1fdc63b33b 100644 --- a/tile_engine/ops/gemm_multi_d/gemm_multi_d_common.hpp +++ b/tile_engine/ops/gemm_multi_d/gemm_multi_d_common.hpp @@ -9,65 +9,6 @@ #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/pk_int4.hpp" -//[TODO] This can be moved to commons -// DataTypeTraits for all supported types -template -struct DataTypeTraits; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp32"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp64"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp16"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "bf16"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "bf8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "int8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "int32"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "pk_int4_t"; -}; - // Helper function to determine if a layout is row-major template constexpr auto is_row_major(Layout) diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark_single.cpp b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark_single.cpp index 4fbb25f0c9..0d5de02750 100644 --- a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark_single.cpp +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark_single.cpp @@ -11,12 +11,12 @@ #include "ck_tile/core.hpp" #include "ck_tile/host.hpp" +#include "ck_tile/ops/common/utils.hpp" #include "gemm_preshuffle_profiler.hpp" #include "gemm_preshuffle_common.hpp" // The kernel header is included via the compile command line with -include flag // It defines SelectedKernel struct and KERNEL_NAME -// DataTypeTraits are now defined in gemm_common.hpp // Create argument parser inline auto create_args(int argc, char* argv[]) @@ -77,12 +77,12 @@ inline auto create_args(int argc, char* argv[]) void benchmark_single(const ck_tile::ArgParser& arg_parser) { - // Use DataTypeTraits to get the actual type names from the generated header + // Use ck_tile::DataTypeTraits to get the actual type names from the generated header // The generated header defines ADataType, BDataType, AccDataType, CDataType - std::string dtype_a = DataTypeTraits::name; - std::string dtype_b = DataTypeTraits::name; - std::string dtype_acc = DataTypeTraits::name; - std::string dtype_c = DataTypeTraits::name; + std::string dtype_a = ck_tile::DataTypeTraits::name; + std::string dtype_b = ck_tile::DataTypeTraits::name; + std::string dtype_acc = ck_tile::DataTypeTraits::name; + std::string dtype_c = ck_tile::DataTypeTraits::name; // Layout names from the layout types std::string layout_a = ALayout::name; diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp index 1b2cfe3735..bb0b8090fa 100644 --- a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp @@ -9,65 +9,6 @@ #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/pk_int4.hpp" -//[TODO] This can be moved to commons -// DataTypeTraits for all supported types -template -struct DataTypeTraits; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp32"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp64"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp16"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "bf16"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "bf8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "int8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "int32"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "pk_int4_t"; -}; - // Helper function to determine if a layout is row-major template constexpr auto is_row_major(Layout)