Merge branch 'develop' into vpietila/add-fwd-conv-v3-instances-for-unit-group-size

This commit is contained in:
Ville Pietilä
2026-01-29 10:38:35 +02:00
committed by GitHub
56 changed files with 1636 additions and 648 deletions

View File

@@ -750,9 +750,21 @@ fwd_result fmha_fwd_run(mode_enum mode,
}
else if(qscale.type == quant_scale_enum::blockscale)
{
ck_tile::FillUniformDistribution<float>{0.012f, 0.015f, next_seed()}(q_descale_host);
ck_tile::FillUniformDistribution<float>{0.012f, 0.015f, next_seed()}(k_descale_host);
ck_tile::FillUniformDistribution<float>{0.012f, 0.015f, next_seed()}(v_descale_host);
float q_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<QDataType>::max());
float k_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<KDataType>::max());
float v_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<VDataType>::max());
float qkv_max = 3.f;
float max_descale_q = qkv_max / q_dtype_max;
float max_descale_k = qkv_max / k_dtype_max;
float max_descale_v = qkv_max / v_dtype_max;
ck_tile::FillUniformDistribution<float>{max_descale_q * 0.8f, max_descale_q, next_seed()}(
q_descale_host);
ck_tile::FillUniformDistribution<float>{max_descale_k * 0.8f, max_descale_k, next_seed()}(
k_descale_host);
ck_tile::FillUniformDistribution<float>{max_descale_v * 0.8f, max_descale_v, next_seed()}(
v_descale_host);
}
iota_shuffle(block_table_host.begin(), block_table_host.end(), 0, random_engine);

View File

@@ -59,7 +59,8 @@ float grouped_gemm_abquant(const std::vector<grouped_gemm_kargs>& gemm_descs,
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
false, // PreshuffleQuant
false, // APreshuffleQuant
false, // BPreshuffleQuant
GemmConfig::PreshuffleB,
ALayout,
BLayout,
@@ -202,7 +203,8 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
false, // PreshuffleQuant
false, // APreshuffleQuant
false, // BPreshuffleQuant
GemmConfig::PreshuffleB,
ALayout,
BLayout,

View File

@@ -44,7 +44,8 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
false, // PreshuffleQuant
false, // APreshuffleQuant
false, // BPreshuffleQuant
GemmConfig::PreshuffleB,
ALayout,
BLayout,
@@ -210,7 +211,8 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
false, // PreshuffleQuant
false, // APreshuffleQuant
false, // BPreshuffleQuant
GemmConfig::PreshuffleB,
ALayout,
BLayout,

View File

@@ -134,5 +134,35 @@ static auto _ = []() {
BQuantGroupSize,
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"fp8",
"abquant",
"non-preshuffleb",
"preshufflequant",
"1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) {
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfigPreshuffleBQuantPrefill<ck_tile::fp8_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"fp8",
"abquant",
"non-preshuffleb",
"preshufflequant",
"1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) {
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfigPreshuffleBQuantPrefill<ck_tile::fp8_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
};
return 0;
}();

View File

@@ -80,7 +80,8 @@ struct GemmConfigBase
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
static constexpr ck_tile::index_t TileParitionerM01 = 4;
static constexpr bool PreshuffleQuant = false;
static constexpr bool APreshuffleQuant = false;
static constexpr bool BPreshuffleQuant = false;
static constexpr bool PreshuffleB = false;
static constexpr bool DoubleSmemBuffer = false;
static constexpr bool TiledMMAPermuteN = false;
@@ -157,7 +158,8 @@ struct GemmConfigPreshuffleQuantDecode : public GemmConfigBase
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
static constexpr bool PreshuffleQuant = true;
static constexpr bool APreshuffleQuant = true;
static constexpr bool BPreshuffleQuant = true;
};
template <typename PrecType>
@@ -187,7 +189,7 @@ template <typename PrecType>
struct GemmConfigPreshuffleB_PreshuffleBQuant_Decode
: public GemmConfigPreshuffleB_BQuant_Decode<PrecType>
{
static constexpr bool PreshuffleQuant = true;
static constexpr bool BPreshuffleQuant = true;
};
template <typename PrecType>
@@ -218,7 +220,7 @@ template <typename PrecType>
struct GemmConfigPreshuffleB_PreshuffleBQuant_Prefill
: public GemmConfigPreshuffleB_BQuant_Prefill<PrecType>
{
static constexpr bool PreshuffleQuant = true;
static constexpr bool BPreshuffleQuant = true;
};
template <typename PrecType>
@@ -272,7 +274,7 @@ struct GemmConfigABQuantPrefill : public GemmConfigQuantPrefill<PrecType>
template <typename PrecType>
struct GemmConfigPreshuffleBQuantPrefill : public GemmConfigQuantPrefill<PrecType>
{
static constexpr bool PreshuffleQuant = true;
static constexpr bool BPreshuffleQuant = true;
};
template <typename PrecType>

View File

@@ -33,7 +33,8 @@ template <typename GemmConfig,
float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
{
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr bool transpose_c = QuantMode == ck_tile::QuantType::ABQuantGrouped;
constexpr bool transpose_c =
GemmConfig::TransposeC; // QuantMode == ck_tile::QuantType::ABQuantGrouped;
using ComputeDataType = std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant,
typename TypeConfig::BDataType,
@@ -50,14 +51,15 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
using GemmTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::PreshuffleQuant,
GemmConfig::APreshuffleQuant,
GemmConfig::BPreshuffleQuant,
GemmConfig::PreshuffleB,
ALayout,
BLayout,
CLayout,
QuantMode,
AQLayout, // for AQLayout
BQLayout, // for BQLayout
AQLayout,
BQLayout,
transpose_c,
GemmConfig::DoubleSmemBuffer>;
@@ -73,7 +75,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
GemmConfig::PreshuffleB == true,
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>,
std::conditional_t<
QuantMode == ck_tile::QuantType::AQuantGrouped && GemmConfig::PreshuffleQuant == true,
QuantMode == ck_tile::QuantType::AQuantGrouped && GemmConfig::APreshuffleQuant == true,
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>,
std::conditional_t<
QuantMode == ck_tile::QuantType::AQuantGrouped,
@@ -146,7 +148,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
has_hot_loop_v,
tail_number_v>>>>;
using AQuantPipeline =
std::conditional_t<GemmConfig::PreshuffleQuant,
std::conditional_t<GemmConfig::APreshuffleQuant,
ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
ck_tile::AQuantGemmPipelineAgBgCrMem<PipelineProblem>>;
@@ -390,8 +392,10 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
std::cout << " Acc_Type = " << ck_tile::DataTypeTraits<typename TypeConfig::AccDataType>::name
<< " C_Type = " << ck_tile::DataTypeTraits<typename TypeConfig::CDataType>::name
<< " QuantMode = " << quant_type_to_string(QuantMode)
<< " PreshuffleQuant = " << (GemmConfig::PreshuffleQuant ? "true" : "false") << " : "
<< " PreshuffleB = " << (GemmConfig::PreshuffleB ? "true" : "false") << " : "
<< " APreshuffleQuant = " << (GemmConfig::APreshuffleQuant ? "true" : "false")
<< " : "
<< " BPreshuffleQuant = " << (GemmConfig::BPreshuffleQuant ? "true" : "false")
<< " : " << " PreshuffleB = " << (GemmConfig::PreshuffleB ? "true" : "false") << " : "
<< ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< std::endl;
@@ -536,21 +540,13 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
// Create BQ tensor with appropriate shape
std::unique_ptr<ck_tile::HostTensor<BQDataType>> bq_tensor_ptr = nullptr;
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant)
QuantMode == ck_tile::QuantType::ABQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant ||
QuantMode == ck_tile::QuantType::TensorQuant)
{
bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>(
ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, is_row_major(bq_layout)));
}
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
{
bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>(
ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, is_row_major(bq_layout)));
}
else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant)
{
bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>(
ck_tile::host_tensor_descriptor(1, 1, stride_BQ, is_row_major(bq_layout)));
}
std::mt19937 gen(42);
std::uniform_int_distribution<std::uint32_t> fill_seed(0, 500);
@@ -870,7 +866,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
QuantMode == ck_tile::QuantType::RowColQuant ||
QuantMode == ck_tile::QuantType::TensorQuant)
{
if constexpr(GemmConfig::PreshuffleQuant)
if constexpr(GemmConfig::APreshuffleQuant)
{
ck_tile::HostTensor<AQDataType> aq_shuffle_host =
ck_tile::shuffle_aq(aq_tensor_ptr.get(), GemmConfig::K_Tile / AQuantGroupSize::kK);
@@ -929,7 +925,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
ck_tile::HostTensor<BQDataType> bq_permuted_host =
ck_tile::bq_permuteN<GemmConfig>(*bq_tensor_ptr, BQuantGroupSize::kN);
if constexpr(GemmConfig::PreshuffleQuant)
if constexpr(GemmConfig::BPreshuffleQuant)
{
ck_tile::HostTensor<BQDataType> bq_shuffle_host = ck_tile::shuffle_bq(
&bq_permuted_host, GemmConfig::K_Tile / BQuantGroupSize::kK);
@@ -940,7 +936,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
bq_dev_buf_ptr->ToDevice(bq_permuted_host.data());
}
}
else if constexpr(GemmConfig::PreshuffleQuant)
else if constexpr(GemmConfig::BPreshuffleQuant)
{
ck_tile::HostTensor<BQDataType> bq_shuffle_host =
ck_tile::shuffle_bq(bq_tensor_ptr.get(), GemmConfig::K_Tile / BQuantGroupSize::kK);
@@ -1121,7 +1117,7 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser)
if constexpr((QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::ABQuantGrouped) &&
!GemmConfig::PreshuffleQuant && !GemmConfig::PreshuffleB)
!GemmConfig::APreshuffleQuant && !GemmConfig::PreshuffleB)
{
if(a_layout == "R" && b_layout == "R")
{
@@ -1142,7 +1138,8 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser)
arg_parser, Col{}, Row{}, Row{}, Col{}, Row{});
}
}
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped && !GemmConfig::PreshuffleQuant)
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped &&
!GemmConfig::APreshuffleQuant)
{
if(a_layout == "C" && b_layout == "C")
{

View File

@@ -35,7 +35,7 @@ template <typename T>
concept BwdXdlV3AlgorithmBase =
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesTileTransferParameters3D<T> &&
SpecifiesGridwiseBwdXdlGemm<T> && SpecifiesBwdWeightConvSpecialization<T> &&
SpecifiesBlockGemm<T>;
SpecifiesBlockGemm<T> && SpecifiesNumGroupsToMerge<T>;
template <typename T>
concept BwdWmmaAlgorithmBase =

View File

@@ -53,7 +53,9 @@ template <ck::index_t NDimSpatial,
ck::BlockGemmPipelineScheduler BlkGemmPipeSched,
ck::BlockGemmPipelineVersion BlkGemmPipelineVer,
typename ComputeTypeA,
typename ComputeTypeB>
typename ComputeTypeB,
bool DirectLoad,
index_t NumGroupsToMerge>
struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3;
} // namespace ck::tensor_operation::device
@@ -109,7 +111,9 @@ template <ck::index_t NDimSpatial,
ck::BlockGemmPipelineScheduler BlkGemmPipeSched,
ck::BlockGemmPipelineVersion BlkGemmPipelineVer,
typename ComputeTypeA_,
typename ComputeTypeB_>
typename ComputeTypeB_,
bool DirectLoad,
index_t NumGroupsToMerge>
struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3<
NDimSpatial,
InLayout_,
@@ -153,7 +157,9 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA_,
ComputeTypeB_>>
ComputeTypeB_,
DirectLoad,
NumGroupsToMerge>>
{
/// @brief Tag type identifying this device kernel variant
@@ -241,6 +247,9 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
using ComputeTypeA = ComputeTypeA_;
using ComputeTypeB = ComputeTypeB_;
static constexpr bool kDirectLoad = DirectLoad;
static constexpr index_t kNumGroupsToMerge = NumGroupsToMerge;
// Static member function to generate instance string
static std::string instance_string()
{
@@ -302,6 +311,8 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
oss << "," << detail::pipeline_version_name(kBlkGemmPipelineVer); // 41.
oss << "," << detail::type_name<ComputeTypeA>(); // 42.
oss << "," << detail::type_name<ComputeTypeB>(); // 43.
oss << "," << kDirectLoad; // 44.
oss << "," << kNumGroupsToMerge; // 45.
oss << ">";
return oss.str();

View File

@@ -32,7 +32,8 @@ constexpr auto ALGORITHM =
.with_gemm_config(cku::BwdGemmParams_Xdl_1x1_per_wave)
.with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3)
.with_bwd_specialization(ckb::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0)
.with_block_gemm(cku::BlockGemmDesc_v2_intrawave);
.with_block_gemm(cku::BlockGemmDesc_v2_intrawave)
.with_num_conv_groups_to_merge(1);
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
using Instance = Builder::Instance;

View File

@@ -632,7 +632,8 @@ using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 =
BwdXdlGemm_,
Transfer_<>,
ConvSpecializationBwdWeight_,
BlockGemm_>;
BlockGemm_,
GemmBatchOptions_>;
using ConvAlgorithm_DeviceGroupedConvBwdWeight_Dl =
ConvAlgorithmTemplate<ThreadBlock_,

View File

@@ -69,6 +69,8 @@ std::string expected_str =
",v1" // BlkGemmPipelineVer
",fp16" // ComputeTypeA
",fp16" // ComputeTypeB
",0" // DirectLoad
",1" // NumGroupsToMerge
">";
// Test describe() through base class pointer for XDL V3 variant

View File

@@ -30,7 +30,8 @@ template <index_t BlockSize,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
bool TransposeC = false>
bool TransposeC = false,
bool LdsScalarLoadToVgpr = false>
struct BlockwiseGemmXdlops_pipeline_base
{
static constexpr auto I0 = Number<0>{};
@@ -385,7 +386,7 @@ struct BlockwiseGemmXdlops_pipeline_base
Sequence<1, 1, 1, KPack>,
Sequence<0, 1, 2, 3>,
3,
A_K1,
LdsScalarLoadToVgpr ? 1 : A_K1,
A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<BDataType,
@@ -395,7 +396,7 @@ struct BlockwiseGemmXdlops_pipeline_base
Sequence<1, 1, 1, KPack>,
Sequence<0, 1, 2, 3>,
3,
B_K1,
LdsScalarLoadToVgpr ? 1 : B_K1,
B_K1>;
AThreadCopy a_thread_copy_;

View File

@@ -32,9 +32,15 @@ template <BlockGemmPipelineVersion BlkGemmPipelineVer,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
bool DirectLoad = false>
bool DirectLoad = false,
bool LdsScalarLoadToVgpr = false>
constexpr auto BlockGemmPipeline_Selector()
{
// Supported for Direct Load and V1
if constexpr(LdsScalarLoadToVgpr)
{
static_assert(DirectLoad && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1);
}
if constexpr(DirectLoad)
{
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
@@ -58,7 +64,8 @@ constexpr auto BlockGemmPipeline_Selector()
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
KPack,
LdsScalarLoadToVgpr>{};
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{

View File

@@ -758,7 +758,8 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPacks>
index_t KPacks,
bool LdsScalarLoadToVgpr = false>
struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1
{
};
@@ -781,9 +782,9 @@ template <index_t BlockSize,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPack
index_t KPack,
// ,bool TransposeC //disable transposec right now...
>
bool LdsScalarLoadToVgpr>
struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
BlockSize,
ADataType,
@@ -803,7 +804,8 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
NPerXDL,
MRepeat,
NRepeat,
KPack>
KPack,
LdsScalarLoadToVgpr>
: BlockwiseGemmXdlops_pipeline_base<BlockSize,
ADataType,
BDataType,
@@ -822,7 +824,9 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
NPerXDL,
MRepeat,
NRepeat,
KPack>
KPack,
false /*TransposeC*/,
LdsScalarLoadToVgpr>
{
using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize,
@@ -843,7 +847,9 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
NPerXDL,
MRepeat,
NRepeat,
KPack>;
KPack,
false /*TransposeC*/,
LdsScalarLoadToVgpr>;
using Base::I0;
using Base::KRepeat;
using Base::xdlops_gemm;

View File

@@ -140,10 +140,6 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
"Direct load transfer does not support datatypes conversion. Source and "
"destination data types must be the same.");
static_assert(
DstVectorDim == nDim - 1,
"Direct load transfer requires the destination vector dimension to be the last one.");
static_assert(ScalarPerVector == 1 || SrcVectorDim == DstVectorDim,
"When loading more than one element per thread at once, the contiguous "
"dimension must be the same between source and destination.");

View File

@@ -82,23 +82,48 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())];
DispatchSplitKHack<GridwiseGemm,
AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
karg.p_b_grid + b_batch_offset + split_k_offset_b,
karg.p_c_grid + e_batch_offset,
p_shared,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_idx * num_k_per_block,
gridDim.y,
split_k_offset_hack);
if constexpr(GridwiseGemm::DirectLoadEnabled)
{
#if defined(__gfx950__)
DispatchSplitKHack<GridwiseGemm,
AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
karg.p_b_grid + b_batch_offset + split_k_offset_b,
karg.p_c_grid + e_batch_offset,
p_shared,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_idx * num_k_per_block,
gridDim.y,
split_k_offset_hack);
#endif
}
else
{
DispatchSplitKHack<GridwiseGemm,
AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
karg.p_b_grid + b_batch_offset + split_k_offset_b,
karg.p_c_grid + e_batch_offset,
p_shared,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_idx * num_k_per_block,
gridDim.y,
split_k_offset_hack);
}
}
#else
ignore = karg;
@@ -236,7 +261,9 @@ template <ck::index_t NDimSpatial,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
typename ComputeTypeA = InDataType,
typename ComputeTypeB = ComputeTypeA>
typename ComputeTypeB = ComputeTypeA,
bool DirectLoad = false,
index_t NumGroupsToMerge = 1>
struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
: public DeviceGroupedConvBwdWeight<NDimSpatial,
InLayout,
@@ -287,7 +314,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
NPerBlock,
K1Number,
K0PerBlock / K1Number,
1 /*NumGroupsToMerge*/,
NumGroupsToMerge,
ConvBackwardWeightSpecialization>{};
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
@@ -371,6 +398,16 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
// Disable vector load = 4. It is not supported for Direct Load. Align to 2 in such case.
static constexpr index_t ABlockTransferSrcScalarPerVectorAligned =
ABlockTransferSrcScalarPerVector * sizeof(ADataType) == 8
? 4 / sizeof(ADataType)
: ABlockTransferSrcScalarPerVector;
static constexpr index_t BBlockTransferSrcScalarPerVectorAligned =
BBlockTransferSrcScalarPerVector * sizeof(BDataType) == 8
? 4 / sizeof(BDataType)
: BBlockTransferSrcScalarPerVector;
template <index_t NXdlPerWave_>
using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_conv_v3<
tensor_layout::gemm::RowMajor,
@@ -399,7 +436,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
DirectLoad ? ABlockTransferSrcScalarPerVectorAligned : ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
false,
ABlockLdsAddExtraM,
@@ -407,7 +444,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
DirectLoad ? BBlockTransferSrcScalarPerVectorAligned : BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
false,
BBlockLdsAddExtraN,
@@ -418,7 +455,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>;
ComputeTypeB,
DirectLoad>;
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
@@ -653,15 +691,16 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
if(split_k_offset_hack_)
split_k_stride_b_ /= k_batch_;
// A/B/C Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0];
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides[0];
// A/B/C Batch Stride (multiply by NumGroupsToMerge for group merging)
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0] * NumGroupsToMerge;
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides[0] * NumGroupsToMerge;
compute_ptr_offset_of_batch_.BatchStrideC_ =
Conv_K_ * Conv_C_ *
std::accumulate(begin(filter_spatial_lengths_),
end(filter_spatial_lengths_),
index_t{1},
std::multiplies<>{});
std::multiplies<>{}) *
NumGroupsToMerge;
const index_t GemmM = a_grid_desc_k0_m_k1_.GetLength(I1);
const index_t GemmN = b_grid_desc_k0_n_k1_.GetLength(I1);
@@ -743,7 +782,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(
gemm_arg.M, gemm_arg.N, gemm_arg.KBatch, arg.Conv_G_);
gemm_arg.M, gemm_arg.N, gemm_arg.KBatch, arg.Conv_G_ / NumGroupsToMerge);
float ave_time = 0;
@@ -1367,6 +1406,30 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
}
#endif
// check device
if constexpr(DirectLoad)
{
if(get_device_name() != "gfx950")
{
return false;
}
}
// Check that NumGroupsToMerge divides Conv_G evenly
if constexpr(NumGroupsToMerge > 1)
{
if(arg.Conv_G_ % NumGroupsToMerge != 0)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Unsupported! Conv_G_ % NumGroupsToMerge != 0: Conv_G_="
<< arg.Conv_G_ << ", NumGroupsToMerge=" << NumGroupsToMerge
<< std::endl;
}
return false;
}
}
const index_t GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1);
const index_t GemmN = arg.b_grid_desc_k0_n_k1_.GetLength(I1);
const index_t GemmK =
@@ -1617,8 +1680,13 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
auto str = std::stringstream();
// clang-format off
str << "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3"
<< "<"
str << "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3";
if constexpr(DirectLoad) {
str << "_DirectLoad";
}
str << "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "

