mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 04:19:36 +00:00
Merge commit '775b96ea6a8bb0d82d635dc1a396c8d98091c832' into develop
This commit is contained in:
@@ -28,7 +28,8 @@ template <typename GemmConfig,
|
||||
typename BDataType,
|
||||
typename BQDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType>
|
||||
typename CDataType,
|
||||
ck_tile::QuantType QuantMode>
|
||||
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
const ck_tile::index_t num_groups,
|
||||
void* kargs_ptr)
|
||||
@@ -44,19 +45,20 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
using TilePartitioner = ck_tile::
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
constexpr ck_tile::QuantType QuantMode = ck_tile::QuantType::RowColQuant;
|
||||
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
false,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
QuantMode,
|
||||
AQLayout,
|
||||
BQLayout,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
true>;
|
||||
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
false,
|
||||
false,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
QuantMode,
|
||||
AQLayout,
|
||||
BQLayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
true>;
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
|
||||
@@ -11,12 +11,6 @@
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V3 1
|
||||
#define CK_TILE_PIPELINE_MEMORY 2
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V4 3
|
||||
|
||||
#ifndef CK_TILE_PIPELINE_DEFAULT
|
||||
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3
|
||||
#endif
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
@@ -66,7 +60,6 @@ struct GemmConfigBase
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool Preshuffle = false;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -102,15 +95,6 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<PipelineProblem>;
|
||||
};
|
||||
|
||||
using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs;
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
@@ -119,7 +103,12 @@ auto create_args(int argc, char* argv[])
|
||||
arg_parser.insert("Ms", "", "M dimensions - empty by default.")
|
||||
.insert("Ns", "", "N dimensions - empty by default.")
|
||||
.insert("Ks", "", "K dimensions - empty by default.")
|
||||
.insert("stride_As", "", "Tensor A strides - it is empty by default.")
|
||||
.insert(
|
||||
"stride_As",
|
||||
"",
|
||||
"Tensor A strides - it is empty by default.") // stride_As/stride_Bs/stride_Cs/stride_AQs/stride_BQs
|
||||
// can be set to zero if
|
||||
// Ms/Ns/Ks is not empty
|
||||
.insert("stride_Bs", "", "Tensor B strides - it is empty by default.")
|
||||
.insert("stride_Cs", "", "Tensor C strides - it is empty by default.")
|
||||
.insert("stride_AQs", "", "Tensor AQ strides - it is empty by default.")
|
||||
@@ -132,7 +121,9 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("warmup", "10", "number of iterations before benchmark the kernel.")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel.")
|
||||
.insert("group_count", "8", "group count.")
|
||||
.insert("kbatch", "1", "kbatch for SplitK");
|
||||
.insert("kbatch", "1", "kbatch for SplitK")
|
||||
.insert("quant_mode", "tensor", "Choose tensor (default), or rowcol");
|
||||
;
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
@@ -145,13 +136,17 @@ inline std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gem
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ALayout,
|
||||
typename AQLayout,
|
||||
typename BLayout,
|
||||
typename BQLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename AQDataType,
|
||||
typename BDataType,
|
||||
typename BQDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType>
|
||||
typename CDataType,
|
||||
ck_tile::QuantType QuantMode>
|
||||
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
const ck_tile::index_t num_groups,
|
||||
void* kargs_ptr,
|
||||
bool splitk = false);
|
||||
void* kargs_ptr);
|
||||
|
||||
@@ -43,6 +43,7 @@ template <typename GemmConfig,
|
||||
typename BLayout,
|
||||
typename BQLayout,
|
||||
typename CLayout,
|
||||
ck_tile::QuantType QuantMode,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float invoke_gemm(int n_warmup,
|
||||
int n_repeat,
|
||||
@@ -102,9 +103,10 @@ float invoke_gemm(int n_warmup,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
CDataType>(stream, group_count, kargs_ptr);
|
||||
CDataType,
|
||||
QuantMode>(stream, group_count, kargs_ptr);
|
||||
|
||||
std::string op_name{"Grouped Gemm"};
|
||||
std::string op_name = "Quant Grouped Gemm (" + ck_tile::quant_type_to_string(QuantMode) + ")";
|
||||
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
for(int j = 0; j < group_count; ++j)
|
||||
@@ -132,6 +134,7 @@ template <typename GemmConfig,
|
||||
typename BQDataType,
|
||||
typename CDataType,
|
||||
typename AccDataType,
|
||||
ck_tile::QuantType QuantMode,
|
||||
typename ALayout,
|
||||
typename AQLayout,
|
||||
typename BLayout,
|
||||
@@ -153,7 +156,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
};
|
||||
|
||||
auto valid_input_data = [&](int group_count, const auto&... args) {
|
||||
return !(args.empty() || ...) && group_count == (args.size() == ...);
|
||||
return group_count != 0 && ((args.size() == static_cast<size_t>(group_count)) && ...);
|
||||
};
|
||||
|
||||
const int group_count = arg_parser.get_int("group_count");
|
||||
@@ -180,7 +183,8 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
|
||||
ck_tile::index_t AQK, BQK;
|
||||
|
||||
if(!valid_input_data(group_count, Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs))
|
||||
if(!valid_input_data(
|
||||
group_count, Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, stride_AQs, stride_BQs))
|
||||
{
|
||||
std::cout << "Please check the input data. Default values will be used." << std::endl;
|
||||
|
||||
@@ -242,25 +246,49 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
const ck_tile::index_t M = Ms[i];
|
||||
const ck_tile::index_t N = Ns[i];
|
||||
const ck_tile::index_t K = Ks[i];
|
||||
if constexpr(QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
QuantMode == ck_tile::QuantType::TensorQuant)
|
||||
{
|
||||
AQK = 1; // Row quantization: tensor shape [M, 1] or [1]
|
||||
BQK = 1; // Column quantization: tensor shape [1, N] or [1]
|
||||
}
|
||||
|
||||
AQK = 1; // Row quantization: tensor shape [M, 1]. Only for NT
|
||||
BQK = N; // Column quantization: tensor shape [1, N]. Only for NT
|
||||
stride_As[i] = ck_tile::get_default_stride(M, K, stride_As[i], is_row_major(a_layout));
|
||||
stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(b_layout));
|
||||
stride_Cs[i] = ck_tile::get_default_stride(M, N, stride_Cs[i], is_row_major(CLayout{}));
|
||||
if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
{
|
||||
stride_AQs[i] =
|
||||
ck_tile::get_default_stride(M, 1, stride_AQs[i], is_row_major(aq_layout));
|
||||
stride_BQs[i] =
|
||||
ck_tile::get_default_stride(1, N, stride_BQs[i], is_row_major(bq_layout));
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant)
|
||||
{
|
||||
stride_AQs[i] = 1; // Tensor quantization: tensor shape [1]
|
||||
stride_BQs[i] = 1; // Tensor quantization: tensor shape [1]
|
||||
}
|
||||
|
||||
stride_As[i] = ck_tile::get_default_stride(M, K, stride_As[i], is_row_major(a_layout));
|
||||
stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(b_layout));
|
||||
stride_Cs[i] = ck_tile::get_default_stride(M, N, stride_Cs[i], is_row_major(CLayout{}));
|
||||
stride_AQs[i] = ck_tile::get_default_stride(M, AQK, stride_AQs[i], is_row_major(aq_layout));
|
||||
stride_BQs[i] = ck_tile::get_default_stride(1, N, stride_BQs[i], is_row_major(bq_layout));
|
||||
a_m_k_tensors.push_back(ck_tile::HostTensor<ADataType>(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(a_layout))));
|
||||
b_k_n_tensors.push_back(ck_tile::HostTensor<BDataType>(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_Bs[i], is_row_major(b_layout))));
|
||||
c_m_n_tensors.push_back(ck_tile::HostTensor<CDataType>(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_Cs[i], is_row_major(CLayout{}))));
|
||||
aq_tensors.push_back(ck_tile::HostTensor<AQDataType>(
|
||||
ck_tile::host_tensor_descriptor(M, AQK, stride_AQs[i], is_row_major(aq_layout))));
|
||||
bq_tensors.push_back(ck_tile::HostTensor<BQDataType>(
|
||||
ck_tile::host_tensor_descriptor(1, N, stride_BQs[i], is_row_major(bq_layout))));
|
||||
if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
{
|
||||
aq_tensors.push_back(ck_tile::HostTensor<AQDataType>(
|
||||
ck_tile::host_tensor_descriptor(M, AQK, stride_AQs[i], is_row_major(aq_layout))));
|
||||
bq_tensors.push_back(ck_tile::HostTensor<BQDataType>(
|
||||
ck_tile::host_tensor_descriptor(BQK, N, stride_BQs[i], is_row_major(bq_layout))));
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant)
|
||||
{
|
||||
aq_tensors.push_back(ck_tile::HostTensor<AQDataType>(
|
||||
ck_tile::host_tensor_descriptor(1, 1, stride_AQs[i], is_row_major(aq_layout))));
|
||||
bq_tensors.push_back(ck_tile::HostTensor<BQDataType>(
|
||||
ck_tile::host_tensor_descriptor(1, 1, stride_BQs[i], is_row_major(bq_layout))));
|
||||
}
|
||||
|
||||
std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc
|
||||
<< " b_k_n: " << b_k_n_tensors[i].mDesc << " c_m_n: " << c_m_n_tensors[i].mDesc
|
||||
@@ -324,7 +352,8 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
AQLayout,
|
||||
BLayout,
|
||||
BQLayout,
|
||||
CLayout>(warmup, repeat, group_count, gemm_descs);
|
||||
CLayout,
|
||||
QuantMode>(warmup, repeat, group_count, gemm_descs);
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
@@ -339,13 +368,33 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(ck_tile::host_tensor_descriptor(
|
||||
Ms[i], Ns[i], stride_Cs[i], is_row_major(CLayout{})));
|
||||
c_m_n_host_ref.SetZero();
|
||||
ck_tile::reference_gemm_rowcol_quant<ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
CDataType>(
|
||||
a_m_k_tensors[i], aq_tensors[i], b_k_n_tensors[i], bq_tensors[i], c_m_n_host_ref);
|
||||
if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
{
|
||||
ck_tile::reference_gemm_rowcol_quant<ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
CDataType>(a_m_k_tensors[i],
|
||||
aq_tensors[i],
|
||||
b_k_n_tensors[i],
|
||||
bq_tensors[i],
|
||||
c_m_n_host_ref);
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant)
|
||||
{
|
||||
ck_tile::reference_gemm_tensor_quant<ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
CDataType>(a_m_k_tensors[i],
|
||||
aq_tensors[i],
|
||||
b_k_n_tensors[i],
|
||||
bq_tensors[i],
|
||||
c_m_n_host_ref);
|
||||
}
|
||||
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol =
|
||||
@@ -367,7 +416,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
return pass;
|
||||
}
|
||||
|
||||
template <typename GemmConfig, typename PrecType>
|
||||
template <typename GemmConfig, typename PrecType, ck_tile::QuantType QuantMode>
|
||||
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
|
||||
{
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
@@ -388,7 +437,8 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
|
||||
BDataType,
|
||||
BQDataType,
|
||||
CDataType,
|
||||
AccDataType>(
|
||||
AccDataType,
|
||||
QuantMode>(
|
||||
argc, argv, Row{}, Row{}, Col{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "R" && b_layout == "R")
|
||||
@@ -399,8 +449,9 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
|
||||
BDataType,
|
||||
BQDataType,
|
||||
CDataType,
|
||||
AccDataType>(
|
||||
argc, argv, Row{}, Row{}, Row{}, Row{}, Row{});
|
||||
AccDataType,
|
||||
QuantMode>(
|
||||
argc, argv, Row{}, Row{}, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "R")
|
||||
{
|
||||
@@ -410,7 +461,8 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
|
||||
BDataType,
|
||||
BQDataType,
|
||||
CDataType,
|
||||
AccDataType>(
|
||||
AccDataType,
|
||||
QuantMode>(
|
||||
argc, argv, Row{}, Row{}, Col{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "C")
|
||||
@@ -421,7 +473,8 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
|
||||
BDataType,
|
||||
BQDataType,
|
||||
CDataType,
|
||||
AccDataType>(
|
||||
AccDataType,
|
||||
QuantMode>(
|
||||
argc, argv, Col{}, Col{}, Col{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
@@ -442,11 +495,28 @@ int run_grouped_gemm_example(int argc, char* argv[])
|
||||
const std::string a_layout = arg_parser.get_str("a_layout");
|
||||
const std::string b_layout = arg_parser.get_str("b_layout");
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
std::string quant_mode = arg_parser.get_str("quant_mode");
|
||||
|
||||
if(data_type == "fp8")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, ck_tile::fp8_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
if(quant_mode == "tensor")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::QuantType::TensorQuant>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "rowcol")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::QuantType::RowColQuant>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported quantization mode!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -143,7 +143,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
|
||||
auto valid_input_data = [&](int group_count, const auto&... args) {
|
||||
return !(args.empty() || ...) && group_count == (args.size() == ...);
|
||||
return group_count != 0 && ((args.size() == static_cast<size_t>(group_count)) && ...);
|
||||
};
|
||||
|
||||
const int group_count = arg_parser.get_int("group_count");
|
||||
|
||||
@@ -159,7 +159,7 @@ int run_grouped_gemm_multi_d_example_with_layouts(int argc,
|
||||
using DsDataType = ck_tile::tuple<D0DataType, D1DataType>;
|
||||
|
||||
auto valid_input_data = [&](int group_count, const auto&... args) {
|
||||
return !(args.empty() || ...) && group_count == (args.size() == ...);
|
||||
return group_count != 0 && ((args.size() == static_cast<size_t>(group_count)) && ...);
|
||||
};
|
||||
|
||||
const int group_count = arg_parser.get_int("group_count");
|
||||
|
||||
Reference in New Issue
Block a user