mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 19:09:59 +00:00
Merge commit '331273b4747c9beebed5653e38f32ebede9f539b' into develop
This commit is contained in:
@@ -52,7 +52,7 @@ struct kernel
|
||||
template <class... Ts>
|
||||
auto launch(hipStream_t stream, std::size_t global, std::size_t local, Ts... zs) const
|
||||
{
|
||||
return [=](auto&&... xs) {
|
||||
return [=, this](auto&&... xs) {
|
||||
launch(stream, global, local, std::vector<kernel_argument>{xs...}, zs...);
|
||||
};
|
||||
}
|
||||
|
||||
@@ -29,7 +29,7 @@ template <typename GemmConfig,
|
||||
typename BQDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
ck_tile::QuantType QuantMode>
|
||||
ck_tile::QuantType QuantMode = ck_tile::QuantType::BQuantGrouped>
|
||||
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
const ck_tile::index_t num_groups,
|
||||
void* kargs_ptr)
|
||||
@@ -48,8 +48,8 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
false,
|
||||
false,
|
||||
false, // PreshuffleQuant
|
||||
false, // PreshuffleB
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
@@ -67,18 +67,29 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
constexpr bool transpose_c = false;
|
||||
|
||||
using QuantGemmProblem = ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
transpose_c,
|
||||
BDataType,
|
||||
scheduler>;
|
||||
using QuantGemmProblem = typename std::conditional<
|
||||
QuantMode == ck_tile::QuantType::BQuantGrouped,
|
||||
ck_tile::GemmBQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
128>, // QuantGroupSize
|
||||
ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
transpose_c,
|
||||
BDataType,
|
||||
scheduler>>::type;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<QuantGemmProblem>;
|
||||
using GemmPipeline =
|
||||
typename std::conditional<QuantMode == ck_tile::QuantType::BQuantGrouped,
|
||||
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>,
|
||||
ck_tile::GemmPipelineAgBgCrCompV3<QuantGemmProblem>>::type;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V3 1
|
||||
#define CK_TILE_PIPELINE_BQUANT_COMPUTE_V3 2
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
@@ -41,6 +42,14 @@ struct GemmTypeConfig<ck_tile::fp8_t>
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
template <>
|
||||
struct GemmTypeConfig<ck_tile::bf8_t>
|
||||
{
|
||||
using ADataType = ck_tile::bf8_t;
|
||||
using BDataType = ck_tile::bf8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
struct GemmConfigBase
|
||||
{
|
||||
@@ -77,24 +86,11 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
};
|
||||
|
||||
template <ck_tile::index_t PipelineId>
|
||||
struct PipelineTypeTraits;
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
};
|
||||
|
||||
using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs;
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
@@ -122,8 +118,7 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel.")
|
||||
.insert("group_count", "8", "group count.")
|
||||
.insert("kbatch", "1", "kbatch for SplitK")
|
||||
.insert("quant_mode", "tensor", "Choose tensor (default), or rowcol");
|
||||
;
|
||||
.insert("quant_mode", "bquant", "Choose bquant (default), tensor, or rowcol");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
|
||||
@@ -43,8 +43,8 @@ template <typename GemmConfig,
|
||||
typename BLayout,
|
||||
typename BQLayout,
|
||||
typename CLayout,
|
||||
ck_tile::QuantType QuantMode,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
ck_tile::QuantType QuantMode = ck_tile::QuantType::BQuantGrouped,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float invoke_gemm(int n_warmup,
|
||||
int n_repeat,
|
||||
int group_count,
|
||||
@@ -159,11 +159,12 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
return group_count != 0 && ((args.size() == static_cast<size_t>(group_count)) && ...);
|
||||
};
|
||||
|
||||
const int group_count = arg_parser.get_int("group_count");
|
||||
const int repeat = arg_parser.get_int("repeat");
|
||||
const int warmup = arg_parser.get_int("warmup");
|
||||
const int kbatch = arg_parser.get_int("kbatch");
|
||||
bool validate = arg_parser.get_bool("validate");
|
||||
const int group_count = arg_parser.get_int("group_count");
|
||||
const int repeat = arg_parser.get_int("repeat");
|
||||
const int warmup = arg_parser.get_int("warmup");
|
||||
const int kbatch = arg_parser.get_int("kbatch");
|
||||
bool validate = arg_parser.get_bool("validate");
|
||||
const ck_tile::index_t QuantGroupSize = 128;
|
||||
|
||||
if(kbatch > 1 && validate && warmup + repeat > 1)
|
||||
{
|
||||
@@ -172,9 +173,11 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
validate = false;
|
||||
}
|
||||
|
||||
std::vector<ck_tile::index_t> Ms = arg_parser.get_int_vec("Ms");
|
||||
std::vector<ck_tile::index_t> Ns = arg_parser.get_int_vec("Ns");
|
||||
std::vector<ck_tile::index_t> Ks = arg_parser.get_int_vec("Ks");
|
||||
std::vector<ck_tile::index_t> Ms = arg_parser.get_int_vec("Ms");
|
||||
std::vector<ck_tile::index_t> Ns = arg_parser.get_int_vec("Ns");
|
||||
std::vector<ck_tile::index_t> Ks = arg_parser.get_int_vec("Ks");
|
||||
std::vector<ck_tile::index_t> AQs; // dimension of AQ tensor is calculated from A tensor
|
||||
std::vector<ck_tile::index_t> BQs; // dimension of BQ tensor is calculated from B tensor
|
||||
std::vector<ck_tile::index_t> stride_As = arg_parser.get_int_vec("stride_As");
|
||||
std::vector<ck_tile::index_t> stride_Bs = arg_parser.get_int_vec("stride_Bs");
|
||||
std::vector<ck_tile::index_t> stride_Cs = arg_parser.get_int_vec("stride_Cs");
|
||||
@@ -252,6 +255,15 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
AQK = 1; // Row quantization: tensor shape [M, 1] or [1]
|
||||
BQK = 1; // Column quantization: tensor shape [1, N] or [1]
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
AQK = 0; // No A quantization
|
||||
BQK = K / QuantGroupSize; // Group quantization: BQK = K / GroupSize
|
||||
if(K % QuantGroupSize != 0)
|
||||
{
|
||||
throw std::runtime_error("K must be divisible by 128 for BQuantGrouped mode");
|
||||
}
|
||||
}
|
||||
|
||||
stride_As[i] = ck_tile::get_default_stride(M, K, stride_As[i], is_row_major(a_layout));
|
||||
stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(b_layout));
|
||||
@@ -289,6 +301,13 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
bq_tensors.push_back(ck_tile::HostTensor<BQDataType>(
|
||||
ck_tile::host_tensor_descriptor(1, 1, stride_BQs[i], is_row_major(bq_layout))));
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
aq_tensors.push_back(ck_tile::HostTensor<AQDataType>(
|
||||
ck_tile::host_tensor_descriptor(0, AQK, stride_AQs[i], is_row_major(aq_layout))));
|
||||
bq_tensors.push_back(ck_tile::HostTensor<BQDataType>(
|
||||
ck_tile::host_tensor_descriptor(BQK, N, stride_BQs[i], is_row_major(bq_layout))));
|
||||
}
|
||||
|
||||
std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc
|
||||
<< " b_k_n: " << b_k_n_tensors[i].mDesc << " c_m_n: " << c_m_n_tensors[i].mDesc
|
||||
@@ -394,6 +413,17 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
bq_tensors[i],
|
||||
c_m_n_host_ref);
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
ck_tile::reference_gemm_quant<ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
QuantGroupSize,
|
||||
false>(
|
||||
a_m_k_tensors[i], bq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
|
||||
}
|
||||
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
@@ -441,42 +471,6 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
|
||||
QuantMode>(
|
||||
argc, argv, Row{}, Row{}, Col{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "R" && b_layout == "R")
|
||||
{
|
||||
return run_grouped_gemm_example_with_layouts<GemmConfig,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
QuantMode>(
|
||||
argc, argv, Row{}, Row{}, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "R")
|
||||
{
|
||||
return run_grouped_gemm_example_with_layouts<GemmConfig,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
QuantMode>(
|
||||
argc, argv, Row{}, Row{}, Col{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "C")
|
||||
{
|
||||
return run_grouped_gemm_example_with_layouts<GemmConfig,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
QuantMode>(
|
||||
argc, argv, Col{}, Col{}, Col{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
|
||||
@@ -513,6 +507,41 @@ int run_grouped_gemm_example(int argc, char* argv[])
|
||||
ck_tile::QuantType::RowColQuant>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "bquant")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::QuantType::BQuantGrouped>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported quantization mode!");
|
||||
}
|
||||
}
|
||||
if(data_type == "bf8")
|
||||
{
|
||||
if(quant_mode == "tensor")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::QuantType::TensorQuant>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "rowcol")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::QuantType::RowColQuant>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "bquant")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::QuantType::BQuantGrouped>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported quantization mode!");
|
||||
|
||||
@@ -16,10 +16,17 @@ __device__ void llvm_amdgcn_s_wait_dscnt(short cnt) __asm("llvm.amdgcn.s.wait.ds
|
||||
__device__ void block_sync_lds()
|
||||
{
|
||||
#if CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
|
||||
#ifdef __gfx12__
|
||||
#if defined(__gfx12__)
|
||||
llvm_amdgcn_s_wait_dscnt(0);
|
||||
asm volatile("s_barrier_signal -1\n\t"
|
||||
"s_barrier_wait -1");
|
||||
#elif defined(__gfx11__)
|
||||
// asm volatile("\
|
||||
// s_waitcnt lgkmcnt(0) \n \
|
||||
// s_barrier \
|
||||
// " ::);
|
||||
__builtin_amdgcn_s_waitcnt(0xfc07);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
#else
|
||||
// asm volatile("\
|
||||
// s_waitcnt lgkmcnt(0) \n \
|
||||
|
||||
@@ -185,14 +185,6 @@ struct GemmKernelMultiABD
|
||||
{
|
||||
return false;
|
||||
}
|
||||
// Currently MultiABD kernel doesn't support F8 data type
|
||||
if(ck_tile::get_device_name() == "gfx950" &&
|
||||
(std::is_same<ck_tile::fp8_t, ADataType>::value ||
|
||||
std::is_same<ck_tile::fp8_t, BDataType>::value ||
|
||||
std::is_same<ck_tile::fp8_t, DDataType>::value))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return UniversalGemmKernel::IsSupportedArgument(kargs);
|
||||
}
|
||||
|
||||
@@ -375,30 +375,48 @@ struct QuantGroupedGemmKernel
|
||||
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
|
||||
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
// Run GEMM pipeline
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0);
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(Base::I4);
|
||||
if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
if constexpr(kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
const auto& aq_block_window = gemm_tile_windows.at(Base::I1);
|
||||
const auto& bq_block_window = gemm_tile_windows.at(Base::I3);
|
||||
EpiloguePipeline{}.template
|
||||
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(c_block_window)>(
|
||||
c_block_window,
|
||||
c_block_tile,
|
||||
c_block_window,
|
||||
smem_ptr_0,
|
||||
aq_block_window,
|
||||
bq_block_window);
|
||||
// Run GEMM pipeline
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window,
|
||||
b_block_window,
|
||||
bq_block_window,
|
||||
num_loop,
|
||||
has_hot_loop,
|
||||
tail_num,
|
||||
smem_ptr_0);
|
||||
|
||||
auto& c_block_window = gemm_tile_windows.at(Base::I4);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::TensorQuant)
|
||||
else
|
||||
{
|
||||
const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
|
||||
const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
|
||||
EpiloguePipeline{}(
|
||||
c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale);
|
||||
// Run GEMM pipeline
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0);
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(Base::I4);
|
||||
if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
{
|
||||
const auto& aq_block_window = gemm_tile_windows.at(Base::I1);
|
||||
const auto& bq_block_window = gemm_tile_windows.at(Base::I3);
|
||||
EpiloguePipeline{}(c_block_window,
|
||||
c_block_tile,
|
||||
c_block_window,
|
||||
smem_ptr_0,
|
||||
aq_block_window,
|
||||
bq_block_window);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::TensorQuant)
|
||||
{
|
||||
const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
|
||||
const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
|
||||
EpiloguePipeline{}(
|
||||
c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -472,6 +472,49 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
|
||||
num_loop,
|
||||
p_smem);
|
||||
}
|
||||
|
||||
/// @brief Runtime pipeline dispatch operator for grouped GEMM kernels.
|
||||
///
|
||||
/// This operator is used by grouped GEMM kernels where pipeline parameters
|
||||
/// (has_hot_loop, num_loop, tail_number) are calculated on the device side
|
||||
/// at runtime, not on the host side during compilation. This is necessary
|
||||
/// because different GEMM problems in the group may have different K dimensions,
|
||||
/// requiring different pipeline configurations that cannot be determined at
|
||||
/// compile time.
|
||||
///
|
||||
/// @param a_dram_block_window_tmp Block window for A tensor in DRAM
|
||||
/// @param b_dram_block_window_tmp Block window for B tensor in DRAM
|
||||
/// @param bq_dram_block_window_tmp Block window for BQ (quantization scale) tensor in DRAM
|
||||
/// @param num_loop Number of main loop iterations (calculated on device)
|
||||
/// @param has_hot_loop Whether the pipeline has a hot loop (calculated on device)
|
||||
/// @param tail_number Type of tail handling required (calculated on device)
|
||||
/// @param p_smem Pointer to shared memory
|
||||
/// @return Accumulated result tile in registers
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename BQDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
bool has_hot_loop,
|
||||
TailNumber tail_number,
|
||||
void* p_smem) const
|
||||
{
|
||||
const auto RunPipeline = [&](auto has_hot_loop_, auto tail_number_) {
|
||||
constexpr bool hot_loop = has_hot_loop_.value;
|
||||
constexpr auto tail_num = tail_number_.value;
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop, tail_num>(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](const BDataType& b) { return b; },
|
||||
bq_dram_block_window_tmp,
|
||||
num_loop,
|
||||
p_smem);
|
||||
};
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -20,20 +20,19 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using KernelTypes = ::testing::Types<
|
||||
// Has cshuffle epilogue enabled
|
||||
// A0Layout, A1Layout, B0Layout, B1Layout CLayout, D0Layout, D1Layout, A0DataType, A01DataType B0DataType, B0DataType, D0DataType, D1DataType, AccDataType, EDataType, AElementWiseFn, BElementWiseFn, CDElementWiseFn, UseCshuffleEpilog
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
|
||||
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, MultiplyMultiply, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, MultiplyMultiply, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>
|
||||
|
||||
// Currently MultiABD kernel doesn't support F8 data type
|
||||
//std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
|
||||
//std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -20,19 +20,17 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using KernelTypes = ::testing::Types<
|
||||
// Has cshuffle epilogue disabled
|
||||
// A0Layout, A1Layout, B0Layout, B1Layout CLayout, D0Layout, D1Layout, A0DataType, A01DataType B0DataType, B0DataType, D0DataType, D1DataType, AccDataType, EDataType, AElementWiseFn, BElementWiseFn, CDElementWiseFn, UseCshuffleEpilog
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
|
||||
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, MultiplyMultiply, std::false_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, MultiplyMultiply, std::false_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, MultiplyMultiply, std::false_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type>
|
||||
|
||||
// Currently MultiABD kernel doesn't support F8 data type
|
||||
//std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
|
||||
//std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -1,5 +1,95 @@
|
||||
#pragma once
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_256x512x256)
|
||||
{
|
||||
constexpr int M = 256;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x256x256)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x512x256)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_256x256x256)
|
||||
{
|
||||
constexpr int M = 256;
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x768x256)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 768;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x1280x256)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 1280;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_256x1280x256)
|
||||
{
|
||||
constexpr int M = 256;
|
||||
constexpr int N = 1280;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_768x512x256)
|
||||
{
|
||||
constexpr int M = 768;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_1280x512x256)
|
||||
{
|
||||
constexpr int M = 1280;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_512x512x512)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
|
||||
@@ -13,40 +13,9 @@
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
struct AddScale
|
||||
{
|
||||
template <typename E, typename A0, typename A1>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(E& a, const A0& a0, const A1& a1) const
|
||||
{
|
||||
a = scale * (ck_tile::type_convert<float>(a0) + ck_tile::type_convert<float>(a1));
|
||||
}
|
||||
|
||||
float scale = 1.0;
|
||||
};
|
||||
|
||||
struct MultiplyMultiply
|
||||
{
|
||||
template <typename E, typename C, typename D0, typename D1>
|
||||
CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const D0& d0, const D1& d1) const -> void
|
||||
{
|
||||
const float x0_f = ck_tile::type_convert<float>(c) * ck_tile::type_convert<float>(d0) *
|
||||
ck_tile::type_convert<float>(d1);
|
||||
|
||||
e = ck_tile::type_convert<E>(x0_f);
|
||||
}
|
||||
};
|
||||
|
||||
struct ElementWiseAddAdd
|
||||
{
|
||||
template <typename E, typename C, typename D0, typename D1>
|
||||
CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const D0& d0, const D1& d1) const -> void
|
||||
{
|
||||
const float x0_f = ck_tile::type_convert<float>(c) + ck_tile::type_convert<float>(d0) +
|
||||
ck_tile::type_convert<float>(d1);
|
||||
|
||||
e = ck_tile::type_convert<E>(x0_f);
|
||||
}
|
||||
};
|
||||
using AddScale = ck_tile::element_wise::AddScale;
|
||||
using ElementWiseAddAdd = ck_tile::element_wise::MultiDAdd;
|
||||
using MultiplyMultiply = ck_tile::element_wise::MultiDMultiply;
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
|
||||
@@ -18,6 +18,7 @@ using True = ck_tile::bool_constant<true>;
|
||||
using False = ck_tile::bool_constant<false>;
|
||||
using RowColQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::RowColQuant>;
|
||||
using TensorQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::TensorQuant>;
|
||||
using BQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
@@ -31,16 +32,16 @@ using KernelTypes = ::testing::Types<
|
||||
std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant>,
|
||||
std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant>,
|
||||
std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant>,
|
||||
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant>,
|
||||
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant>,
|
||||
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant>,
|
||||
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant>,
|
||||
|
||||
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant>,
|
||||
std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant>,
|
||||
std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant>,
|
||||
std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant>
|
||||
std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant>,
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant>,
|
||||
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -26,3 +26,32 @@ TYPED_TEST(TestCkTileGroupedGemmQuant, Basic)
|
||||
|
||||
this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, stride_AQs, stride_BQs, group_count);
|
||||
}
|
||||
|
||||
// No Hot Loop Test Case, this is to test the correctness of the kernel when there is no hot loop
|
||||
// Using 256x256x128 to match the test kernel's tile size (M_Tile=256, N_Tile=256, K_Tile=128)
|
||||
TYPED_TEST(TestCkTileGroupedGemmQuant, SmallUniform) //
|
||||
{
|
||||
const int group_count = 2;
|
||||
std::vector<int> Ms;
|
||||
std::vector<int> Ns;
|
||||
std::vector<int> Ks;
|
||||
std::vector<int> stride_As;
|
||||
std::vector<int> stride_Bs;
|
||||
std::vector<int> stride_Cs;
|
||||
std::vector<int> stride_AQs;
|
||||
std::vector<int> stride_BQs;
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
Ms.push_back(256);
|
||||
Ns.push_back(256);
|
||||
Ks.push_back(256);
|
||||
|
||||
stride_As.push_back(0);
|
||||
stride_Bs.push_back(0);
|
||||
stride_Cs.push_back(0);
|
||||
stride_AQs.push_back(0);
|
||||
stride_BQs.push_back(0);
|
||||
}
|
||||
|
||||
this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, stride_AQs, stride_BQs, group_count);
|
||||
}
|
||||
|
||||
@@ -107,7 +107,15 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
constexpr bool transpose_c = false;
|
||||
// We create the GEMM pipeline without specifying hotloop or tailnumber.
|
||||
// These are automatically run inside the kernel based on the given input data.
|
||||
using QuantGemmProblem =
|
||||
using QuantGemmProblem = typename std::conditional<
|
||||
QuantType == ck_tile::QuantType::BQuantGrouped,
|
||||
ck_tile::GemmBQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
128>, // QuantGroupSize
|
||||
ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
@@ -116,9 +124,13 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
GemmUniversalTraits,
|
||||
transpose_c,
|
||||
BDataType,
|
||||
scheduler>;
|
||||
scheduler>>::type;
|
||||
|
||||
using GemmPipeline = typename std::conditional<
|
||||
QuantType == ck_tile::QuantType::BQuantGrouped,
|
||||
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>,
|
||||
ck_tile::GemmPipelineAgBgCrCompV3<QuantGemmProblem>>::type;
|
||||
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<QuantGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
@@ -244,6 +256,15 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
AQK = 1; // Row quantization: tensor shape [M, 1] or [1]
|
||||
BQK = 1; // Column quantization: tensor shape [1, N] or [1]
|
||||
}
|
||||
else if constexpr(QuantType == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
AQK = 0; // No A quantization
|
||||
BQK = K / 128; // Group quantization: BQK = K / GroupSize
|
||||
if(K % 128 != 0)
|
||||
{
|
||||
throw std::runtime_error("K must be divisible by 128 for BQuantGrouped mode");
|
||||
}
|
||||
}
|
||||
|
||||
stride_As[i] = ck_tile::get_default_stride(M, K, stride_As[i], is_row_major(ALayout{}));
|
||||
stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(BLayout{}));
|
||||
@@ -258,7 +279,13 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
else if constexpr(QuantType == ck_tile::QuantType::TensorQuant)
|
||||
{
|
||||
stride_AQs[i] = 1; // Tensor quantization: tensor shape [1]
|
||||
stride_AQs[i] = 1; // Tensor quantization: tensor shape [1]
|
||||
stride_BQs[i] = 1; // Tensor quantization: tensor shape [1]
|
||||
}
|
||||
else if constexpr(QuantType == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
stride_AQs[i] = 0; // No A quantization
|
||||
stride_BQs[i] =
|
||||
ck_tile::get_default_stride(BQK, N, stride_BQs[i], is_row_major(BQLayout()));
|
||||
}
|
||||
|
||||
a_m_k_tensors.push_back(ck_tile::HostTensor<ADataType>(
|
||||
@@ -285,6 +312,15 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
ck_tile::HostTensor<BQDataType>(ck_tile::host_tensor_descriptor(
|
||||
1, 1, stride_BQs[i], is_row_major(BQLayout()))));
|
||||
}
|
||||
else if constexpr(QuantType == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
aq_tensors.push_back(
|
||||
ck_tile::HostTensor<AQDataType>(ck_tile::host_tensor_descriptor(
|
||||
0, AQK, stride_AQs[i], is_row_major(AQLayout{}))));
|
||||
bq_tensors.push_back(
|
||||
ck_tile::HostTensor<BQDataType>(ck_tile::host_tensor_descriptor(
|
||||
BQK, N, stride_BQs[i], is_row_major(BQLayout()))));
|
||||
}
|
||||
|
||||
std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc
|
||||
<< " b_k_n: " << b_k_n_tensors[i].mDesc
|
||||
@@ -373,7 +409,6 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg),
|
||||
hipMemcpyHostToDevice,
|
||||
stream.stream_id_));
|
||||
|
||||
invoke_grouped_gemm_persistent<GroupedGemKernelParam_Mfma, ALayout, BLayout, CLayout>(
|
||||
stream, group_count, kargs_ptr);
|
||||
}
|
||||
@@ -420,6 +455,17 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
bq_tensors[i],
|
||||
c_m_n_host_ref);
|
||||
}
|
||||
else if constexpr(QuantType == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
ck_tile::reference_gemm_quant<ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
128,
|
||||
false>(
|
||||
a_m_k_tensors[i], bq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
|
||||
}
|
||||
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
|
||||
Reference in New Issue
Block a user