Merge commit '15e81397a45d82e2c3032ac7b4e8a7ac0f66590a' into develop

This commit is contained in:
assistant-librarian[bot]
2025-12-18 09:16:13 +00:00
parent a84d4d52bd
commit 792c3eb377
9 changed files with 1244 additions and 42 deletions

View File

@@ -84,24 +84,51 @@ auto gemm_multi_d(const gemm_multi_d_kargs& args, const ck_tile::stream_config&
const auto Run = [&](const auto memory_operation_) {
constexpr auto memory_operation = memory_operation_.value;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
EDataType,
DsLayout,
CLayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
// Epilogue selection: set to true for chainer-based, false for standard
// CShuffleEpilogue
constexpr bool UseChainerEpilogue = true;
using GemmEpilogue = std::conditional_t<
UseChainerEpilogue,
// Chainer-based epilogue
ck_tile::EpilogueChainer<ck_tile::CshuffleEpilogueSchedule<
ck_tile::CShuffleEpilogueChainProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
EDataType,
DsLayout,
CLayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>,
ck_tile::DefaultScheduleTag>>,
// Standard CShuffleEpilogue
ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
EDataType,
DsLayout,
CLayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>>;
using Kernel = ck_tile::GemmKernelMultiD<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);

View File

@@ -16,6 +16,7 @@
#include "ck_tile/host/permute_pk_int4.hpp"
#include "ck_tile/host/tensor_shuffle_utils.hpp"
#include "ck_tile/ops/gemm_quant.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "gemm_utils.hpp"
template <typename GemmConfig,
@@ -172,31 +173,77 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
printf(
"TiledPermuteN: %d (QuantGroupSize::kN=%d)\n", TiledPermuteN, BQuantGroupSize::kN);
}
using GemmEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
typename TypeConfig::ADataType,
std::conditional_t<
std::is_same_v<typename TypeConfig::BDataType, ck_tile::pk_fp4_raw_t>,
// Epilogue selection: use chainer for RowCol/Tensor quant, standard for others
// Toggle to switch between chainer-based and standard CShuffleEpilogue
constexpr bool UseChainerEpilogue = true;
// Define the schedule tag based on quant mode
using ScheduleTag =
std::conditional_t<QuantMode == ck_tile::QuantType::RowColQuant,
ck_tile::RowColQuantScheduleTag,
std::conditional_t<QuantMode == ck_tile::QuantType::TensorQuant,
ck_tile::TensorQuantScheduleTag,
ck_tile::DefaultScheduleTag>>;
using GemmEpilogue = std::conditional_t<
UseChainerEpilogue && (QuantMode == ck_tile::QuantType::RowColQuant ||
QuantMode == ck_tile::QuantType::TensorQuant),
// Chainer-based epilogue for RowCol/Tensor quant modes
ck_tile::EpilogueChainer<ck_tile::CshuffleEpilogueSchedule<
ck_tile::CShuffleEpilogueChainProblem<
typename TypeConfig::ADataType,
std::conditional_t<
std::is_same_v<typename TypeConfig::BDataType, ck_tile::pk_fp4_raw_t>,
typename TypeConfig::ADataType,
typename TypeConfig::BDataType>,
ck_tile::tuple<>,
typename TypeConfig::AccDataType,
typename TypeConfig::CDataType,
ck_tile::tuple<>,
CLayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
transpose_c,
ck_tile::memory_operation_enum::set,
1,
false,
1,
TiledPermuteN>,
ScheduleTag>>,
// Standard CShuffleEpilogue for other modes
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
typename TypeConfig::ADataType,
typename TypeConfig::BDataType>,
ck_tile::tuple<>,
typename TypeConfig::AccDataType,
typename TypeConfig::CDataType,
ck_tile::tuple<>,
CLayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
transpose_c,
ck_tile::memory_operation_enum::set,
1,
false,
1,
TiledPermuteN>>;
std::conditional_t<
std::is_same_v<typename TypeConfig::BDataType, ck_tile::pk_fp4_raw_t>,
typename TypeConfig::ADataType,
typename TypeConfig::BDataType>,
ck_tile::tuple<>,
typename TypeConfig::AccDataType,
typename TypeConfig::CDataType,
ck_tile::tuple<>,
CLayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
transpose_c,
ck_tile::memory_operation_enum::set,
1,
false,
1,
TiledPermuteN>>>;
using Kernel =
ck_tile::QuantGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue, QuantMode>;