View File

@@ -567,6 +567,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
using DsGridDesc_M_N =
remove_cvref_t<decltype(MakeDsGridDescriptor_M_N(dummy_conv_to_gemm_transformer))>;
// Disable vector load = 4. It is not supported for Direct Load. Align to 2 in such case.
static constexpr index_t ABlockTransferSrcScalarPerVectorAligned =
ABlockTransferSrcScalarPerVector * sizeof(ADataType) == 8
? 4 / sizeof(ADataType)

View File

@@ -12,6 +12,7 @@
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp"
namespace ck {
@@ -61,7 +62,8 @@ template <typename ALayout,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v4,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA>
typename ComputeTypeB = ComputeTypeA,
bool DirectLoad = false>
struct GridwiseGemm_xdl_cshuffle_conv_v3
: public GridwiseGemm_xdl_cshuffle_base<
ALayout,
@@ -109,6 +111,10 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
ComputeTypeB,
false> // ForceNaiveLayout
{
static_assert((is_same_v<AElementwiseOperation, tensor_operation::element_wise::PassThrough> &&
is_same_v<BElementwiseOperation, tensor_operation::element_wise::PassThrough>) ||
!DirectLoad);
using Base = GridwiseGemm_xdl_cshuffle_base<
ALayout,
BLayout,
@@ -164,6 +170,8 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
using Base::I2;
using ThisThreadBlock = typename Base::ThisThreadBlock;
static constexpr bool DirectLoadEnabled = DirectLoad;
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
static constexpr bool is_single_rate_mfma =
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
@@ -353,7 +361,13 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
template <typename DeviceArch>
__device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(DeviceArch)
{
if constexpr(is_same_v<DeviceArch, gfx950_t>)
if constexpr(DirectLoad)
{
return make_naive_tensor_descriptor(
make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number),
make_tuple(Number<MPerBlock * AK1Number>{}, I1, Number<MPerBlock>{}));
}
else if constexpr(is_same_v<DeviceArch, gfx950_t>)
{
// Force use padded layout on gfx950 to reduce bank conflicts
constexpr index_t ABlockLdsExtraM = 1;
@@ -370,7 +384,13 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
template <typename DeviceArch>
__device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(DeviceArch)
{
if constexpr(is_same_v<DeviceArch, gfx950_t>)
if constexpr(DirectLoad)
{
return make_naive_tensor_descriptor(
make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number),
make_tuple(Number<NPerBlock * BK1Number>{}, I1, Number<NPerBlock>{}));
}
else if constexpr(is_same_v<DeviceArch, gfx950_t>)
{
constexpr index_t BBlockLdsExtraN = 1;
return make_naive_tensor_descriptor(
@@ -385,31 +405,36 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType)
using BlockwiseGemmPipe = remove_cvref_t<
decltype(BlockGemmPipeline_Selector<
BlkGemmPipelineVer,
BlkGemmPipeSched,
BlockSize,
ADataType,
BDataType,
ComputeTypeA,
AccDataType,
decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch())),
decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch())),
decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(
// Disable vector load from lds to vgpr for direct load (backward weight store with continous M
// or N dimension)
static constexpr bool LdsScalarLoadToVgpr = DirectLoad;
using BlockwiseGemmPipe = remove_cvref_t<
decltype(BlockGemmPipeline_Selector<
BlkGemmPipelineVer,
BlkGemmPipeSched,
BlockSize,
ADataType,
BDataType,
ComputeTypeA,
AccDataType,
decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch())),
decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch())),
decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()))),
decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(
decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()))),
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack>())>;
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack,
DirectLoad,
LdsScalarLoadToVgpr>())>;
template <typename DeviceArch>
__device__ static constexpr index_t GetSharedMemoryNumberOfByte(DeviceArch)
@@ -539,67 +564,119 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
constexpr auto b_block_desc_bk0_n_bk1 =
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch());
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ADataType,
ADataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
a_grid_desc_ak0_m_ak1,
make_multi_index(SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
auto get_a_blockwise_copy = [&]() {
if constexpr(DirectLoad)
{
return ThreadGroupTensorSliceTransfer_DirectLoad<
ThisThreadBlock,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ADataType,
ADataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
1,
ABlockTransferSrcScalarPerVector>(
a_grid_desc_ak0_m_ak1,
make_multi_index(SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0),
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0));
}
else
{
return ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ADataType,
ADataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
a_grid_desc_ak0_m_ak1,
make_multi_index(SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
}
};
// B matrix blockwise copy
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BDataType,
BDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
b_grid_desc_bk0_n_bk1,
make_multi_index(SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
auto get_b_blockwise_copy = [&]() {
if constexpr(DirectLoad)
{
return ThreadGroupTensorSliceTransfer_DirectLoad<
ThisThreadBlock,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BDataType,
BDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
1,
BBlockTransferSrcScalarPerVector>(
b_grid_desc_bk0_n_bk1,
make_multi_index(SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0),
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0));
}
else
{
return ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BDataType,
BDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
b_grid_desc_bk0_n_bk1,
make_multi_index(SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
}
};
auto a_blockwise_copy = get_a_blockwise_copy();
auto b_blockwise_copy = get_b_blockwise_copy();
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
@@ -722,67 +799,119 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
constexpr auto b_block_desc_bk0_n_bk1 =
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch());
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ADataType,
ADataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
a_grid_desc_ak0_m_ak1,
make_multi_index(SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
auto get_a_blockwise_copy = [&]() {
if constexpr(DirectLoad)
{
return ThreadGroupTensorSliceTransfer_DirectLoad<
ThisThreadBlock,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ADataType,
ADataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
1,
ABlockTransferSrcScalarPerVector>(
a_grid_desc_ak0_m_ak1,
make_multi_index(SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0),
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0));
}
else
{
return ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ADataType,
ADataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
a_grid_desc_ak0_m_ak1,
make_multi_index(SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
}
};
// B matrix blockwise copy
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BDataType,
BDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
b_grid_desc_bk0_n_bk1,
make_multi_index(SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
auto get_b_blockwise_copy = [&]() {
if constexpr(DirectLoad)
{
return ThreadGroupTensorSliceTransfer_DirectLoad<
ThisThreadBlock,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BDataType,
BDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
1,
BBlockTransferSrcScalarPerVector>(
b_grid_desc_bk0_n_bk1,
make_multi_index(SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0),
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0));
}
else
{
return ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BDataType,
BDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
b_grid_desc_bk0_n_bk1,
make_multi_index(SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
}
};
auto a_blockwise_copy = get_a_blockwise_copy();
auto b_blockwise_copy = get_b_blockwise_copy();
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(

View File

@@ -26,17 +26,26 @@ namespace ck_tile {
*
* @tparam BottomTensorView_ Class describing & holding device tensor memory.
* @tparam WindowLengths_ Spatial sizes of windowed view on tensor.
* @tparam StaticTileDistribution_ Thread distribution (mapping) into Tile dimensions
* @tparam NumCoord TBD
* @tparam StaticTileDistribution_ Thread distribution (mapping) into Tile dimensions.
* @tparam StaticPageIndexArray_ Array type holding page indices for scatter/gather.
* @tparam StaticValidArray_ Array type holding validity flags (nullptr_t if unused).
* @tparam HsGatherDim H-space dimension index used for gather lookup (default: 0).
* @tparam NumCoord Number of pre-computed coordinates for pipelining (default: 1).
* @tparam YsGatherDims Sequence of Y-space dimension indices used for page lookup.
* For single dimension: sequence<0> (default).
* For multiple dimensions: sequence<dim0, dim1, ...> where
* the combined index is computed as:
* idx[dim0] + idx[dim1] * len[dim0] + idx[dim2] * len[dim0] *
* len[dim1] + ...
*/
template <typename BottomTensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
typename StaticPageIndexArray_,
typename StaticValidArray_,
index_t HsGatherDim = 0,
index_t NumCoord = 1,
index_t YsGatherDim = 0>
index_t HsGatherDim = 0,
index_t NumCoord = 1,
typename YsGatherDims = sequence<0>>
struct tile_scatter_gather
{
using BottomTensorView = remove_reference_t<BottomTensorView_>;
@@ -77,6 +86,75 @@ struct tile_scatter_gather
using BottomTensorCoord =
decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{}));
/**
* @brief Check if a given Y-space dimension index is a gather dimension.
*
* Gather dimensions are those specified in YsGatherDims template parameter.
* When computing forward_step_scatter, gather dimensions are set to 0
* because page offset lookup handles address calculation for these dimensions.
*
* @param i Y-space dimension index to check
* @return true if dimension i is in YsGatherDims, false otherwise
*/
CK_TILE_DEVICE static constexpr bool is_gather_dim(index_t i)
{
return sequence_any_of(YsGatherDims{}, [i](auto k) { return i == k; });
}
/**
* @brief Compute the linearized gather index from Y-space indices for page lookup.
*
* This function converts multi-dimensional Y-space indices (specified by YsGatherDims)
* into a single linearized index used to look up the page offset in page_idx_ array.
*
* For single gather dimension (YsGatherDims::size() == 1):
* Simply returns idx_ys_start[YsGatherDims::at(0)]
*
* For multiple gather dimensions (e.g., YsGatherDims = sequence<0, 2>):
* Computes: idx[dim0] + idx[dim1] * len[dim0] + idx[dim2] * len[dim0] * len[dim1] + ...
* This is row-major linearization where earlier dimensions are inner (faster-varying).
*
* @tparam YsIndex Type of the Y-space index tuple/array
* @param idx_ys_start Current Y-space indices from space-filling curve iteration
* @return Linearized index for page_idx_ array lookup
*/
template <typename YsIndex>
CK_TILE_DEVICE static constexpr auto get_gather_index(const YsIndex& idx_ys_start)
{
// TODO: Consider making ys_lengths_ part of public API or adding accessor
static_assert(sizeof(TileDstr::DstrEncode::detail::ys_lengths_) > 0,
"Relies on internal detail::ys_lengths_");
constexpr index_t num_gather_dims = YsGatherDims::size();
if constexpr(num_gather_dims == 1)
{
return idx_ys_start[number<YsGatherDims::at(0)>{}];
}
else
{
// Recursive lambda to compute index as a compile-time number
// Uses row-major linearization: idx[0] + idx[1] * len[0] + idx[2] * len[0] * len[1] +
// ...
auto recurse = [&](auto self, auto i_constant) {
constexpr index_t i = decltype(i_constant)::value;
constexpr index_t dim = YsGatherDims::at(i);
auto current_val = idx_ys_start[number<dim>{}];
if constexpr(i + 1 < num_gather_dims)
{
constexpr index_t len = TileDstr::DstrEncode::detail::ys_lengths_[dim];
return current_val + self(self, number<i + 1>{}) * number<len>{};
}
else
{
return current_val;
}
};
return recurse(recurse, number<0>{});
}
}
struct load_store_traits
{
private:
@@ -375,7 +453,7 @@ struct tile_scatter_gather
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
constexpr auto idx_gather = get_gather_index(idx_ys_start);
const auto page_offset = page_idx_[idx_gather];
// read from bottom tensor
@@ -427,7 +505,7 @@ struct tile_scatter_gather
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto forward_step_scatter = generate_tuple(
[&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
[&](auto i) { return is_gather_dim(i) ? 0 : idx_diff_ys[i]; },
number<NDimY>{});
constexpr auto idx_diff_ps_ys = container_concat(
@@ -485,7 +563,7 @@ struct tile_scatter_gather
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
constexpr auto idx_gather = get_gather_index(idx_ys_start);
const auto page_offset = page_idx_[idx_gather];
// merge page_offset into bottom_coord
@@ -513,7 +591,7 @@ struct tile_scatter_gather
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto forward_step_scatter = generate_tuple(
[&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
[&](auto i) { return is_gather_dim(i) ? 0 : idx_diff_ys[i]; },
number<NDimY>{});
constexpr auto idx_diff_ps_ys = container_concat(
@@ -598,7 +676,7 @@ struct tile_scatter_gather
}();
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
constexpr auto idx_gather = get_gather_index(idx_ys_start);
const auto page_offset = page_idx_[idx_gather];
// read from bottom tensor
@@ -624,7 +702,7 @@ struct tile_scatter_gather
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto forward_step_scatter = generate_tuple(
[&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
[&](auto i) { return is_gather_dim(i) ? 0 : idx_diff_ys[i]; },
number<NDimY>{});
constexpr auto idx_diff_ps_ys = container_concat(
@@ -718,7 +796,7 @@ struct tile_scatter_gather
}();
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
constexpr auto idx_gather = get_gather_index(idx_ys_start);
const auto page_offset = page_idx_[idx_gather];
auto mixed_bottom_thread_coord = bottom_tensor_thread_coord;
@@ -748,7 +826,7 @@ struct tile_scatter_gather
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto forward_step_scatter = generate_tuple(
[&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
[&](auto i) { return is_gather_dim(i) ? 0 : idx_diff_ys[i]; },
number<NDimY>{});
constexpr auto idx_diff_ps_ys = container_concat(
@@ -791,7 +869,7 @@ struct tile_scatter_gather
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
constexpr auto idx_gather = idx_ys_start[number<0>{}];
constexpr auto idx_gather = get_gather_index(idx_ys_start);
const auto page_offset = page_idx_[idx_gather];
// read from distributed tensor
@@ -837,7 +915,7 @@ struct tile_scatter_gather
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto forward_step_scatter = generate_tuple(
[&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
[&](auto i) { return is_gather_dim(i) ? 0 : idx_diff_ys[i]; },
number<NDimY>{});
constexpr auto idx_diff_ps_ys = container_concat(
@@ -874,11 +952,11 @@ struct tile_scatter_gather
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
constexpr auto idx_gather = idx_ys_start[number<0>{}];
constexpr auto idx_gather = get_gather_index(idx_ys_start);
const auto page_offset = page_idx_[idx_gather];
// printf("idx_ys_start[0], idx_ys_start[1](%d, %d) \n",
// idx_ys_start[number<0>{}]+0, idx_ys_start[number<1>{}]+0);
// get_gather_index(idx_ys_start)+0, idx_ys_start[number<1>{}]+0);
// read from distributed tensor
// vector_type_t vec;
@@ -928,7 +1006,7 @@ struct tile_scatter_gather
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto forward_step_scatter = generate_tuple(
[&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
[&](auto i) { return is_gather_dim(i) ? 0 : idx_diff_ys[i]; },
number<NDimY>{});
constexpr auto idx_diff_ps_ys = container_concat(
@@ -1076,6 +1154,53 @@ struct tile_scatter_gather
};
// TODO: use strategy
/**
* @brief Factory function to create tile_scatter_gather with multi-dimensional gather support.
*
* This overload accepts a sequence<YsGatherDims...> to specify multiple Y-space dimensions
* for page lookup. Use this when the tile distribution decomposes the paged dimension
* into multiple Y-space dimensions (e.g., VECTORIZED_LAYOUT V tensor with K decomposition
* {K2, K0, K1} where both Y0 and Y2 contribute to page index).
*
* @tparam HsGatherDim H-space dimension for gather
* @tparam NumCoord Number of pre-computed coordinates
* @tparam YsGatherDims Parameter pack specifying which Y-dimensions are used for page lookup
*
* @param tensor_view The underlying tensor view for device memory access
* @param window_lengths Static window sizes for each dimension
* @param origin Window origin coordinates on the bottom tensor
* @param tile_distribution Thread-to-tile mapping distribution
* @param page_idx Array of page offsets (in bytes) for scatter/gather
*/
template <typename TensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
typename StaticPageIndexArray_,
index_t HsGatherDim,
index_t NumCoord,
index_t... YsGatherDims>
CK_TILE_DEVICE constexpr auto
make_tile_scatter_gather(const TensorView_& tensor_view,
const WindowLengths_& window_lengths,
const multi_index<TensorView_::get_num_of_dimension()>& origin,
const StaticTileDistribution_& tile_distribution,
const StaticPageIndexArray_& page_idx,
number<HsGatherDim>,
number<NumCoord>,
sequence<YsGatherDims...>)
{
return tile_scatter_gather<remove_cvref_t<TensorView_>,
remove_cvref_t<WindowLengths_>,
remove_cvref_t<StaticTileDistribution_>,
remove_cvref_t<StaticPageIndexArray_>,
std::nullptr_t,
HsGatherDim,
NumCoord,
sequence<YsGatherDims...>>{
tensor_view, window_lengths, origin, tile_distribution, page_idx, nullptr};
}
// Legacy overload (compatible with original API)
template <typename TensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
@@ -1087,7 +1212,7 @@ make_tile_scatter_gather(const TensorView_& tensor_view,
const WindowLengths_& window_lengths,
const multi_index<TensorView_::get_num_of_dimension()>& origin,
const StaticTileDistribution_& tile_distribution,
const StaticPageIndexArray_& page_idx, // perbytes
const StaticPageIndexArray_& page_idx,
number<HsGatherDim> = {},
number<NumCoord> = {})
{
@@ -1097,7 +1222,8 @@ make_tile_scatter_gather(const TensorView_& tensor_view,
remove_cvref_t<StaticPageIndexArray_>,
std::nullptr_t,
HsGatherDim,
NumCoord>{
NumCoord,
sequence<0>>{
tensor_view, window_lengths, origin, tile_distribution, page_idx, nullptr};
}

View File

@@ -533,32 +533,170 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
randval_dram_block_window_tmp, seqlen_k_start);
auto v_dist = Policy::template MakeVDramTileDistribution<Problem>();
auto v_coord = v_dist.calculate_index();
const auto VPageIndexDim = I1;
using VDstrEncode = typename decltype(v_dist)::DstrEncode;
constexpr index_t V_KRepeat = VDstrEncode::hs_lengthss_[I1][I3];
statically_indexed_array<index_t, V_KRepeat> v_offsets;
kv_offset_array_transform<statically_indexed_array<index_t, V_KRepeat>,
decltype(v_coord),
VPageIndexDim,
kPageBlockSize,
0,
V_KRepeat,
1,
kKVMemoryLayout,
false,
kN0,
kVectorSize>(
page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
auto v_dist = Policy::template MakeVDramTileDistribution<Problem>();
auto v_coord = v_dist.calculate_index();
using VDstrEncode = typename decltype(v_dist)::DstrEncode;
// V tensor K-dimension decomposition for page index computation
// ============================================================
// The K dimension (seqlen_k) in V distribution is decomposed into multiple sub-dimensions.
// This decomposition determines how threads iterate over the K dimension and how page
// indices are computed for paged KV cache.
//
// The decomposition pattern differs by memory layout:
//
// VECTORIZED_LAYOUT (ColumnMajor, custom distribution):
// 3D decomposition: K = K2 × K0 × K1
// - K2 (V_KIterOuter): Outer iteration count
// - K0 (V_KLanes): Lanes for K dimension (matches GEMM kABKLane)
// - K1 (V_KIterInner): Vector load size (matches GEMM kKPerThread)
// - hs_lengthss_[I1] = {K2, K0, K1}, size = 3 (or {K0, K1} size = 2 if no outer iter)
//
// LINEAR_LAYOUT ColumnMajor (base class distribution):
// 2D decomposition: K = K0 × K1
// - K0: Lanes for K dimension (may not match GEMM kABKLane)
// - K1: Vector load size
// - hs_lengthss_[I1] = {K0, K1}, size = 2
//
// LINEAR_LAYOUT RowMajor (base class distribution):
// 4D decomposition: K = K0 × K1 × K2 × K3 (uses shuffle_tile for GEMM alignment)
// 3D decomposition: K = K0 × K1 × K2 (fallback case)
// - Page lookup uses Y-space's last dimension only (inner iteration)
//
// V_PageIdxRepeat = total number of page lookups per thread = V_KIterOuter × V_KIterInner
constexpr index_t V_KIterInner = VDstrEncode::hs_lengthss_[I1].back();
// Compute V_KIterOuter and V_KLanes based on memory layout and K decomposition
constexpr index_t V_KIterOuter = [] {
if constexpr(kKVMemoryLayout ==
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
{
// VECTORIZED_LAYOUT: 3D decomposition {K2, K0, K1} when outer iteration is needed
if constexpr(VDstrEncode::hs_lengthss_[I1].size() == 3)
return static_cast<index_t>(VDstrEncode::hs_lengthss_[I1][I0]);
else
return index_t{1};
}
else
{
// LINEAR_LAYOUT: No outer iteration for page lookup
// RowMajor uses shuffle_tile, ColumnMajor has simple 2D decomposition
// Both cases use single-dimension Y-space page lookup
return index_t{1};
}
}();
constexpr index_t V_KLanes = [] {
if constexpr(kKVMemoryLayout ==
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
{
// VECTORIZED_LAYOUT: K0 is the lanes dimension
if constexpr(V_KIterOuter > 1)
return static_cast<index_t>(VDstrEncode::hs_lengthss_[I1][I1]);
else
return static_cast<index_t>(VDstrEncode::hs_lengthss_[I1][I0]);
}
else
{
// LINEAR_LAYOUT: First dimension is K0 (lanes)
return static_cast<index_t>(VDstrEncode::hs_lengthss_[I1][I0]);
}
}();
// This affects page offset computation - need to track offsets for each (k2, k1)
// combination
constexpr index_t V_PageIdxRepeat = V_KIterInner * V_KIterOuter;
// VPageIndexYDims: Y-space dimension indices that participate in page index computation
// ================================================================================
// In tile_scatter_gather, the gather index is computed from Y-space coordinates.
// This sequence specifies which Y dimensions should be linearized to form the page lookup
// index.
//
// VECTORIZED_LAYOUT with outer iteration: sequence<Y_K1, Y_K2>
// - Both K1 and K2 are in Y-space (thread iteration dimensions)
// - gather_index = y_k1 + y_k2 * len(Y_K1) (linearized 2D -> 1D)
//
// VECTORIZED_LAYOUT without outer iteration / LINEAR_LAYOUT: sequence<Y_K1>
// - Only the innermost K dimension is used for page lookup (single dimension)
//
constexpr auto VPageIndexYDims = []() {
// K1Minor is always the last element index in hs_lengthss_[I1]
constexpr index_t K1Minor = VDstrEncode::hs_lengthss_[I1].size() - 1;
constexpr index_t Y_K1 = VDstrEncode::detail::rhs_major_minor_to_ys_[2][K1Minor];
if constexpr(kKVMemoryLayout ==
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT &&
V_KIterOuter > 1)
{
// VECTORIZED_LAYOUT with outer iteration: need 2D page lookup
constexpr index_t Y_K2 = VDstrEncode::detail::rhs_major_minor_to_ys_[2][I0];
return sequence<Y_K1, Y_K2>{};
}
else
{
// LINEAR_LAYOUT or VECTORIZED_LAYOUT without outer iteration: 1D page lookup
return sequence<Y_K1>{};
}
}();
static_assert(decltype(VPageIndexYDims)::at(0) < VDstrEncode::NDimY,
"V page-index Y dim must be valid");
statically_indexed_array<index_t, V_PageIdxRepeat> v_offsets;
auto update_v_offsets = [&](auto k_loop_start) {
constexpr index_t kLoopStart = decltype(k_loop_start)::value;
// For 3D K decomposition (K2, K0, K1), compute offsets for each K2 slice
// The global K offset for (k2, k1) is: kLoopStart + k2 * (K0 * K1) + k1
// We iterate K2 outer, K1 inner, and merge into 1D v_offsets array
if constexpr(V_KIterOuter > 1)
{
static_for<0, V_KIterOuter, 1>{}([&](auto k2) {
statically_indexed_array<index_t, V_KIterInner> v_offsets_k2;
kv_offset_array_transform<statically_indexed_array<index_t, V_KIterInner>,
decltype(v_coord),
I1,
kPageBlockSize,
kLoopStart + k2.value * V_KLanes * V_KIterInner,
V_KIterInner,
1,
kKVMemoryLayout,
false,
kN0,
kVectorSize>(
page_idx, stride_v, page_stride_v, v_coord, v_offsets_k2, current_seq_k);
static_for<0, V_KIterInner, 1>{}([&](auto k1) {
constexpr auto idx = number<k1.value + k2.value * V_KIterInner>{};
v_offsets[idx] = v_offsets_k2[k1];
});
});
}
else
{
kv_offset_array_transform<statically_indexed_array<index_t, V_KIterInner>,
decltype(v_coord),
I1,
kPageBlockSize,
kLoopStart,
V_KIterInner,
1,
kKVMemoryLayout,
false,
kN0,
kVectorSize>(
page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
}
};
update_v_offsets(number<0>{});
auto v_dram_window =
make_tile_scatter_gather(v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
{0, seqlen_k_start}, // TODO: hdim split?
v_dist,
v_offsets,
VPageIndexDim);
number<1>{}, // HsGatherDim
number<1>{}, // NumCoord
VPageIndexYDims);
// prefetch K tile
async_load_tile_raw(
@@ -625,18 +763,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
__builtin_amdgcn_sched_barrier(1);
auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant<false>{});
kv_offset_array_transform<statically_indexed_array<index_t, V_KRepeat>,
decltype(v_coord),
VPageIndexDim,
kPageBlockSize,
kK1,
V_KRepeat,
1,
kKVMemoryLayout,
false,
kN0,
kVectorSize>(
page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
update_v_offsets(number<kK1>{});
v_dram_window.update_page_idx(v_offsets);
const auto p = [&]() {
@@ -766,7 +893,9 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
__builtin_amdgcn_sched_barrier(0x7F);
// store & prefetch next v, after the max reduction
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> &&
kKVMemoryLayout ==
BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
@@ -787,8 +916,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
get_slice_tile(v_lds_window,
sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
store_tile(v_lds_window_tmp,
tile_elementwise_in(v_element_func, v_buf)); // store the prefetch
const auto v_store_tile = tile_elementwise_in(v_element_func, v_buf);
store_tile(v_lds_window_tmp, v_store_tile); // store the prefetch
}
if constexpr(k1_loops > 1)
@@ -799,18 +928,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
kK1}); // will have scratch if move this right after load_tile(v_dram)...
v_buf = load_tile(
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
kv_offset_array_transform<statically_indexed_array<index_t, V_KRepeat>,
decltype(v_coord),
VPageIndexDim,
kPageBlockSize,
2 * kK1,
V_KRepeat,
1,
kKVMemoryLayout,
false,
kN0,
kVectorSize>(
page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
update_v_offsets(number<2 * kK1>{});
v_dram_window.update_page_idx(v_offsets);
}
__builtin_amdgcn_sched_barrier(0);
@@ -938,18 +1056,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
{
v_buf = load_tile(
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
kv_offset_array_transform<statically_indexed_array<index_t, V_KRepeat>,
decltype(v_coord),
VPageIndexDim,
kPageBlockSize,
(2 + i_k1.value) * kK1,
V_KRepeat,
1,
kKVMemoryLayout,
false,
kN0,
kVectorSize>(
page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
update_v_offsets(number<(2 + i_k1.value) * kK1>{});
v_dram_window.update_page_idx(v_offsets);
}
block_sync_lds();
@@ -961,7 +1068,9 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
sequence<(LdsSeq.at(number<k0_loops + i_k1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + i_k1>{}) + 1) * kN1, kK1>{}));
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> &&
kKVMemoryLayout ==
BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());

View File

@@ -4,15 +4,246 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
namespace ck_tile {
// This pipeline is qkv all located in LDS
using BlockFmhaBatchPrefillPipelineQRKSVSAsyncDefaultPolicy =
BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
/* AsyncCopy = */ true,
/* NumPrefetchK = */ 3,
/* NumPrefetchV = */ 3>;
struct BlockFmhaBatchPrefillPipelineQRKSVSAsyncDefaultPolicy
: BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
/* AsyncCopy = */ true,
/* NumPrefetchK = */ 3,
/* NumPrefetchV = */ 3>
{
using Base = BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
/* AsyncCopy = */ true,
/* NumPrefetchK = */ 3,
/* NumPrefetchV = */ 3>;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV()
{
if constexpr(Problem::kKVMemoryLayout ==
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
{
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t kDwordx4Bytes = 16;
return kDwordx4Bytes / sizeof(VDataType);
}
else
{
return Base::template GetAlignmentV<Problem>();
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV()
{
if constexpr(Problem::kKVMemoryLayout ==
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
{
// For VECTORIZED_LAYOUT, kKPack should match GEMM's kKPerThread
// to ensure correct LDS access pattern
constexpr auto gemm_k_decomp = GetGemmKDecomposition<Problem>();
constexpr index_t kKPerThread = gemm_k_decomp.template at<1>();
return kKPerThread;
}
else
{
return Base::template GetSmemKPackV<Problem>();
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSingleSmemElementSpaceSize()
{
if constexpr(Problem::kKVMemoryLayout ==
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
{
// For VECTORIZED_LAYOUT, we need to use our GetSmemKPackV for V size calculation
constexpr index_t SingleKSize = [&]() {
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
constexpr index_t WarpSize = ck_tile::get_warp_size();
constexpr index_t KPack = Base::template GetSmemKPackK<Problem>();
constexpr index_t KVector = Base::template GetAlignmentK<Problem>();
constexpr index_t kPad = KPack;
static_assert(WarpSize * KVector >= kKPerBlock &&
WarpSize * KVector % kKPerBlock == 0);
constexpr index_t LanesPerK = kKPerBlock / KVector;
constexpr index_t LaneGroups = WarpSize / LanesPerK;
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
return NumIssues * NumWarps * (WarpSize * KVector + kPad);
}();
constexpr index_t SingleVSize = [&]() {
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t Banks = get_n_lds_banks();
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
constexpr index_t kKPack = GetSmemKPackV<Problem>(); // Use our override!
static_assert(PixelsPerRow % kKPack == 0);
constexpr index_t NPerRow = PixelsPerRow / kKPack;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
static_assert(kNPerBlock % NPerRow == 0);
static_assert(kKPerBlock % kKPack == 0);
return (kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack);
}();
return max(SingleKSize, SingleVSize);
}
else
{
return Base::template GetSingleSmemElementSpaceSize<Problem>();
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor()
{
if constexpr(Problem::kKVMemoryLayout ==
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
{
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t Banks = get_n_lds_banks();
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
constexpr index_t kKPack = GetSmemKPackV<Problem>();
static_assert(PixelsPerRow % kKPack == 0);
constexpr index_t NPerRow = PixelsPerRow / kKPack;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
static_assert(kNPerBlock % NPerRow == 0);
static_assert(kKPerBlock % kKPack == 0);
constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<Base::NumKVLdsBuffers>{},
number<kKPerBlock / kKPack>{},
number<kNPerBlock / NPerRow>{},
number<NPerRow>{},
number<kKPack>{}),
make_tuple(number<GetSingleSmemElementSpaceSize<Problem>()>{},
number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{},
number<PixelsPerRow + kKPack>{},
number<kKPack>{},
number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto v_lds_block_desc = transform_tensor_descriptor(
v_lds_block_desc_0,
make_tuple(make_merge_transform(make_tuple(number<Base::NumKVLdsBuffers>{},
number<kNPerBlock / NPerRow>{},
number<NPerRow>{})),
make_merge_transform(
make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<0, 2, 3>{}, sequence<1, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return v_lds_block_desc;
}
else
{
return Base::template MakeVLdsBlockDescriptor<Problem>();
}
}
// Helper to get GEMM's K decomposition parameters (kABKLane, kKPerThread)
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetGemmKDecomposition()
{
// Get the KV block GEMM and extract warp gemm's K decomposition
constexpr auto gemm = Base::template GetKVBlockGemm<Problem>();
using BlockGemm = remove_cvref_t<decltype(gemm)>;
constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
// Return kABKLane and kKPerThread from warp gemm
return make_tuple(number<WG::WarpGemmAttribute::Impl::kABKLane>{},
number<WG::kKPerThread>{});
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution()
{
if constexpr(Problem::kKVMemoryLayout ==
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
{
// For VECTORIZED_LAYOUT, use column-major distribution (K direction vector load)
// The K decomposition must match GEMM's BWarpDstrEncoding to ensure correct LDS access
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
// Get GEMM's K decomposition (kABKLane, kKPerThread)
constexpr auto gemm_k_decomp = GetGemmKDecomposition<Problem>();
constexpr index_t kABKLane = gemm_k_decomp.template at<0>();
constexpr index_t kKPerThread = gemm_k_decomp.template at<1>();
// K1 = kKPerThread (inner K dimension, matches GEMM's expectation)
// K0 = kKPerBlock / K1 (outer K dimension)
// But we need K0 to match kABKLane for the per-warp iteration
constexpr index_t K1 = kKPerThread;
constexpr index_t K0 = kABKLane;
// Verify K decomposition matches GEMM's BWarpDstrEncoding requirements
static_assert(K0 == kABKLane, "K0 must match GEMM's kABKLane for correct LDS access");
static_assert(K1 == kKPerThread,
"K1 must match GEMM's kKPerThread for correct LDS access");
// K0 * K1 may be less than kKPerBlock, so we need outer iteration
constexpr index_t KPerIter = K0 * K1;
constexpr index_t KOuterIter = kKPerBlock / KPerIter;
constexpr index_t N2 = get_warp_size() / K0;
constexpr index_t N1 = kBlockSize / get_warp_size();
static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error.");
static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error.");
constexpr index_t N0 = kNPerBlock / (N2 * N1);
static_assert(N0 != 0, "N0 is zero");
if constexpr(KOuterIter == 1)
{
// Simple case: K decomposition matches exactly
constexpr auto dstr = make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<2, 1>,
sequence<1, 0>>{});
static_assert(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
kNPerBlock * kKPerBlock);
return dstr;
}
else
{
// Need outer K iteration
constexpr index_t K2 = KOuterIter;
constexpr auto dstr = make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K2, K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 1>>,
sequence<2, 1, 2>,
sequence<2, 0, 0>>{});
static_assert(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
kNPerBlock * kKPerBlock);
return dstr;
}
}
else
{
// For non-VECTORIZED_LAYOUT, use base class implementation
return Base::template MakeVDramTileDistribution<Problem>();
}
}
};
} // namespace ck_tile

View File

@@ -121,6 +121,9 @@ struct BlockFmhaBatchPrefillPipelineProblem
static_assert(!kIsVectorizedLayout || kPageBlockSize % kVectorSize == 0,
"kPageBlockSize must be divisible by kVectorSize for vectorized layout");
static_assert(kIsGroupMode_, "Batch prefill requires group mode");
static_assert(BlockFmhaShape_::IsVLayoutRowMajor,
"Batch prefill kernel requires RowMajor VLayout");
};
template <typename QDataType_,

View File

@@ -96,9 +96,9 @@ struct AQPickerCommon : public BlockGemmQuantBase
if constexpr(Traits::TransposeC) // transposed C
{
index_t reg_offset =
Traits::PreshuffleQuant ? mIter : mIter * Traits::AQPerBlock + kQScale;
Traits::APreshuffleQuant ? mIter : mIter * Traits::AQPerBlock + kQScale;
auto scale_reg = aq_block_tensor.get_thread_buffer()[reg_offset];
if constexpr(Traits::PreshuffleQuant)
if constexpr(Traits::APreshuffleQuant)
{
auto pull_from_lane =
(__lane_id() & (Traits::WarpGemm::kN - 1)) * Traits::AQPerBlock + kQScale;
@@ -121,7 +121,7 @@ struct AQPickerCommon : public BlockGemmQuantBase
}
else
{
if constexpr(Traits::PreshuffleQuant)
if constexpr(Traits::APreshuffleQuant)
{
// A view is created on top of the preshuffled AQ, where each row of
// the view is composed of a row from a warp tile within an AQ block

View File

@@ -69,7 +69,8 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
static constexpr bool APreshuffleQuant = Problem::Traits::APreshuffleQuant;
static constexpr bool BPreshuffleQuant = Problem::Traits::BPreshuffleQuant;
static constexpr index_t QScalesPerBlockRow =
integer_divide_ceil(KPerBlock, BQuantGroupSize::kK);
@@ -127,9 +128,9 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; // TileFlatmmShape
using QuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
using BQuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!");
static_assert(BQuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!");
static constexpr auto I0 = number<0>();
static constexpr auto I1 = number<1>();
@@ -162,12 +163,12 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
static constexpr auto MIter_2nd_last =
(MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1;
static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize::kK;
static constexpr index_t KPerBlockBQ = KPerBlock / BQuantGroupSize::kK;
static constexpr index_t QScalesPerBlockRow =
integer_divide_ceil(KPerBlock, QuantGroupSize::kK); // 128 / 128 = 1
integer_divide_ceil(KPerBlock, BQuantGroupSize::kK); // 128 / 128 = 1
static constexpr index_t QScalesPerWarpGemmRow =
integer_divide_ceil(WG::kK, QuantGroupSize::kK);
integer_divide_ceil(WG::kK, BQuantGroupSize::kK);
static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow; // 8 / 1 = 8
static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read
@@ -289,9 +290,9 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
CBlockTensor::PackedSize>{};
index_t reg_offset = [&]() {
if constexpr(QuantGroupSize::kN >= (NWarp * WG::kN))
if constexpr(BQuantGroupSize::kN >= (NWarp * WG::kN))
{
return (nIter * NWarp * WG::kN) / QuantGroupSize::kN * KPerBlockBQ +
return (nIter * NWarp * WG::kN) / BQuantGroupSize::kN * KPerBlockBQ +
kQScale;
}
else

View File

@@ -25,9 +25,9 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; // TileFlatmmShape
using QuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
using BQuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!");
static_assert(BQuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!");
static constexpr auto I0 = number<0>();
static constexpr auto I1 = number<1>();
@@ -53,7 +53,7 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
static constexpr bool BPreshuffleQuant = Problem::Traits::BPreshuffleQuant;
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
static constexpr index_t NIterPerWarp =
@@ -63,12 +63,12 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
static constexpr auto MIter_2nd_last =
(MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1;
static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize::kK;
static constexpr index_t KPerBlockBQ = KPerBlock / BQuantGroupSize::kK;
static constexpr index_t QScalesPerBlockRow =
integer_divide_ceil(KPerBlock, QuantGroupSize::kK);
integer_divide_ceil(KPerBlock, BQuantGroupSize::kK);
static constexpr index_t QScalesPerWarpGemmRow =
integer_divide_ceil(WG::kK, QuantGroupSize::kK);
integer_divide_ceil(WG::kK, BQuantGroupSize::kK);
static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow;
static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read
@@ -173,7 +173,7 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
c_warp_y_index_zeros)) /
CBlockTensor::PackedSize>{};
if constexpr(PreshuffleQuant)
if constexpr(BPreshuffleQuant)
{
constexpr index_t reg_offset = nIter;
auto pull_from_lane = (__lane_id() & (WG::kN - 1)) * KPerBlockBQ + kQScale;
@@ -205,9 +205,10 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
else
{
index_t reg_offset = [&]() {
if constexpr(QuantGroupSize::kN >= (NWarp * WG::kN))
if constexpr(BQuantGroupSize::kN >= (NWarp * WG::kN))
{
return (nIter * NWarp * WG::kN) / QuantGroupSize::kN * KPerBlockBQ +
return (nIter * NWarp * WG::kN) / BQuantGroupSize::kN *
KPerBlockBQ +
kQScale;
}
else

View File

@@ -33,6 +33,7 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
using AQDataType = remove_cvref_t<typename Problem::AQDataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
@@ -75,7 +76,8 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
static constexpr bool APreshuffleQuant = Problem::Traits::APreshuffleQuant;
static constexpr bool BPreshuffleQuant = Problem::Traits::BPreshuffleQuant;
static constexpr index_t QScalesPerBlockRow =
integer_divide_ceil(KPerBlock, BQuantGroupSize::kK);
@@ -134,8 +136,12 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
using CDataType = remove_cvref_t<typename Traits::CDataType>;
// BDataType gets converted from PkInt4 during loading
using OverrideBDataType =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
using OverrideBDataType = std::conditional_t<
std::is_same_v<BDataType, pk_int4_t> &&
std::is_same_v<typename Traits::BLayout, tensor_layout::gemm::RowMajor>,
ADataType,
BDataType>;
using Base = BlockGemmQuantBase;
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
@@ -156,7 +162,8 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
using BWarpTensor = typename WarpGemm::BWarpTensor;
using CWarpTensor = typename WarpGemm::CWarpTensor;
static constexpr bool PreshuffleQuant = Traits::PreshuffleQuant;
static constexpr bool APreshuffleQuant = Traits::APreshuffleQuant;
static constexpr bool BPreshuffleQuant = Traits::BPreshuffleQuant;
static_assert(std::is_same_v<typename WarpGemm::CDataType, float>);
@@ -354,11 +361,24 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
AQPickerCommon<AQBlockTensor, Traits, mIter, kQScale> aq_picker(
aq_block_tensor);
if constexpr(PreshuffleQuant)
if constexpr(BPreshuffleQuant)
{
constexpr index_t reg_offset = nIter;
constexpr index_t reg_offset = [&]() {
if constexpr(GemmTraits::BQuantGroupSize::kN >
(NWarp * WarpGemm::kN) &&
Traits::NPerBlock == GemmTraits::BQuantGroupSize::kN)
{
return kQScale;
}
else
{
return nIter;
}
}();
auto pull_from_lane =
(__lane_id() & (WarpGemm::kN - 1)) * Traits::KQPerBlock + kQScale;
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
// cross lane ops
uint32_t scale_reg_dword;

View File

@@ -34,7 +34,7 @@ struct AQuantBlockUniversalGemmAsBsCr
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using QuantGroupSize = remove_cvref_t<typename Problem::AQuantGroupSize>;
using AQuantGroupSize = remove_cvref_t<typename Problem::AQuantGroupSize>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr auto Scheduler = Problem::Scheduler;
@@ -43,7 +43,7 @@ struct AQuantBlockUniversalGemmAsBsCr
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t AQPerBlock = KPerBlock / QuantGroupSize::kK;
static constexpr index_t AQPerBlock = KPerBlock / AQuantGroupSize::kK;
static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
@@ -69,20 +69,20 @@ struct AQuantBlockUniversalGemmAsBsCr
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
static constexpr index_t QScalesPerBlockRow =
integer_divide_ceil(KPerBlock, QuantGroupSize::kK);
integer_divide_ceil(KPerBlock, AQuantGroupSize::kK);
static constexpr index_t QScalesPerWarpGemmRow =
integer_divide_ceil(WarpGemm::kK, QuantGroupSize::kK);
integer_divide_ceil(WarpGemm::kK, AQuantGroupSize::kK);
static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow;
static_assert(QuantGroupSize::kK % WarpGemm::kK == 0,
"Error! WarpGemm::kK should be a multiple of QuantGroupSize");
static_assert(AQuantGroupSize::kK % WarpGemm::kK == 0,
"Error! WarpGemm::kK should be a multiple of AQuantGroupSize");
static_assert(QScalesPerWarpGemmRow == 1,
"Error! QuantGroupSize shouldn't be smaller than WarpGemm::kK");
"Error! AQuantGroupSize shouldn't be smaller than WarpGemm::kK");
static_assert(KIterPerWarp % QScalesPerBlockRow == 0,
"Error! KItersPerWarp should be a multiple of QscalesPerBlockRow");
static_assert(KPerBlock / QuantGroupSize::kK > 0,
static_assert(KPerBlock / AQuantGroupSize::kK > 0,
"Error! Each row of blockgemm should have a separate scale");
static_assert(MIterPerWarp * MWarp * WarpGemm::kM == MPerBlock,
@@ -110,8 +110,8 @@ struct AQuantBlockUniversalGemmAsBsCr
static constexpr index_t KPack = WarpGemm::kKPerThread;
static constexpr index_t KPerThread = KIterPerWarp * WarpGemm::kKPerThread;
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
static constexpr bool TransposeC = Problem::TransposeC;
static constexpr bool APreshuffleQuant = Problem::Traits::APreshuffleQuant;
static constexpr bool TransposeC = Problem::TransposeC;
};
public:

View File

@@ -36,7 +36,7 @@ struct BQuantBlockUniversalGemmAsBsCr
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using QuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
using BQuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr auto Scheduler = Problem::Scheduler;
@@ -46,8 +46,8 @@ struct BQuantBlockUniversalGemmAsBsCr
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t NQPerBlock = NPerBlock / QuantGroupSize::kN;
static constexpr index_t KQPerBlock = KPerBlock / QuantGroupSize::kK;
static constexpr index_t NQPerBlock = NPerBlock / BQuantGroupSize::kN;
static constexpr index_t KQPerBlock = KPerBlock / BQuantGroupSize::kK;
static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
@@ -72,23 +72,23 @@ struct BQuantBlockUniversalGemmAsBsCr
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
static constexpr bool BPreshuffleQuant = Problem::Traits::BPreshuffleQuant;
static constexpr index_t QScalesPerBlockRow =
integer_divide_ceil(KPerBlock, QuantGroupSize::kK);
integer_divide_ceil(KPerBlock, BQuantGroupSize::kK);
static constexpr index_t QScalesPerWarpGemmRow =
integer_divide_ceil(WarpGemm::kK, QuantGroupSize::kK);
integer_divide_ceil(WarpGemm::kK, BQuantGroupSize::kK);
static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow;
static_assert(QuantGroupSize::kK % WarpGemm::kK == 0,
"Error! WarpGemm::kK should be a multiple of QuantGroupSize");
static_assert(BQuantGroupSize::kK % WarpGemm::kK == 0,
"Error! WarpGemm::kK should be a multiple of BQuantGroupSize");
static_assert(QScalesPerWarpGemmRow == 1,
"Error! QuantGroupSize shouldn't be smaller than WarpGemm::kK");
"Error! BQuantGroupSize shouldn't be smaller than WarpGemm::kK");
static_assert(KIterPerWarp % QScalesPerBlockRow == 0,
"Error! KItersPerWarp should be a multiple of QscalesPerBlockRow");
static_assert(KPerBlock / QuantGroupSize::kK > 0,
static_assert(KPerBlock / BQuantGroupSize::kK > 0,
"Error! Each row of blockgemm should have a separate scale");
static_assert(MIterPerWarp * MWarp * WarpGemm::kM == MPerBlock,
@@ -153,7 +153,7 @@ struct BQuantBlockUniversalGemmAsBsCr
using BWarpTensor = typename WarpGemm::BWarpTensor;
using CWarpTensor = typename WarpGemm::CWarpTensor;
static constexpr bool PreshuffleQuant = Traits::PreshuffleQuant;
static constexpr bool BPreshuffleQuant = Traits::BPreshuffleQuant;
static_assert(std::is_same_v<typename WarpGemm::CDataType, float>);
@@ -317,25 +317,21 @@ struct BQuantBlockUniversalGemmAsBsCr
c_warp_y_index_zeros)) /
CBlockTensor::PackedSize>{};
if constexpr(PreshuffleQuant)
if constexpr(BPreshuffleQuant)
{
// constexpr index_t reg_offset = nIter;
constexpr index_t reg_offset = [&]() {
if constexpr(GemmTraits::QuantGroupSize::kN >
(NWarp * WarpGemm::kN))
if constexpr(GemmTraits::BQuantGroupSize::kN >
(NWarp * WarpGemm::kN) &&
Traits::NPerBlock == GemmTraits::BQuantGroupSize::kN)
{
if constexpr(Traits::NPerBlock ==
GemmTraits::QuantGroupSize::kN)
return kQScale;
else
return nIter; // for prefill needs kQscale, for decode needs
// nIter
return kQScale; // prefill: one quant group per block
}
else
{
return nIter;
return nIter; // decode or multiple groups per warp
}
}();
auto pull_from_lane =
(__lane_id() & (WarpGemm::kN - 1)) * Traits::KQPerBlock + kQScale;
@@ -370,10 +366,11 @@ struct BQuantBlockUniversalGemmAsBsCr
{
// Multiply bquant with accumulated C
constexpr index_t reg_offset = [&]() {
if constexpr(GemmTraits::QuantGroupSize::kN >=
if constexpr(GemmTraits::BQuantGroupSize::kN >=
(NWarp * WarpGemm::kN))
return (nIter * NWarp * WarpGemm::kN) /
GemmTraits::QuantGroupSize::kN * Traits::KQPerBlock +
GemmTraits::BQuantGroupSize::kN *
Traits::KQPerBlock +
kQScale;
else
{

View File

@@ -67,15 +67,27 @@ struct get_bq_data_type_or<T, Default, std::void_t<typename T::BQDataType>>
};
template <typename, typename = void>
struct is_quantpreshuffle_enabled
struct is_Aquantpreshuffle_enabled
{
static constexpr bool value = false;
};
template <typename T>
struct is_quantpreshuffle_enabled<T, std::void_t<decltype(T::PreshuffleQuant)>>
struct is_Aquantpreshuffle_enabled<T, std::void_t<decltype(T::APreshuffleQuant)>>
{
static constexpr bool value = T::PreshuffleQuant;
static constexpr bool value = T::APreshuffleQuant;
};
template <typename, typename = void>
struct is_Bquantpreshuffle_enabled
{
static constexpr bool value = false;
};
template <typename T>
struct is_Bquantpreshuffle_enabled<T, std::void_t<decltype(T::BPreshuffleQuant)>>
{
static constexpr bool value = T::BPreshuffleQuant;
};
template <typename, typename = void>
@@ -206,8 +218,10 @@ struct QuantGemmKernel
typename detail::get_bq_layout_or<GemmPipeline, typename GemmPipeline::BLayout>::type>;
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
static constexpr bool PreshuffleQuant =
detail::is_quantpreshuffle_enabled<GemmPipeline_>::value;
static constexpr bool APreshuffleQuant =
detail::is_Aquantpreshuffle_enabled<GemmPipeline_>::value;
static constexpr bool BPreshuffleQuant =
detail::is_Bquantpreshuffle_enabled<GemmPipeline_>::value;
static constexpr bool PreshuffleB = detail::is_preshuffleB_enabled<GemmPipeline_>::value;
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
@@ -476,7 +490,7 @@ struct QuantGemmKernel
{
// Step 1: Create tensor view for AQ
const auto& aq_tensor_view = [&]() {
if constexpr(kQuantType == QuantType::AQuantGrouped && PreshuffleQuant)
if constexpr(kQuantType == QuantType::AQuantGrouped && APreshuffleQuant)
{
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
const auto aq_x = kargs.M * GemmPipeline::KPerBlockAQ;
@@ -533,7 +547,7 @@ struct QuantGemmKernel
}
else if constexpr((kQuantType == QuantType::AQuantGrouped ||
kQuantType == QuantType::ABQuantGrouped) &&
!PreshuffleQuant)
!APreshuffleQuant)
{
if constexpr(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>)
{
@@ -571,13 +585,13 @@ struct QuantGemmKernel
// Step 2: Create tile window (no padding for AQ)
const auto& aq_block_window = [&]() {
if constexpr(kQuantType == QuantType::AQuantGrouped && PreshuffleQuant)
if constexpr(kQuantType == QuantType::AQuantGrouped && APreshuffleQuant)
{
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
using AQuantGroupSize = remove_cvref_t<typename GemmPipeline::AQuantGroupSize>;
constexpr auto block_m = TilePartitioner::MPerBlock;
constexpr auto warp_m = GemmPipeline::BlockGemmShape::WarpTile::at(I0);
constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
constexpr auto aqk_per_block = TilePartitioner::KPerBlock / AQuantGroupSize::kK;
constexpr auto tile_window_width =
ck_tile::integer_least_multiple(warp_m * aqk_per_block, get_warp_size());
constexpr auto tile_window_height = block_m / warp_m;
@@ -587,11 +601,19 @@ struct QuantGemmKernel
make_tuple(number<tile_window_height>{}, number<tile_window_width>{}),
{block_m_idx * tile_window_height, 0});
}
else if constexpr(kQuantType == QuantType::AQuantGrouped && !PreshuffleQuant)
else if constexpr((kQuantType == QuantType::AQuantGrouped ||
kQuantType == QuantType::ABQuantGrouped) &&
!APreshuffleQuant)
{
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
using AQuantGroupSize = remove_cvref_t<typename GemmPipeline::AQuantGroupSize>;
constexpr auto aqk_per_block = TilePartitioner::KPerBlock / AQuantGroupSize::kK;
constexpr auto block_m = TilePartitioner::MPerBlock;
if constexpr(kQuantType == QuantType::ABQuantGrouped)
{
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>,
"ABQuantGrouped requires RowMajor AQ layout");
}
if constexpr(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(aq_tensor_view,
@@ -605,17 +627,6 @@ struct QuantGemmKernel
{0, i_m});
}
}
else if constexpr(kQuantType == QuantType::ABQuantGrouped && !PreshuffleQuant)
{
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::AQuantGroupSize>;
constexpr auto block_m = TilePartitioner::MPerBlock;
constexpr auto block_k = TilePartitioner::KPerBlock;
return make_tile_window(
aq_tensor_view,
make_tuple(number<block_m>{}, number<block_k / QuantGroupSize::kK>{}),
{i_m, 0});
}
else if constexpr(kQuantType == QuantType::RowColQuant)
{
return make_tile_window(aq_tensor_view,
@@ -808,14 +819,15 @@ struct QuantGemmKernel
number<1>{},
number<1>{});
}
else if constexpr(kQuantType == QuantType::BQuantGrouped)
else if constexpr(kQuantType == QuantType::BQuantGrouped ||
kQuantType == QuantType::ABQuantGrouped)
{
if constexpr(PreshuffleQuant)
if constexpr(BPreshuffleQuant)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>,
"PreshuffleQuant with BQuantGrouped currently only supports "
"ColumnMajor BQ layout");
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
using BQuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
return MakePreshuffledQuantTensorView<
GemmPipeline::KPerBlockBQ,
@@ -824,48 +836,42 @@ struct QuantGemmKernel
TilePartitioner::BlockGemmShape::WarpTile::at(I1),
GemmPipeline::GetVectorSizeBQ()>(
bq_ptr,
ck_tile::integer_divide_ceil(kargs.N, QuantGroupSize::kN),
QuantGroupSize::kN,
ck_tile::integer_divide_ceil(kargs.N, BQuantGroupSize::kN),
BQuantGroupSize::kN,
kargs.QK_B);
}
else
{
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
using BQuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
if constexpr(kQuantType == QuantType::ABQuantGrouped)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>,
"ABQuantGrouped requires ColumnMajor BQ layout");
}
if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
bq_ptr,
make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK),
integer_divide_ceil(kargs.N, QuantGroupSize::kN)),
make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), 1),
make_tuple(integer_divide_ceil(kargs.K, BQuantGroupSize::kK),
integer_divide_ceil(kargs.N, BQuantGroupSize::kN)),
make_tuple(integer_divide_ceil(kargs.N, BQuantGroupSize::kN), 1),
number<GemmPipeline::GetVectorSizeBQ()>{},
number<1>{});
}
else
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
return make_naive_tensor_view<address_space_enum::global>(
bq_ptr,
make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN),
integer_divide_ceil(kargs.K, QuantGroupSize::kK)),
make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK), 1),
make_tuple(integer_divide_ceil(kargs.N, BQuantGroupSize::kN),
integer_divide_ceil(kargs.K, BQuantGroupSize::kK)),
make_tuple(integer_divide_ceil(kargs.K, BQuantGroupSize::kK), 1),
number<GemmPipeline::GetVectorSizeBQ()>{},
number<1>{});
}
}
}
else if constexpr(kQuantType == QuantType::ABQuantGrouped)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
return make_naive_tensor_view<address_space_enum::global>(
bq_ptr,
make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), kargs.QK_B),
make_tuple(kargs.stride_BQ, 1),
number<GemmPipeline::GetVectorSizeBQ()>{},
number<1>{});
}
else
{
return nullptr;
@@ -881,28 +887,29 @@ struct QuantGemmKernel
number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
}
else if constexpr(kQuantType == QuantType::BQuantGrouped)
else if constexpr(kQuantType == QuantType::BQuantGrouped ||
kQuantType == QuantType::ABQuantGrouped)
{
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
if constexpr(PreshuffleQuant)
using BQuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
if constexpr(BPreshuffleQuant)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
// Number of N-dimension quantization groups per block
constexpr auto block_n = (QuantGroupSize::kN <= TilePartitioner::NPerBlock)
? TilePartitioner::NPerBlock / QuantGroupSize::kN
: QuantGroupSize::kN / TilePartitioner::NPerBlock;
constexpr auto block_n = (BQuantGroupSize::kN <= TilePartitioner::NPerBlock)
? TilePartitioner::NPerBlock / BQuantGroupSize::kN
: BQuantGroupSize::kN / TilePartitioner::NPerBlock;
// Number of N-dimension elements per warp
constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1);
// Determine how many warps share the same scale in N-dimension
constexpr auto warp_per_group = (QuantGroupSize::kN < warp_n)
? (warp_n / QuantGroupSize::kN)
: (QuantGroupSize::kN / warp_n);
constexpr auto warp_per_group = (BQuantGroupSize::kN < warp_n)
? (warp_n / BQuantGroupSize::kN)
: (BQuantGroupSize::kN / warp_n);
// Number of K-dimension quantization groups per block
constexpr auto bqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
constexpr auto bqk_per_block = TilePartitioner::KPerBlock / BQuantGroupSize::kK;
// The pre-shuffled layout flattens warp_n ×
// bqk_per_block scales per row, Padded up to warp_size
@@ -911,25 +918,25 @@ struct QuantGemmKernel
ck_tile::integer_least_multiple(warp_n * bqk_per_block, get_warp_size());
// Adapts based on fine vs coarse quantization granularity:
// - Fine-grained (QuantGroupSize::kN < warp_n):
// - Fine-grained (BQuantGroupSize::kN < warp_n):
// Multiple quant groups per warp → fewer rows needed per block.
// height = block_n / warp_per_group
//
// - Coarse-grained (QuantGroupSize::kN >= warp_n):
// - Coarse-grained (BQuantGroupSize::kN >= warp_n):
// Each row represents one quant group.
// height = block_n
constexpr auto tile_window_height =
(QuantGroupSize::kN < warp_n) ? block_n / warp_per_group : block_n;
(BQuantGroupSize::kN < warp_n) ? block_n / warp_per_group : block_n;
auto block_n_idx = i_n / TilePartitioner::NPerBlock;
// For decode shapes GN: 128, Blocks needs to repeat 0,0,1,1,2,2 ...
if(QuantGroupSize::kN > TilePartitioner::NPerBlock)
if(BQuantGroupSize::kN > TilePartitioner::NPerBlock)
{
block_n_idx = block_n_idx >> 1;
}
if(QuantGroupSize::kN > TilePartitioner::NPerBlock)
if(BQuantGroupSize::kN > TilePartitioner::NPerBlock)
{
return make_tile_window(
bq_tensor_view,
@@ -946,17 +953,22 @@ struct QuantGemmKernel
}
else
{
if constexpr(kQuantType == QuantType::ABQuantGrouped)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>,
"ABQuantGrouped requires RowMajor AQ layout");
}
constexpr auto tensor_dim =
(QuantGroupSize::kN <= TilePartitioner::NPerBlock)
? TilePartitioner::NPerBlock / QuantGroupSize::kN
(BQuantGroupSize::kN <= TilePartitioner::NPerBlock)
? TilePartitioner::NPerBlock / BQuantGroupSize::kN
: 1;
if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(
bq_tensor_view,
make_tuple(number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{},
make_tuple(number<TilePartitioner::KPerBlock / BQuantGroupSize::kK>{},
number<tensor_dim>{}),
{0, i_n / QuantGroupSize::kN});
{0, i_n / BQuantGroupSize::kN});
}
else
{
@@ -964,21 +976,11 @@ struct QuantGemmKernel
return make_tile_window(
bq_tensor_view,
make_tuple(number<tensor_dim>{},
number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{}),
{i_n / QuantGroupSize::kN, 0});
number<TilePartitioner::KPerBlock / BQuantGroupSize::kK>{}),
{i_n / BQuantGroupSize::kN, 0});
}
}
}
else if constexpr(kQuantType == QuantType::ABQuantGrouped)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
return make_tile_window(
bq_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{},
number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{}),
{i_n / QuantGroupSize::kN, 0});
}
else
{
return nullptr;
@@ -1223,7 +1225,7 @@ struct QuantGemmKernel
if constexpr(kQuantType == QuantType::AQuantGrouped)
{
index_t m = 0;
if constexpr(PreshuffleQuant)
if constexpr(APreshuffleQuant)
{
m = kargs.M;
}
@@ -1233,7 +1235,7 @@ struct QuantGemmKernel
else if constexpr(kQuantType == QuantType::BQuantGrouped)
{
index_t n = 0;
if constexpr(PreshuffleQuant)
if constexpr(BPreshuffleQuant)
{
n = kargs.N;
}
@@ -1244,9 +1246,9 @@ struct QuantGemmKernel
{
index_t m = 0;
index_t n = 0;
if constexpr(PreshuffleQuant)
if constexpr(BPreshuffleQuant)
{
m = kargs.M;
// m = kargs.M;
n = kargs.N;
}
return GemmPipeline{}.template operator()(a_block_window,

View File

@@ -72,7 +72,10 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t KPerBlockAQ = BlockGemmShape::kK / AQuantGroupSize::kK;
static constexpr index_t NPerBlockBQ = BlockGemmShape::kN / BQuantGroupSize::kN;
static constexpr index_t NPerBlockBQ =
(BQuantGroupSize::kN <= BlockGemmShape::kN)
? integer_divide_ceil(BlockGemmShape::kN, BQuantGroupSize::kN)
: 1;
static constexpr index_t KPerBlockBQ = BlockGemmShape::kK / BQuantGroupSize::kK;
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
@@ -95,7 +98,8 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
static constexpr bool kPadK = Problem::kPadK;
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
static constexpr bool APreshuffleQuant = Problem::Traits::APreshuffleQuant;
static constexpr bool BPreshuffleQuant = Problem::Traits::BPreshuffleQuant;
static constexpr bool HasHotLoop = Problem::HasHotLoop;
static constexpr auto TailNum = Problem::TailNum;
@@ -264,7 +268,7 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"B block window has incorrect lengths for defined BLayout!");
static_assert(
PreshuffleQuant ||
BPreshuffleQuant ||
(is_bq_row_major
? (KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}])
@@ -323,15 +327,18 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
// only row_major for AQ
const AQDramTileWindowStep aq_dram_tile_window_step =
PreshuffleQuant
APreshuffleQuant
? make_array(ck_tile::integer_least_multiple(m, MPerBlock) /
BlockGemm::WarpGemm::kM,
0)
: (is_aq_col_major ? make_array(KPerBlockAQ, 0) : make_array(0, KPerBlockAQ));
const BQDramTileWindowStep bq_dram_tile_window_step =
(PreshuffleQuant) ? make_array(ck_tile::integer_least_multiple(n, NPerBlock) /
BlockGemmShape::WarpTile::at(number<1>{}),
0)
(BPreshuffleQuant)
? make_array(((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{}))
? ck_tile::integer_divide_ceil(n, BQuantGroupSize::kN)
: ck_tile::integer_least_multiple(n, NPerBlock) /
BlockGemmShape::WarpTile::at(number<1>{})),
0)
: is_bq_row_major ? make_array(KPerBlockBQ, 0)
: make_array(0, KPerBlockBQ);
@@ -484,7 +491,7 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
currIdx = (currIdx + 1) % 2;
if constexpr(is_a_col_major)
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
@@ -495,7 +502,7 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
{
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
}
if constexpr(is_b_row_major)
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
// Note: BDataType gets converted during loading from PkInt4
auto b_shuffle_tmp = make_static_distributed_tensor<OverrideBDataType>(

View File

@@ -12,21 +12,21 @@ namespace ck_tile {
template <typename Problem, typename Policy>
struct GemmAQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Problem, Policy>
{
using Base = GemmPipelineAgBgCrImplBase<Problem, Policy>;
using ADataType = typename Base::ADataType;
using ALayout = typename Base::ALayout;
using BDataType = typename Base::BDataType;
using BLayout = typename Base::BLayout;
using BlockGemmShape = typename Base::BlockGemmShape;
using QuantGroupSize = remove_cvref_t<typename Problem::AQuantGroupSize>;
using Base = GemmPipelineAgBgCrImplBase<Problem, Policy>;
using ADataType = typename Base::ADataType;
using ALayout = typename Base::ALayout;
using BDataType = typename Base::BDataType;
using BLayout = typename Base::BLayout;
using BlockGemmShape = typename Base::BlockGemmShape;
using AQuantGroupSize = remove_cvref_t<typename Problem::AQuantGroupSize>;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t KPerBlockAQ = KPerBlock / QuantGroupSize::kK;
static constexpr index_t KPerBlockAQ = KPerBlock / AQuantGroupSize::kK;
static_assert(KPerBlock % QuantGroupSize::kK == 0,
static_assert(KPerBlock % AQuantGroupSize::kK == 0,
"KPerBlock must be a multiple of QuantGroupSize");
// Create DRAM tile window for AQ

View File

@@ -23,19 +23,19 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
using Base = BaseGemmPipelineAgBgCrMem<Problem>;
using PipelineImplBase = GemmAQuantPipelineAgBgCrImplBase<Problem, Policy>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using AQDataType = remove_cvref_t<typename Problem::AQDataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using AQDataType = remove_cvref_t<typename Problem::AQDataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using AQuantGroupSize = remove_cvref_t<typename Problem::AQuantGroupSize>;
// When ADataType is pk_int4_t, use BDataType instead for transpose operations
// since packed 4-bit integers cannot be directly transposed (requires at least 8-bit precision)
using OverrideADataType =
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
static_assert(QuantGroupSize::kM == 1, "no block for M supported yet!");
static_assert(QuantGroupSize::kN == 1, "only M/K blocks for AQuant kernel!");
static_assert(AQuantGroupSize::kM == 1, "no block for M supported yet!");
static_assert(AQuantGroupSize::kN == 1, "only M/K blocks for AQuant kernel!");
using I0 = number<0>;
using I1 = number<1>;
@@ -60,7 +60,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t KPerBlockAQ = BlockGemmShape::kK / QuantGroupSize::kK;
static constexpr index_t KPerBlockAQ = BlockGemmShape::kK / AQuantGroupSize::kK;
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
@@ -78,7 +78,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
static constexpr bool kPadK = Problem::kPadK;
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
static constexpr bool APreshuffleQuant = Problem::Traits::APreshuffleQuant;
static constexpr bool HasHotLoop = Problem::HasHotLoop;
static constexpr auto TailNum = Problem::TailNum;
@@ -99,7 +99,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
BlockSize,
concat('x', WaveNumM, WaveNumN),
concat('x', BlockGemm::WarpGemm::kM, BlockGemm::WarpGemm::kN, BlockGemm::WarpGemm::kK),
concat('x', kPadM, kPadN, kPadK), QuantGroupSize::GetName(),
concat('x', kPadM, kPadN, kPadK), AQuantGroupSize::GetName(),
Scheduler == GemmPipelineScheduler::Interwave ? "interwave" : "intrawave"); // else Intrawave
// clang-format on
}
@@ -156,7 +156,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
<< "\n"
<< "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n"
<< "C MFMA inst: " << C_MFMA_Inst_Num << "\n"
<< "QuantGroupSize: " << QuantGroupSize::GetName() << "\n"
<< "AQuantGroupSize: " << AQuantGroupSize::GetName() << "\n"
<< "KPack: " << BlockGemm::Traits::KPack << "\n"
<< "PrefetchStages: " << PrefetchStages << "\n";
return str.str();
@@ -216,7 +216,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
std::is_same_v<AQLayout, tensor_layout::gemm::ColumnMajor>;
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
static_assert(!PreshuffleQuant, "Memory pipeline does not support PreshuffleQuant!");
static_assert(!APreshuffleQuant, "Memory pipeline does not support APreshuffleQuant!");
static_assert(is_a_col_major
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&

View File

@@ -32,22 +32,22 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
using AQLayout = remove_cvref_t<typename Problem::AQLayout>;
using BlockGemmShape = typename Problem::BlockGemmShape;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPerBlockAQ = KPerBlock / Problem::AQuantGroupSize::kK;
constexpr index_t VecLoadSize = GetVectorSizeAQ<Problem>();
constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC>;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPerBlockAQ = KPerBlock / Problem::AQuantGroupSize::kK;
constexpr index_t VecLoadSize = GetVectorSizeAQ<Problem>();
constexpr bool APreshuffleQuant = Problem::Traits::APreshuffleQuant;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC>;
if constexpr(PreshuffleQuant)
if constexpr(APreshuffleQuant)
{
using TileEncodingPattern = tile_distribution_encoding_pattern_aq<
BlockGemmShape,
@@ -57,7 +57,7 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
ck_tile::integer_least_multiple(WarpGemm::kM * KPerBlockAQ, get_warp_size()),
KPerBlockAQ,
VecLoadSize,
PreshuffleQuant>;
APreshuffleQuant>;
return TileEncodingPattern::make_2d_static_tile_distribution();
}
@@ -89,7 +89,7 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
KPerBlockAQ,
KPerBlockAQ,
VecLoadSize,
PreshuffleQuant>;
APreshuffleQuant>;
return TileEncodingPattern::make_2d_static_tile_distribution();
}
@@ -103,7 +103,7 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
MPerBlock, // XPerTile
KPerBlockAQ,
VecLoadSize,
PreshuffleQuant>;
APreshuffleQuant>;
return TileEncodingPattern::make_2d_static_tile_distribution_transposed();
}
}

View File

@@ -20,19 +20,19 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
using Base = BaseGemmPipelineAgBgCrCompV3<Problem>;
using PipelineImplBase = GemmAQuantPipelineAgBgCrImplBase<Problem, Policy>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using AQDataType = remove_cvref_t<typename Problem::AQDataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using AQDataType = remove_cvref_t<typename Problem::AQDataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using AQuantGroupSize = remove_cvref_t<typename Problem::AQuantGroupSize>;
// When ADataType is pk_int4_t, use BDataType instead for transpose operations
// since packed 4-bit integers cannot be directly transposed (requires at least 8-bit precision)
using OverrideADataType =
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
static_assert(QuantGroupSize::kM == 1, "no block for M supported yet!");
static_assert(QuantGroupSize::kN == 1, "only M/K blocks for AQuant kernel!");
static_assert(AQuantGroupSize::kM == 1, "no block for M supported yet!");
static_assert(AQuantGroupSize::kN == 1, "only M/K blocks for AQuant kernel!");
using I0 = number<0>;
using I1 = number<1>;
@@ -57,7 +57,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t KPerBlockAQ = BlockGemmShape::kK / QuantGroupSize::kK;
static constexpr index_t KPerBlockAQ = BlockGemmShape::kK / AQuantGroupSize::kK;
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
@@ -75,7 +75,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
static constexpr bool kPadK = Problem::kPadK;
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
static constexpr bool APreshuffleQuant = Problem::Traits::APreshuffleQuant;
static constexpr bool HasHotLoop = Problem::HasHotLoop;
static constexpr auto TailNum = Problem::TailNum;
@@ -96,7 +96,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
BlockSize,
concat('x', WaveNumM, WaveNumN),
concat('x', BlockGemm::WarpGemm::kM, BlockGemm::WarpGemm::kN, BlockGemm::WarpGemm::kK),
concat('x', kPadM, kPadN, kPadK), QuantGroupSize::GetName());
concat('x', kPadM, kPadN, kPadK), AQuantGroupSize::GetName());
// clang-format on
}
@@ -152,7 +152,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
<< "\n"
<< "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n"
<< "C MFMA inst: " << C_MFMA_Inst_Num << "\n"
<< "QuantGroupSize: " << QuantGroupSize::GetName() << "\n"
<< "AQuantGroupSize: " << AQuantGroupSize::GetName() << "\n"
<< "KPack: " << BlockGemm::Traits::KPack << "\n"
<< "PrefetchStages: " << PrefetchStages << "\n";
return str.str();
@@ -271,7 +271,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
// only row_major for AQ
const AQDramTileWindowStep aq_dram_tile_window_step =
PreshuffleQuant
APreshuffleQuant
? make_array(ck_tile::integer_least_multiple(m, MPerBlock) /
BlockGemm::WarpGemm::kM,
0)

View File

@@ -12,13 +12,13 @@ namespace ck_tile {
template <typename Problem, typename Policy>
struct GemmBQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Problem, Policy>
{
using Base = GemmPipelineAgBgCrImplBase<Problem, Policy>;
using ADataType = typename Base::ADataType;
using ALayout = typename Base::ALayout;
using BDataType = typename Base::BDataType;
using BLayout = typename Base::BLayout;
using BlockGemmShape = typename Base::BlockGemmShape;
using QuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
using Base = GemmPipelineAgBgCrImplBase<Problem, Policy>;
using ADataType = typename Base::ADataType;
using ALayout = typename Base::ALayout;
using BDataType = typename Base::BDataType;
using BLayout = typename Base::BLayout;
using BlockGemmShape = typename Base::BlockGemmShape;
using BQuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
@@ -27,16 +27,16 @@ struct GemmBQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Prob
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t NPerBlockBQ =
(QuantGroupSize::kN <= NPerBlock) ? NPerBlock / QuantGroupSize::kN : 1;
static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize::kK;
(BQuantGroupSize::kN <= NPerBlock) ? NPerBlock / BQuantGroupSize::kN : 1;
static constexpr index_t KPerBlockBQ = KPerBlock / BQuantGroupSize::kK;
// static_assert(NPerBlockBQ >= 1, "NPerBlock must be >= QuantGroupSize");
static_assert(KPerBlockBQ >= 1, "KPerBlock must be >= QuantGroupSize");
// static_assert(NPerBlockBQ >= 1, "NPerBlock must be >= BQuantGroupSize");
static_assert(KPerBlockBQ >= 1, "KPerBlock must be >= BQuantGroupSize");
// static_assert(NPerBlock % QuantGroupSize::kN == 0,
// "NPerBlock must be a multiple of QuantGroupSize::kN");
static_assert(KPerBlock % QuantGroupSize::kK == 0,
"KPerBlock must be a multiple of QuantGroupSize::kK");
// static_assert(NPerBlock % BQuantGroupSize::kN == 0,
// "NPerBlock must be a multiple of BQuantGroupSize::kN");
static_assert(KPerBlock % BQuantGroupSize::kK == 0,
"KPerBlock must be a multiple of BQuantGroupSize::kK");
// Create DRAM tile window for BQ
template <typename BQDramBlockWindowTmp>

View File

@@ -43,14 +43,14 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
using BlockGemmShape = typename Problem::BlockGemmShape;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t NPerBlockBQ = (Problem::QuantGroupSize::kN <= NPerBlock)
? NPerBlock / Problem::QuantGroupSize::kN
: 1;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPerBlockBQ = KPerBlock / Problem::BQuantGroupSize::kK;
constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t NPerBlockBQ = (Problem::BQuantGroupSize::kN <= NPerBlock)
? NPerBlock / Problem::BQuantGroupSize::kN
: 1;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPerBlockBQ = KPerBlock / Problem::BQuantGroupSize::kK;
constexpr bool BPreshuffleQuant = Problem::Traits::BPreshuffleQuant;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
@@ -61,7 +61,7 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
WarpTile::at(I2),
Problem::TransposeC>;
if constexpr(PreshuffleQuant)
if constexpr(BPreshuffleQuant)
{
using TileEncodingPattern = tile_distribution_encoding_pattern_bq<
BlockGemmShape,
@@ -72,7 +72,7 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
Problem::BQuantGroupSize::kN,
Problem::BQuantGroupSize::kK,
BQLayout,
PreshuffleQuant>;
BPreshuffleQuant>;
return TileEncodingPattern::make_2d_static_tile_distribution();
}
else

View File

@@ -26,12 +26,12 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
using Base = BaseGemmPipelineAgBgCrCompV3<Problem>;
using PipelineImplBase = GemmBQuantPipelineAgBgCrImplBase<Problem, Policy>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using QuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using BQuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
@@ -45,7 +45,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
ADataType,
BDataType>;
static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!");
static_assert(BQuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!");
using I0 = number<0>;
using I1 = number<1>;
using I2 = number<2>;
@@ -66,11 +66,11 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t NPerBlockBQ =
(QuantGroupSize::kN <= BlockGemmShape::kN)
? integer_divide_ceil(BlockGemmShape::kN, QuantGroupSize::kN)
(BQuantGroupSize::kN <= BlockGemmShape::kN)
? integer_divide_ceil(BlockGemmShape::kN, BQuantGroupSize::kN)
: 1;
static constexpr index_t KPerBlockBQ =
integer_divide_ceil(BlockGemmShape::kK, QuantGroupSize::kK);
integer_divide_ceil(BlockGemmShape::kK, BQuantGroupSize::kK);
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
@@ -88,7 +88,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
static constexpr bool kPadK = Problem::kPadK;
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
static constexpr bool BPreshuffleQuant = Problem::Traits::BPreshuffleQuant;
static constexpr bool HasHotLoop = Problem::HasHotLoop;
static constexpr auto TailNum = Problem::TailNum;
@@ -109,7 +109,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
BlockSize,
concat('x', WaveNumM, WaveNumN),
concat('x', BlockGemm::WarpGemm::kM, BlockGemm::WarpGemm::kN, BlockGemm::WarpGemm::kK),
concat('x', kPadM, kPadN, kPadK), QuantGroupSize::GetName());
concat('x', kPadM, kPadN, kPadK), BQuantGroupSize::GetName());
// clang-format on
}
@@ -165,7 +165,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
<< "\n"
<< "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n"
<< "C MFMA inst: " << C_MFMA_Inst_Num << "\n"
<< "QuantGroupSize: " << QuantGroupSize::GetName() << "\n"
<< "BQuantGroupSize: " << BQuantGroupSize::GetName() << "\n"
<< "KPack: " << BlockGemm::Traits::KPack << "\n"
<< "PrefetchStages: " << PrefetchStages << "\n";
return str.str();
@@ -252,7 +252,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"B block window has incorrect lengths for defined BLayout!");
static_assert(
PreshuffleQuant ||
BPreshuffleQuant ||
(is_bq_row_major
? (KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}])
@@ -304,9 +304,9 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
constexpr BDramTileWindowStep b_dram_tile_window_step =
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
const BQDramTileWindowStep bq_dram_tile_window_step =
(PreshuffleQuant)
(BPreshuffleQuant)
? make_array(((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{}))
? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN)
? ck_tile::integer_divide_ceil(n, BQuantGroupSize::kN)
: ck_tile::integer_least_multiple(n, NPerBlock) /
BlockGemmShape::WarpTile::at(number<1>{})),
0)

View File

@@ -52,7 +52,7 @@ template <typename BlockGemmShape,
index_t XPerTile,
index_t KPerBlockAQ,
index_t VecSize,
bool PreshuffleQuant>
bool APreshuffleQuant>
struct tile_distribution_encoding_pattern_aq : public tile_distribution_encoding_pattern
{
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
@@ -72,7 +72,7 @@ struct tile_distribution_encoding_pattern_aq : public tile_distribution_encoding
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution()
{
if constexpr(PreshuffleQuant)
if constexpr(APreshuffleQuant)
{
// # of elements per thread
static_assert(XPerTile >= warp_size && XPerTile % warp_size == 0);
@@ -193,8 +193,8 @@ template <typename BlockGemmShape,
index_t NPerTile,
index_t NPerQ,
index_t KPerQ,
typename BQLayout = tensor_layout::gemm::ColumnMajor,
bool PreshuffleQuant = false>
typename BQLayout = tensor_layout::gemm::ColumnMajor,
bool BPreshuffleQuant = false>
struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding_pattern
{
static constexpr index_t warp_size = get_warp_size();
@@ -212,10 +212,11 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution()
{
// Preshuffle only supported for ColumnMajor currently
static_assert(!(PreshuffleQuant && std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>),
"PreshuffleQuant only supported for ColumnMajor BQLayout");
static_assert(
!(BPreshuffleQuant && std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>),
"PreshuffleQuant only supported for ColumnMajor BQLayout");
if constexpr(PreshuffleQuant)
if constexpr(BPreshuffleQuant)
{
// =============================================================================
// PRE-SHUFFLED BQ SCALE TILE DISTRIBUTION

View File

@@ -12,13 +12,13 @@ namespace ck_tile {
template <typename Problem, typename Policy>
struct GemmMxFp4PipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Problem, Policy>
{
using Base = GemmPipelineAgBgCrImplBase<Problem, Policy>;
using ADataType = typename Base::ADataType;
using ALayout = typename Base::ALayout;
using BDataType = typename Base::BDataType;
using BLayout = typename Base::BLayout;
using BlockGemmShape = typename Base::BlockGemmShape;
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
using Base = GemmPipelineAgBgCrImplBase<Problem, Policy>;
using ADataType = typename Base::ADataType;
using ALayout = typename Base::ALayout;
using BDataType = typename Base::BDataType;
using BLayout = typename Base::BLayout;
using BlockGemmShape = typename Base::BlockGemmShape;
using BQuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
@@ -26,16 +26,16 @@ struct GemmMxFp4PipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Probl
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t NPerBlockBQ = NPerBlock / QuantGroupSize::kN;
static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize::kK;
static constexpr index_t NPerBlockBQ = NPerBlock / BQuantGroupSize::kN;
static constexpr index_t KPerBlockBQ = KPerBlock / BQuantGroupSize::kK;
static_assert(NPerBlockBQ >= 1, "NPerBlock must be >= QuantGroupSize");
static_assert(KPerBlockBQ >= 1, "KPerBlock must be >= QuantGroupSize");
static_assert(NPerBlockBQ >= 1, "NPerBlock must be >= BQuantGroupSize");
static_assert(KPerBlockBQ >= 1, "KPerBlock must be >= BQuantGroupSize");
static_assert(NPerBlock % QuantGroupSize::kN == 0,
"NPerBlock must be a multiple of QuantGroupSize::kN");
static_assert(KPerBlock % QuantGroupSize::kK == 0,
"KPerBlock must be a multiple of QuantGroupSize::kK");
static_assert(NPerBlock % BQuantGroupSize::kN == 0,
"NPerBlock must be a multiple of BQuantGroupSize::kN");
static_assert(KPerBlock % BQuantGroupSize::kK == 0,
"KPerBlock must be a multiple of BQuantGroupSize::kK");
// Create DRAM tile window for BQ
template <typename BQDramBlockWindowTmp>

View File

@@ -22,9 +22,9 @@ struct GemmMxFp4PipelineAgBgCrPolicy : public UniversalGemmPipelineAgBgCrPolicy
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t NPerBlockBQ = NPerBlock / Problem::QuantGroupSize::kN;
constexpr index_t NPerBlockBQ = NPerBlock / Problem::BQuantGroupSize::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPerBlockBQ = KPerBlock / Problem::QuantGroupSize::kK;
constexpr index_t KPerBlockBQ = KPerBlock / Problem::BQuantGroupSize::kK;
static_assert(std::is_same_v<BQLayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
return GetABQGlobalVectorLoadSize<Problem, BQDataType, NPerBlockBQ, KPerBlockBQ>();
@@ -76,7 +76,7 @@ struct GemmMxFp4PipelineAgBgCrPolicy : public UniversalGemmPipelineAgBgCrPolicy
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KScale = KPerBlock / Problem::QuantGroupSize::kK; // k_scale num //2
constexpr index_t KScale = KPerBlock / Problem::BQuantGroupSize::kK; // k_scale num //2
constexpr index_t VecLoadSize =
Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB<Problem>();
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
@@ -109,7 +109,7 @@ struct GemmMxFp4PipelineAgBgCrPolicy : public UniversalGemmPipelineAgBgCrPolicy
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
static_assert(Problem::QuantGroupSize::kK % WarpTile::at(I2) == 0,
static_assert(Problem::BQuantGroupSize::kK % WarpTile::at(I2) == 0,
"KPerWarpGemm must be a multiple of QuantGroupSize!");
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,

View File

@@ -24,15 +24,15 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Probl
using Base = BaseGemmPipelineAgBgCrCompV3<Problem>;
using PipelineImplBase = GemmMxFp4PipelineAgBgCrImplBase<Problem, Policy>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BDqDataType = remove_cvref_t<typename Problem::ADataType>;
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BDqDataType = remove_cvref_t<typename Problem::ADataType>;
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using BQuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!");
static_assert(BQuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!");
using I0 = number<0>;
using I1 = number<1>;
@@ -58,8 +58,8 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Probl
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t NPerBlockBQ = BlockGemmShape::kN / QuantGroupSize::kN;
static constexpr index_t KPerBlockBQ = BlockGemmShape::kK / QuantGroupSize::kK;
static constexpr index_t NPerBlockBQ = BlockGemmShape::kN / BQuantGroupSize::kN;
static constexpr index_t KPerBlockBQ = BlockGemmShape::kK / BQuantGroupSize::kK;
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
@@ -93,7 +93,7 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Probl
concat('x', MPerBlock, NPerBlock, KPerBlock), BlockSize,
concat('x', WaveNumM, WaveNumN),
concat('x', kPadM, kPadN, kPadK),
concat('x', kPadM, kPadN, kPadK), QuantGroupSize::GetName());
concat('x', kPadM, kPadN, kPadK), BQuantGroupSize::GetName());
// clang-format on
}
@@ -149,7 +149,7 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Probl
<< "\n"
<< "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n"
<< "C MFMA inst: " << C_MFMA_Inst_Num << "\n"
<< "QuantGroupSize: " << QuantGroupSize::GetName() << "\n"
<< "BQuantGroupSize: " << BQuantGroupSize::GetName() << "\n"
<< "KPack: " << BlockGemm::Traits::KPack << "\n"
<< "PrefetchStages: " << PrefetchStages << "\n";
return str.str();
@@ -412,7 +412,7 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Probl
constexpr BDramTileWindowStep b_dram_tile_window_step =
is_b_row_major ? make_array(KPerBlock / 2, 0) : make_array(0, KPerBlock / 2);
constexpr index_t b_scale_dram_tile_window_step = KPerBlock / QuantGroupSize::kK;
constexpr index_t b_scale_dram_tile_window_step = KPerBlock / BQuantGroupSize::kK;
// -----------------------------------------------------------------------------------------
// Gemm pipeline start

View File

@@ -120,7 +120,7 @@ template <typename ADataType_,
typename CDataType_,
typename BlockGemmShape_,
typename Traits_,
typename QuantGroupSize_,
typename AQuantGroupSize_,
bool TransposeC_,
typename ComputeDataType_ = BDataType_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
@@ -133,7 +133,7 @@ using GemmAQuantPipelineProblem = GemmQuantPipelineProblemBase<ADataType_,
CDataType_,
BlockGemmShape_,
Traits_,
QuantGroupSize_,
AQuantGroupSize_,
void,
TransposeC_,
ComputeDataType_,
@@ -147,7 +147,7 @@ template <typename ADataType_,
typename CDataType_,
typename BlockGemmShape_,
typename Traits_,
typename QuantGroupSize_,
typename BQuantGroupSize_,
typename ComputeDataType_ = ADataType_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
@@ -160,7 +160,7 @@ using GemmBQuantPipelineProblem = GemmQuantPipelineProblemBase<ADataType_,
BlockGemmShape_,
Traits_,
void,
QuantGroupSize_,
BQuantGroupSize_,
false, // no TransposeC
ComputeDataType_,
Scheduler_,

View File

@@ -25,7 +25,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using QuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
using BQuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
@@ -69,14 +69,14 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
using Base::m_preload;
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
static constexpr bool BPreshuffleQuant = Problem::Traits::BPreshuffleQuant;
static constexpr index_t VectorLoadSize = Problem::VectorLoadSize;
static constexpr index_t NPerBlockBQ =
integer_divide_ceil(BlockGemmShape::kN, QuantGroupSize::kN);
integer_divide_ceil(BlockGemmShape::kN, BQuantGroupSize::kN);
static constexpr index_t KPerBlockBQ =
integer_divide_ceil(BlockGemmShape::kK, QuantGroupSize::kK);
integer_divide_ceil(BlockGemmShape::kK, BQuantGroupSize::kK);
static constexpr index_t QScalesPerBlockRow =
integer_divide_ceil(kKPerBlock, QuantGroupSize::kK);
integer_divide_ceil(kKPerBlock, BQuantGroupSize::kK);
static constexpr index_t GetVectorSizeBQ()
{
@@ -94,7 +94,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
BlockSize,
concat('x', WaveNumM, WaveNumN),
concat('x', Base::GetVectorSizeA(), Base::GetVectorSizeB(), GetVectorSizeBQ()),
concat('x', kPadM, kPadN, kPadK), QuantGroupSize::GetName());
concat('x', kPadM, kPadN, kPadK), BQuantGroupSize::GetName());
// clang-format on
}
@@ -115,7 +115,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
// then by vector width to get an approximate number of vector loads.
constexpr index_t BQload_inst = ck_tile::integer_divide_ceil(
ck_tile::integer_divide_ceil(kKPerBlock * kNPerBlock * sizeof(BQDataType),
QuantGroupSize::kK * QuantGroupSize::kK),
BQuantGroupSize::kK * BQuantGroupSize::kK),
VectorLoadSize);
// ToDo: Hardcoded, need to change in future. How many instruction emit per iteration
@@ -360,11 +360,11 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
BQBlockTile bq_block_tile, bq_block_tile_2;
bq_block_tile = load_tile(bq_copy_dram_window);
// move BQ to tile 1
if constexpr(PreshuffleQuant)
if constexpr(BPreshuffleQuant)
{
move_tile_window(bq_copy_dram_window,
{((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{}))
? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN)
? ck_tile::integer_divide_ceil(n, BQuantGroupSize::kN)
: ck_tile::integer_least_multiple(n, kNPerBlock) /
BlockGemmShape::WarpTile::at(number<1>{})),
0});
@@ -437,11 +437,11 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
bq_block_tile_2 = load_tile(bq_copy_dram_window);
if constexpr(PreshuffleQuant)
if constexpr(BPreshuffleQuant)
{
move_tile_window(bq_copy_dram_window,
{((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{}))
? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN)
? ck_tile::integer_divide_ceil(n, BQuantGroupSize::kN)
: ck_tile::integer_least_multiple(n, kNPerBlock) /
BlockGemmShape::WarpTile::at(number<1>{})),
0});
@@ -474,11 +474,11 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
bq_block_tile = load_tile(bq_copy_dram_window);
if constexpr(PreshuffleQuant)
if constexpr(BPreshuffleQuant)
{
move_tile_window(bq_copy_dram_window,
{((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{}))
? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN)
? ck_tile::integer_divide_ceil(n, BQuantGroupSize::kN)
: ck_tile::integer_least_multiple(n, kNPerBlock) /
BlockGemmShape::WarpTile::at(number<1>{})),
0});

View File

@@ -33,7 +33,8 @@ inline std::string quant_type_to_string(QuantType quant_type)
template <bool kPadM_,
bool kPadN_,
bool kPadK_,
bool PreshuffleQuant_,
bool APreshuffleQuant_,
bool BPreshuffleQuant_,
bool PreshuffleB_,
typename ALayout_,
typename BLayout_,
@@ -71,8 +72,9 @@ struct TileGemmQuantTraits
static constexpr index_t NumWaveGroups = 1;
static constexpr bool UsePersistentKernel = UsePersistentKernel_;
static constexpr bool PreshuffleQuant = PreshuffleQuant_;
static constexpr bool PreshuffleB = PreshuffleB_;
static constexpr bool APreshuffleQuant = APreshuffleQuant_;
static constexpr bool BPreshuffleQuant = BPreshuffleQuant_;
static constexpr bool PreshuffleB = PreshuffleB_;
};
} // namespace ck_tile

View File

@@ -101,6 +101,55 @@ using device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f16_instances = std::tuple
// clang-format on
>;
template <ck::index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename ELayout,
ConvolutionBackwardWeightSpecialization ConvSpec,
BlockGemmPipelineScheduler Scheduler,
BlockGemmPipelineVersion PipelineVersion>
using device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f16_direct_load_instances = std::tuple<
// clang-format off
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| Compute| Compute| Direct|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| Data| Data| Load|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| Type| Type| |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | | |
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 16, 32, 64, 8, 16, 16, 1, 1, S<8, 2, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 4, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, F16, F16, true>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 16, 32, 64, 8, 16, 16, 1, 1, S<8, 2, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 4, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, F16, F16, true, 2>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 16, 64, 64, 8, 16, 16, 1, 2, S<8, 2, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<2, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, F16, F16, true>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 16, 64, 64, 8, 16, 16, 1, 2, S<8, 2, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<2, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, F16, F16, true, 2>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 32, 64, 8, 32, 32, 2, 1, S<4, 32, 2>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<8, 4, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Scheduler, PipelineVersion, F16, F16, true>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 64, 8, 32, 32, 1, 2, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 16, 4>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Scheduler, PipelineVersion, F16, F16, true>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 64, 8, 32, 32, 1, 2, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 16, 4>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Scheduler, PipelineVersion, F16, F16, true, 2>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Scheduler, PipelineVersion, F16, F16, true>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 2, Scheduler, PipelineVersion, F16, F16, true>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 2>, S<0, 2, 1>,S<0, 2, 1>, 1, 2, 1, 0, S<4, 32, 2>, S<0, 2, 1>,S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 32, 1, 4>, 4, Scheduler, PipelineVersion, F16, F16, true>
// clang-format on
>;
template <ck::index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename ELayout,
ConvolutionBackwardWeightSpecialization ConvSpec,
BlockGemmPipelineScheduler Scheduler,
BlockGemmPipelineVersion PipelineVersion>
using device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_bf16_direct_load_instances = std::tuple<
// clang-format off
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| Compute| Compute| Direct|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| Data| Data| Load|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| Type| Type| |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | | |
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 16, 32, 64, 8, 16, 16, 1, 1, S<8, 2, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 4, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, BF16, BF16, true>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 16, 64, 64, 8, 16, 16, 1, 2, S<8, 2, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<2, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, BF16, BF16, true>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 32, 64, 8, 32, 32, 2, 1, S<4, 32, 2>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<8, 4, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Scheduler, PipelineVersion, BF16, BF16, true>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 64, 8, 32, 32, 1, 2, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 16, 4>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Scheduler, PipelineVersion, BF16, BF16, true>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Scheduler, PipelineVersion, BF16, BF16, true>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 2, Scheduler, PipelineVersion, BF16, BF16, true>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 2>, S<0, 2, 1>,S<0, 2, 1>, 1, 2, 1, 0, S<4, 32, 2>, S<0, 2, 1>,S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 32, 1, 4>, 4, Scheduler, PipelineVersion, BF16, BF16, true>
// clang-format on
>;
template <ck::index_t NDimSpatial,
typename ALayout,
typename BLayout,

View File

@@ -393,6 +393,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_direct_load_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev2_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev5_instances(
@@ -453,6 +456,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_direct_load_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev2_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev5_instances(

View File

@@ -184,6 +184,18 @@ void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pip
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_direct_load_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev5_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
@@ -389,6 +401,18 @@ void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipe
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_direct_load_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev5_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,

View File

@@ -20,6 +20,8 @@ set(GROUPED_CONV2D_BWD_WEIGHT
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev5_instance.cpp
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev2_instance.cpp
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev5_instance.cpp
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_direct_load.cpp
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_direct_load.cpp
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev2_instance.cpp
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev5_instance.cpp
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_pad0_pipev2_instance.cpp

View File

@@ -0,0 +1,40 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_direct_load_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_bf16_direct_load_instances<
2,
NHWGC,
GKYXC,
NHWGK,
ConvBwdWeightDefault,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v1>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,40 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_direct_load_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f16_direct_load_instances<
2,
NHWGC,
GKYXC,
NHWGK,
ConvBwdWeightDefault,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v1>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -45,7 +45,22 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
target_compile_options(test_tile_gemm_quant_aquant_base_ccr PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
# ABQuant tests
add_gtest_executable(test_tile_gemm_quant_aquant_prefill
test_gemm_quant_aquant_prefill.cpp
)
target_compile_options(test_tile_gemm_quant_aquant_prefill PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_tile_gemm_quant_aquant_transpose_c
test_gemm_quant_aquant_transpose_c.cpp
)
target_compile_options(test_tile_gemm_quant_aquant_transpose_c PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_tile_gemm_quant_aquant_preshuffle
test_gemm_quant_aquant_preshuffle.cpp
)
target_compile_options(test_tile_gemm_quant_aquant_preshuffle PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
# ABQuant tests split into 4 files
add_gtest_executable(test_tile_gemm_quant_abquant_base
test_gemm_quant_abquant_base.cpp
)
@@ -61,21 +76,10 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
)
target_compile_options(test_tile_gemm_quant_abquant_preshuffle PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
# AQuant tests
add_gtest_executable(test_tile_gemm_quant_aquant_prefill
test_gemm_quant_aquant_prefill.cpp
add_gtest_executable(test_tile_gemm_quant_abquant_preshuffleQuant
test_gemm_quant_abquant_preshuffleQuant.cpp
)
target_compile_options(test_tile_gemm_quant_aquant_prefill PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_tile_gemm_quant_aquant_transpose_c
test_gemm_quant_aquant_transpose_c.cpp
)
target_compile_options(test_tile_gemm_quant_aquant_transpose_c PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_tile_gemm_quant_aquant_preshuffle
test_gemm_quant_aquant_preshuffle.cpp
)
target_compile_options(test_tile_gemm_quant_aquant_preshuffle PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
target_compile_options(test_tile_gemm_quant_abquant_preshuffleQuant PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
# BQuant tests (without PreshuffleB) - split into 6 files
add_gtest_executable(test_tile_gemm_quant_bquant_1d_128
@@ -188,6 +192,11 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
test_tile_gemm_quant_aquant_prefill
test_tile_gemm_quant_aquant_transpose_c
test_tile_gemm_quant_aquant_preshuffle
# ABQuant tests
test_tile_gemm_quant_abquant_base
test_tile_gemm_quant_abquant_padding
test_tile_gemm_quant_abquant_preshuffle
test_tile_gemm_quant_abquant_preshuffleQuant
# BQuant tests
test_tile_gemm_quant_bquant_1d_128
test_tile_gemm_quant_bquant_1d_64

View File

@@ -0,0 +1,43 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
#include <gtest/gtest.h>
#include <memory>
#include "test_gemm_quant_fixtures.hpp"
// Type aliases for readability
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Half = ck_tile::half_t;
using PkInt4 = ck_tile::pk_int4_t;
using ABQuantGrouped =
std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::ABQuantGrouped>;
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
// 2d block sizes for BQuant
using GroupSize2D128N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
// Type combinations for ABQuant tests
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
// QuantType, GemmConfig, AQuantGroupSize, BQuantGroupSize, BQLayout>
// clang-format off
using ABQuantPreshuffleQuantTypes = ::testing::Types<
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigPreshuffleQuantPrefill, GroupSize, GroupSize, ColumnMajor>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigPreshuffleQuantPrefill, GroupSize, GroupSize2D128N, ColumnMajor>
>;
// clang-format on
// Test suite for ABQuant
TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantPreshuffleQuantTypes);
// AQuant tests
TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest)
{
this->run_test_with_validation(1024, 1024, 1024);
}

View File

@@ -75,7 +75,8 @@ class TestCkTileGemmQuantBase : public ::testing::Test
static constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
static constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
static constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
static constexpr bool PreshuffleQuant = GemmConfig::PreshuffleQuant;
static constexpr bool APreshuffleQuant = GemmConfig::APreshuffleQuant;
static constexpr bool BPreshuffleQuant = GemmConfig::BPreshuffleQuant;
static constexpr bool PreshuffleB = GemmConfig::PreshuffleB;
static constexpr bool TiledMMAPermuteN = GemmConfig::TiledMMAPermuteN;
static constexpr bool DoubleSmemBuffer = GemmConfig::DoubleSmemBuffer;
@@ -111,7 +112,8 @@ class TestCkTileGemmQuantBase : public ::testing::Test
using CodegenGemmTraits = ck_tile::TileGemmQuantTraits<kPadM,
kPadN,
kPadK,
PreshuffleQuant,
APreshuffleQuant,
BPreshuffleQuant,
PreshuffleB,
ALayout,
BLayout,

View File

@@ -34,7 +34,8 @@ struct GemmConfigBase
static constexpr ck_tile::index_t TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool PreshuffleQuant = false;
static constexpr bool APreshuffleQuant = false;
static constexpr bool BPreshuffleQuant = false;
static constexpr bool PreshuffleB = false;
static constexpr bool DoubleSmemBuffer = false;
static constexpr bool TiledMMAPermuteN = false;
@@ -110,7 +111,7 @@ struct GemmConfigMxFp4 : public GemmConfigBase
struct GemmConfigPreshuffleQuant : public GemmConfigBase
{
static constexpr bool PreshuffleQuant = true;
static constexpr bool APreshuffleQuant = true;
};
struct GemmConfigTransposeC : public GemmConfigBase
@@ -120,8 +121,8 @@ struct GemmConfigTransposeC : public GemmConfigBase
struct GemmConfigPreshuffleQuantTransposeC : public GemmConfigBase
{
static constexpr bool PreshuffleQuant = true;
static constexpr bool TransposeC = true;
static constexpr bool APreshuffleQuant = true;
static constexpr bool TransposeC = true;
};
struct GemmConfigPadding : public GemmConfigBase
@@ -138,7 +139,7 @@ struct GemmConfigPreshuffleBDecode : public GemmConfigDecode
struct GemmConfigPreshuffleQuantDecode : public GemmConfigDecode
{
static constexpr bool PreshuffleQuant = true;
static constexpr bool BPreshuffleQuant = true;
};
struct GemmConfigPreshuffleBPrefill : public GemmConfigPrefill
@@ -149,7 +150,7 @@ struct GemmConfigPreshuffleBPrefill : public GemmConfigPrefill
struct GemmConfigPreshuffleQuantPrefill : public GemmConfigPrefill
{
static constexpr bool PreshuffleQuant = true;
static constexpr bool BPreshuffleQuant = true;
};
struct GemmConfigPreshuffleBPrefillTiledPermuteN : public GemmConfigPreshuffleBPrefill
@@ -160,7 +161,7 @@ struct GemmConfigPreshuffleBPrefillTiledPermuteN : public GemmConfigPreshuffleBP
struct GemmConfigPreshuffleBPreshuffleQuantDecode : public GemmConfigPreshuffleBDecode
{
static constexpr bool PreshuffleQuant = true;
static constexpr bool BPreshuffleQuant = true;
};
template <typename Tuple>
@@ -244,7 +245,7 @@ class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
a_m_k_dev_buf.ToDevice(a_m_k.data());
}
// aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data());
if constexpr(Base::GemmConfig::PreshuffleQuant)
if constexpr(Base::GemmConfig::APreshuffleQuant)
{
ck_tile::HostTensor<QDataType> aq_shuffle_host =
ck_tile::shuffle_aq(&aq_m_aqk, Base::GemmConfig::K_Tile / QuantGroupSize::kK);
@@ -481,7 +482,7 @@ class TestCkTileGemmAQuantMem
a_m_k_dev_buf.ToDevice(a_m_k.data());
}
// aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data());
if constexpr(Base::GemmConfig::PreshuffleQuant)
if constexpr(Base::GemmConfig::APreshuffleQuant)
{
ck_tile::HostTensor<QDataType> aq_shuffle_host =
ck_tile::shuffle_aq(&aq_m_aqk, Base::GemmConfig::K_Tile / QuantGroupSize::kK);
@@ -727,7 +728,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
ck_tile::bq_permuteN<GemmConfig>(bq_bqk_bqn, QuantGroupSize::kN);
bq_bqk_bqn_dev_buf.ToDevice(bq_shuffle_host.data());
}
else if constexpr(GemmConfig::PreshuffleQuant)
else if constexpr(GemmConfig::BPreshuffleQuant)
{
ck_tile::HostTensor<QDataType> bq_shuffle_host =
ck_tile::shuffle_bq(&bq_bqk_bqn, GemmConfig::K_Tile / QuantGroupSize::kK);
@@ -1024,7 +1025,7 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGe
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
if constexpr(Base::GemmConfig::PreshuffleQuant)
if constexpr(Base::GemmConfig::APreshuffleQuant)
{
ck_tile::HostTensor<QDataType> aq_shuffle_host =
ck_tile::shuffle_aq(&aq_m_aqk, Base::GemmConfig::K_Tile / AQuantGroupSize::kK);
@@ -1041,7 +1042,7 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGe
ck_tile::bq_permuteN<GemmConfig>(bq_bqk_bqn, BQuantGroupSize::kN);
bq_bqk_bqn_dev_buf.ToDevice(bq_shuffle_host.data());
}
else if constexpr(GemmConfig::PreshuffleQuant)
else if constexpr(GemmConfig::BPreshuffleQuant)
{
ck_tile::HostTensor<QDataType> bq_shuffle_host =
ck_tile::shuffle_bq(&bq_bqk_bqn, GemmConfig::K_Tile / BQuantGroupSize::kK);

View File

@@ -117,6 +117,7 @@ class TestCkTileGroupedGemmABQuant : public ::testing::Test
Config::kPadN,
Config::kPadK,
false,
false,
Config::PreshuffleB,
ALayout,
BLayout,
@@ -241,6 +242,7 @@ class TestCkTileGroupedGemmABQuant : public ::testing::Test
Config::kPadN,
Config::kPadK,
false,
false,
Config::PreshuffleB,
ALayout,
BLayout,

View File

@@ -112,6 +112,7 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
GroupedGemKernelParam::kPadN,
GroupedGemKernelParam::kPadK,
false,
false,
PreshuffleB,
ALayout,
BLayout,
@@ -289,6 +290,7 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
GroupedGemKernelParam::kPadN,
GroupedGemKernelParam::kPadK,
false,
false,
PreshuffleB,
ALayout,
BLayout,