mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
[CK_TILE] Move DataTypeTraits into a Common File (#3146)
This renames the typeToStr struct in the common utilities to DataTypeTraits and removes all duplication of DataTypeTraits across files in CK Tile. Co-authored-by: Christopher Millette <63608002+cgmillette@users.noreply.github.com>
This commit is contained in:
@@ -281,59 +281,36 @@ struct GemmQuantTypeConfig
|
||||
using CDataType = CDataType_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct DataTypeTraits;
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<float>
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
static constexpr const char* name = "fp32";
|
||||
};
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "3840", "m dimension")
|
||||
.insert("n", "4096", "n dimension")
|
||||
.insert("k", "2048", "k dimension")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default")
|
||||
.insert("b_layout", "C", "B tensor data layout - Column by default")
|
||||
.insert("bq_layout", "C", "Bq tensor data layout - Column by default")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default")
|
||||
.insert("stride_a", "0", "Tensor A stride")
|
||||
.insert("stride_q", "0", "Tensor AQ stride")
|
||||
.insert("stride_b", "0", "Tensor B stride")
|
||||
.insert("stride_c", "0", "Tensor C stride")
|
||||
.insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
|
||||
.insert("prec",
|
||||
"fp8",
|
||||
"data type. For AQuant: fp8/bf8/i4fp8/i4bf8, For Bquant: fp8/bf8/fp8i4/bf8i4")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "1000", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("split_k", "1", "splitK value")
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
|
||||
.insert("flush_cache", "true", "flush cache before running the kernel, defaults to true")
|
||||
.insert("rotating_count", "1000", "rotating count, defaults to 1")
|
||||
.insert("quant_mode", "bquant", "Choose aquant (default), bquant, tensor or rowcol")
|
||||
.insert("group_size",
|
||||
"1x1x128",
|
||||
"Quantization group size as MxNxK, e.g., 1x1x128, 1x32x128, 1x64x128");
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<double>
|
||||
{
|
||||
static constexpr const char* name = "fp64";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<int32_t>
|
||||
{
|
||||
static constexpr const char* name = "int32";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::half_t>
|
||||
{
|
||||
static constexpr const char* name = "fp16";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf16_t>
|
||||
{
|
||||
static constexpr const char* name = "bf16";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::fp8_t>
|
||||
{
|
||||
static constexpr const char* name = "fp8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf8_t>
|
||||
{
|
||||
static constexpr const char* name = "bf8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::pk_int4_t>
|
||||
{
|
||||
static constexpr const char* name = "pk_int4_t";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::int8_t>
|
||||
{
|
||||
static constexpr const char* name = "int8";
|
||||
};
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/host/permute_pk_int4.hpp"
|
||||
#include "ck_tile/host/tensor_shuffle_utils.hpp"
|
||||
@@ -321,15 +322,15 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
{
|
||||
std::cout << " StrideBQ =" << stride_BQ;
|
||||
}
|
||||
std::cout << " A_Type = " << DataTypeTraits<typename TypeConfig::ADataType>::name
|
||||
<< " AQ_Type = " << DataTypeTraits<typename TypeConfig::QDataType>::name
|
||||
<< " B_Type = " << DataTypeTraits<typename TypeConfig::BDataType>::name;
|
||||
std::cout << " A_Type = " << ck_tile::DataTypeTraits<typename TypeConfig::ADataType>::name
|
||||
<< " AQ_Type = " << ck_tile::DataTypeTraits<typename TypeConfig::QDataType>::name
|
||||
<< " B_Type = " << ck_tile::DataTypeTraits<typename TypeConfig::BDataType>::name;
|
||||
if constexpr(!std::is_same_v<typename TypeConfig::QDataType, void>)
|
||||
{
|
||||
std::cout << " BQ_Type = " << DataTypeTraits<typename TypeConfig::QDataType>::name;
|
||||
std::cout << " BQ_Type = " << ck_tile::DataTypeTraits<typename TypeConfig::QDataType>::name;
|
||||
}
|
||||
std::cout << " Acc_Type = " << DataTypeTraits<typename TypeConfig::AccDataType>::name
|
||||
<< " C_Type = " << DataTypeTraits<typename TypeConfig::CDataType>::name
|
||||
std::cout << " Acc_Type = " << ck_tile::DataTypeTraits<typename TypeConfig::AccDataType>::name
|
||||
<< " C_Type = " << ck_tile::DataTypeTraits<typename TypeConfig::CDataType>::name
|
||||
<< " QuantMode = " << quant_type_to_string(QuantMode)
|
||||
<< " PreshuffleQuant = " << (GemmConfig::PreshuffleQuant ? "true" : "false") << " : "
|
||||
<< " PreshuffleB = " << (GemmConfig::PreshuffleB ? "true" : "false") << " : "
|
||||
|
||||
Reference in New Issue
Block a user