|
|
|
|
@@ -215,11 +215,8 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
|
|
|
|
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
|
|
|
|
const dim3 blocks = Kernel::BlockSize();
|
|
|
|
|
|
|
|
|
|
if(args.k_batch != 1)
|
|
|
|
|
{
|
|
|
|
|
throw std::runtime_error("split-k is not supported yet!");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Split-K validation is handled by Kernel::IsSupportedArgument
|
|
|
|
|
// Split-K is only supported for BQuantGrouped without preshuffle
|
|
|
|
|
if(!Kernel::IsSupportedArgument(kargs))
|
|
|
|
|
{
|
|
|
|
|
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
|
|
|
|
@@ -661,182 +658,6 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
else if(init_method == 3)
|
|
|
|
|
{
|
|
|
|
|
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
|
|
|
|
|
{
|
|
|
|
|
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(0x38)}(a_m_k);
|
|
|
|
|
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x22)}(b_k_n);
|
|
|
|
|
ck_tile::FillConstant<BQDataType>{static_cast<BQDataType>(0.5f)}(*bq_tensor_ptr);
|
|
|
|
|
}
|
|
|
|
|
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
|
|
|
|
|
{
|
|
|
|
|
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(0x38)}(a_m_k);
|
|
|
|
|
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x22)}(b_k_n);
|
|
|
|
|
ck_tile::FillConstant<AQDataType>{static_cast<AQDataType>(0.5f)}(*aq_tensor_ptr);
|
|
|
|
|
ck_tile::FillConstant<BQDataType>{static_cast<BQDataType>(0.5f)}(*bq_tensor_ptr);
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
{
|
|
|
|
|
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(0x22)}(a_m_k);
|
|
|
|
|
ck_tile::FillConstant<AQDataType>{static_cast<AQDataType>(2.0f)}(*aq_tensor_ptr);
|
|
|
|
|
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x38)}(b_k_n);
|
|
|
|
|
|
|
|
|
|
if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
|
|
|
|
|
{
|
|
|
|
|
ck_tile::FillConstant<BQDataType>{static_cast<BQDataType>(0.5f)}(*bq_tensor_ptr);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
else if(init_method == 4)
|
|
|
|
|
{
|
|
|
|
|
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
|
|
|
|
|
{
|
|
|
|
|
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
|
|
|
|
{
|
|
|
|
|
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
|
|
|
|
|
b_k_n);
|
|
|
|
|
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
|
|
|
|
*bq_tensor_ptr);
|
|
|
|
|
}
|
|
|
|
|
else if constexpr(std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>)
|
|
|
|
|
{
|
|
|
|
|
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
|
|
|
|
|
ck_tile::FillUniformDistribution<BQDataType>{125.f, 130.f, fill_seed(gen)}(
|
|
|
|
|
*bq_tensor_ptr);
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
{
|
|
|
|
|
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(b_k_n);
|
|
|
|
|
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
|
|
|
|
*bq_tensor_ptr);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f, fill_seed(gen)}(a_m_k);
|
|
|
|
|
}
|
|
|
|
|
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
|
|
|
|
|
{
|
|
|
|
|
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
|
|
|
|
|
{
|
|
|
|
|
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
|
|
|
|
|
a_m_k);
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
{
|
|
|
|
|
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 3.0f, fill_seed(gen)}(a_m_k);
|
|
|
|
|
}
|
|
|
|
|
ck_tile::FillUniformDistribution<AQDataType>{2.0f, 2.0f, fill_seed(gen)}(
|
|
|
|
|
*aq_tensor_ptr);
|
|
|
|
|
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
|
|
|
|
|
}
|
|
|
|
|
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
|
|
|
|
|
{
|
|
|
|
|
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t> ||
|
|
|
|
|
std::is_same_v<ADataType, ck_tile::pk_fp4_t>)
|
|
|
|
|
{
|
|
|
|
|
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f, fill_seed(gen)}(a_m_k);
|
|
|
|
|
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
{
|
|
|
|
|
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 3.0f, fill_seed(gen)}(a_m_k);
|
|
|
|
|
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(b_k_n);
|
|
|
|
|
}
|
|
|
|
|
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
|
|
|
|
*aq_tensor_ptr);
|
|
|
|
|
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
|
|
|
|
*bq_tensor_ptr);
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
{
|
|
|
|
|
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 2.0f, fill_seed(gen)}(a_m_k);
|
|
|
|
|
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 2.0f, fill_seed(gen)}(b_k_n);
|
|
|
|
|
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
|
|
|
|
*aq_tensor_ptr);
|
|
|
|
|
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
|
|
|
|
*bq_tensor_ptr);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
else if(init_method == 5)
|
|
|
|
|
{
|
|
|
|
|
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
|
|
|
|
|
{
|
|
|
|
|
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
|
|
|
|
{
|
|
|
|
|
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
|
|
|
|
|
b_k_n);
|
|
|
|
|
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
|
|
|
|
*bq_tensor_ptr);
|
|
|
|
|
}
|
|
|
|
|
else if constexpr(std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>)
|
|
|
|
|
{
|
|
|
|
|
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
|
|
|
|
|
ck_tile::FillUniformDistribution<BQDataType>{125.f, 130.f, fill_seed(gen)}(
|
|
|
|
|
*bq_tensor_ptr);
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
{
|
|
|
|
|
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(b_k_n);
|
|
|
|
|
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
|
|
|
|
*bq_tensor_ptr);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f, fill_seed(gen)}(a_m_k);
|
|
|
|
|
}
|
|
|
|
|
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
|
|
|
|
|
{
|
|
|
|
|
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
|
|
|
|
|
{
|
|
|
|
|
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
|
|
|
|
|
a_m_k);
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
{
|
|
|
|
|
ck_tile::FillUniformDistribution<ADataType>{1.0f, 1.0f, fill_seed(gen)}(a_m_k);
|
|
|
|
|
}
|
|
|
|
|
// Fill aquant such that column j has value 2^j (1, 2, 4, 8, ...)
|
|
|
|
|
for(ck_tile::index_t row = 0;
|
|
|
|
|
row < static_cast<ck_tile::index_t>(aq_tensor_ptr->get_length(0));
|
|
|
|
|
++row)
|
|
|
|
|
{
|
|
|
|
|
for(ck_tile::index_t col = 0;
|
|
|
|
|
col < static_cast<ck_tile::index_t>(aq_tensor_ptr->get_length(1));
|
|
|
|
|
++col)
|
|
|
|
|
{
|
|
|
|
|
(*aq_tensor_ptr)(row, col) = static_cast<AQDataType>(col + 1);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// std::cout << "aq_tensor_ptr: " << *aq_tensor_ptr << std::endl;
|
|
|
|
|
ck_tile::FillUniformDistribution<BDataType>{1.0f, 1.0f, fill_seed(gen)}(b_k_n);
|
|
|
|
|
}
|
|
|
|
|
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
|
|
|
|
|
{
|
|
|
|
|
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t> ||
|
|
|
|
|
std::is_same_v<ADataType, ck_tile::pk_fp4_t>)
|
|
|
|
|
{
|
|
|
|
|
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f, fill_seed(gen)}(a_m_k);
|
|
|
|
|
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
{
|
|
|
|
|
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 3.0f, fill_seed(gen)}(a_m_k);
|
|
|
|
|
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(b_k_n);
|
|
|
|
|
}
|
|
|
|
|
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
|
|
|
|
*aq_tensor_ptr);
|
|
|
|
|
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
|
|
|
|
*bq_tensor_ptr);
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
{
|
|
|
|
|
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 2.0f, fill_seed(gen)}(a_m_k);
|
|
|
|
|
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 2.0f, fill_seed(gen)}(b_k_n);
|
|
|
|
|
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
|
|
|
|
*aq_tensor_ptr);
|
|
|
|
|
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
|
|
|
|
*bq_tensor_ptr);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
{
|
|
|
|
|
a_m_k.SetZero();
|
|
|
|
|
|