mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 03:49:41 +00:00
Merge commit '15e81397a45d82e2c3032ac7b4e8a7ac0f66590a' into develop
This commit is contained in:
@@ -84,24 +84,51 @@ auto gemm_multi_d(const gemm_multi_d_kargs& args, const ck_tile::stream_config&
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
// Epilogue selection: set to true for chainer-based, false for standard
|
||||
// CShuffleEpilogue
|
||||
constexpr bool UseChainerEpilogue = true;
|
||||
|
||||
using GemmEpilogue = std::conditional_t<
|
||||
UseChainerEpilogue,
|
||||
// Chainer-based epilogue
|
||||
ck_tile::EpilogueChainer<ck_tile::CshuffleEpilogueSchedule<
|
||||
ck_tile::CShuffleEpilogueChainProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>,
|
||||
ck_tile::DefaultScheduleTag>>,
|
||||
// Standard CShuffleEpilogue
|
||||
ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>>;
|
||||
|
||||
using Kernel = ck_tile::GemmKernelMultiD<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
#include "ck_tile/host/permute_pk_int4.hpp"
|
||||
#include "ck_tile/host/tensor_shuffle_utils.hpp"
|
||||
#include "ck_tile/ops/gemm_quant.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "gemm_utils.hpp"
|
||||
|
||||
template <typename GemmConfig,
|
||||
@@ -172,31 +173,77 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
printf(
|
||||
"TiledPermuteN: %d (QuantGroupSize::kN=%d)\n", TiledPermuteN, BQuantGroupSize::kN);
|
||||
}
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
typename TypeConfig::ADataType,
|
||||
std::conditional_t<
|
||||
std::is_same_v<typename TypeConfig::BDataType, ck_tile::pk_fp4_raw_t>,
|
||||
|
||||
// Epilogue selection: use chainer for RowCol/Tensor quant, standard for others
|
||||
// Toggle to switch between chainer-based and standard CShuffleEpilogue
|
||||
constexpr bool UseChainerEpilogue = true;
|
||||
|
||||
// Define the schedule tag based on quant mode
|
||||
using ScheduleTag =
|
||||
std::conditional_t<QuantMode == ck_tile::QuantType::RowColQuant,
|
||||
ck_tile::RowColQuantScheduleTag,
|
||||
std::conditional_t<QuantMode == ck_tile::QuantType::TensorQuant,
|
||||
ck_tile::TensorQuantScheduleTag,
|
||||
ck_tile::DefaultScheduleTag>>;
|
||||
|
||||
using GemmEpilogue = std::conditional_t<
|
||||
UseChainerEpilogue && (QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
QuantMode == ck_tile::QuantType::TensorQuant),
|
||||
// Chainer-based epilogue for RowCol/Tensor quant modes
|
||||
ck_tile::EpilogueChainer<ck_tile::CshuffleEpilogueSchedule<
|
||||
ck_tile::CShuffleEpilogueChainProblem<
|
||||
typename TypeConfig::ADataType,
|
||||
std::conditional_t<
|
||||
std::is_same_v<typename TypeConfig::BDataType, ck_tile::pk_fp4_raw_t>,
|
||||
typename TypeConfig::ADataType,
|
||||
typename TypeConfig::BDataType>,
|
||||
ck_tile::tuple<>,
|
||||
typename TypeConfig::AccDataType,
|
||||
typename TypeConfig::CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
transpose_c,
|
||||
ck_tile::memory_operation_enum::set,
|
||||
1,
|
||||
false,
|
||||
1,
|
||||
TiledPermuteN>,
|
||||
ScheduleTag>>,
|
||||
// Standard CShuffleEpilogue for other modes
|
||||
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
typename TypeConfig::ADataType,
|
||||
typename TypeConfig::BDataType>,
|
||||
ck_tile::tuple<>,
|
||||
typename TypeConfig::AccDataType,
|
||||
typename TypeConfig::CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
transpose_c,
|
||||
ck_tile::memory_operation_enum::set,
|
||||
1,
|
||||
false,
|
||||
1,
|
||||
TiledPermuteN>>;
|
||||
std::conditional_t<
|
||||
std::is_same_v<typename TypeConfig::BDataType, ck_tile::pk_fp4_raw_t>,
|
||||
typename TypeConfig::ADataType,
|
||||
typename TypeConfig::BDataType>,
|
||||
ck_tile::tuple<>,
|
||||
typename TypeConfig::AccDataType,
|
||||
typename TypeConfig::CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
transpose_c,
|
||||
ck_tile::memory_operation_enum::set,
|
||||
1,
|
||||
false,
|
||||
1,
|
||||
TiledPermuteN>>>;
|
||||
|
||||
using Kernel =
|
||||
ck_tile::QuantGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue, QuantMode>;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user