mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
fixing CI failures for grouped quant gemm
This commit is contained in:
@@ -29,6 +29,7 @@ template <typename GemmConfig,
|
||||
typename BQDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename QuantGroupSize,
|
||||
ck_tile::QuantType QuantMode = ck_tile::QuantType::BQuantGrouped>
|
||||
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
const ck_tile::index_t num_groups,
|
||||
@@ -75,7 +76,7 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
128>, // QuantGroupSize
|
||||
QuantGroupSize>, // QuantGroupSize
|
||||
ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
|
||||
@@ -43,6 +43,7 @@ template <typename GemmConfig,
|
||||
typename BLayout,
|
||||
typename BQLayout,
|
||||
typename CLayout,
|
||||
typename QuantGroupSize,
|
||||
ck_tile::QuantType QuantMode = ck_tile::QuantType::BQuantGrouped,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float invoke_gemm(int n_warmup,
|
||||
@@ -104,6 +105,7 @@ float invoke_gemm(int n_warmup,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
QuantGroupSize,
|
||||
QuantMode>(stream, group_count, kargs_ptr);
|
||||
|
||||
std::string op_name = "Quant Grouped Gemm (" + ck_tile::quant_type_to_string(QuantMode) + ")";
|
||||
@@ -134,6 +136,7 @@ template <typename GemmConfig,
|
||||
typename BQDataType,
|
||||
typename CDataType,
|
||||
typename AccDataType,
|
||||
typename QuantGroupSize,
|
||||
ck_tile::QuantType QuantMode,
|
||||
typename ALayout,
|
||||
typename AQLayout,
|
||||
@@ -159,13 +162,12 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
return group_count != 0 && ((args.size() == static_cast<size_t>(group_count)) && ...);
|
||||
};
|
||||
|
||||
const int group_count = arg_parser.get_int("group_count");
|
||||
const int repeat = arg_parser.get_int("repeat");
|
||||
const int warmup = arg_parser.get_int("warmup");
|
||||
const int kbatch = arg_parser.get_int("kbatch");
|
||||
const int init_method = arg_parser.get_int("init");
|
||||
bool validate = arg_parser.get_bool("validate");
|
||||
const ck_tile::index_t QuantGroupSize = 128;
|
||||
const int group_count = arg_parser.get_int("group_count");
|
||||
const int repeat = arg_parser.get_int("repeat");
|
||||
const int warmup = arg_parser.get_int("warmup");
|
||||
const int kbatch = arg_parser.get_int("kbatch");
|
||||
const int init_method = arg_parser.get_int("init");
|
||||
bool validate = arg_parser.get_bool("validate");
|
||||
|
||||
if(kbatch > 1 && validate && warmup + repeat > 1)
|
||||
{
|
||||
@@ -259,9 +261,9 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
AQK = 0; // No A quantization
|
||||
BQK = K / QuantGroupSize; // Group quantization: BQK = K / GroupSize
|
||||
if(K % QuantGroupSize != 0)
|
||||
AQK = 0; // No A quantization
|
||||
BQK = K / QuantGroupSize::kK; // Group quantization: BQK = K / GroupSize
|
||||
if(K % QuantGroupSize::kK != 0)
|
||||
{
|
||||
throw std::runtime_error("K must be divisible by 128 for BQuantGrouped mode");
|
||||
}
|
||||
@@ -400,6 +402,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
BLayout,
|
||||
BQLayout,
|
||||
CLayout,
|
||||
QuantGroupSize,
|
||||
QuantMode>(warmup, repeat, group_count, gemm_descs);
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
@@ -481,12 +484,14 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using Types = GemmTypeConfig<PrecType>;
|
||||
// Specific type aliases for easy access
|
||||
using ADataType = typename Types::ADataType;
|
||||
using BDataType = typename Types::BDataType;
|
||||
using AccDataType = typename Types::AccDataType;
|
||||
using CDataType = typename Types::CDataType;
|
||||
using AQDataType = typename Types::AccDataType;
|
||||
using BQDataType = typename Types::AccDataType;
|
||||
using ADataType = typename Types::ADataType;
|
||||
using BDataType = typename Types::BDataType;
|
||||
using AccDataType = typename Types::AccDataType;
|
||||
using CDataType = typename Types::CDataType;
|
||||
using AQDataType = typename Types::AccDataType;
|
||||
using BQDataType = typename Types::AccDataType;
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_grouped_gemm_example_with_layouts<GemmConfig,
|
||||
@@ -496,6 +501,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
|
||||
BQDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
QuantGroupSize,
|
||||
QuantMode>(
|
||||
argc, argv, Row{}, Row{}, Col{}, Col{}, Row{});
|
||||
}
|
||||
|
||||
@@ -390,6 +390,7 @@ struct QuantGroupedGemmKernel
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window,
|
||||
b_block_window,
|
||||
bq_block_window,
|
||||
kargs.N,
|
||||
num_loop,
|
||||
tail_num,
|
||||
smem_ptr_0,
|
||||
|
||||
Reference in New Issue
Block a user