mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 00:40:09 +00:00
[rocm-libraries] ROCm/rocm-libraries#4640 (commit 37b8c81)
Fix the Composable Kernel CI and versions incompatibility (#4640) ## Motivation This PR has 4 patches: 1. Fix the CI error of grouped gemm. 2. Fix the incompatibility of old linux version. 3. Fix the potential errors of flatmm. 4. Address the previous comments of abquant eight warps pipeline solution.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
1f6768472e
commit
5cb8109535
@@ -116,13 +116,13 @@ struct CShuffleEpilogue
|
||||
static constexpr index_t isCTransposed = Problem::isCTransposed;
|
||||
static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
|
||||
static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN;
|
||||
#ifdef __gfx9__
|
||||
static constexpr bool AsyncPipeline = (MWave * NWave == 8);
|
||||
#ifdef __gfx95__
|
||||
static constexpr bool EightWave = (MWave * NWave == 8);
|
||||
#else
|
||||
static constexpr bool AsyncPipeline = false;
|
||||
static constexpr bool EightWave = false;
|
||||
#endif
|
||||
static constexpr index_t BlockedXDLN_PerWarp =
|
||||
AsyncPipeline ? kNPerBlock / NWave / NPerXdl : Problem::BlockedXDLN_PerWarp;
|
||||
EightWave ? kNPerBlock / NWave / NPerXdl : Problem::BlockedXDLN_PerWarp;
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
|
||||
static constexpr index_t MPerIteration = MPerXdl * MWave;
|
||||
@@ -447,7 +447,7 @@ struct CShuffleEpilogue
|
||||
if constexpr(is_950 || is_any_of<ADataType, pk_int4_t, pk_fp4_t>::value ||
|
||||
is_any_of<BDataType, pk_int4_t, pk_fp4_t>::value)
|
||||
{
|
||||
if constexpr(AsyncPipeline)
|
||||
if constexpr(EightWave)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
|
||||
@@ -780,29 +780,31 @@ struct FlatmmKernel
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_m)
|
||||
{
|
||||
constexpr int ScaleGranularityM = decltype(kargs.scale_m_ptr)::GranularityMN;
|
||||
constexpr int ScaleGranularityKA = decltype(kargs.scale_m_ptr)::GranularityK;
|
||||
constexpr int GM = decltype(kargs.scale_m_ptr)::GranularityMN;
|
||||
constexpr int GK = decltype(kargs.scale_m_ptr)::GranularityK;
|
||||
|
||||
auto scale_stride_m = ScaleGranularityM == 0 ? 0 // per-tensor scale
|
||||
: 1; // per-token scale
|
||||
static_assert(GM != -1,
|
||||
"MakeScaleMWindow should only be instantiated when scale is enabled");
|
||||
|
||||
// per-tensor (GM==0) -> Mdim = 1, stride 0
|
||||
const index_t m_dim = (GM == 0) ? 1 : (kargs.M / GM);
|
||||
const index_t m_stride = (GM == 0) ? 0 : 1;
|
||||
|
||||
const index_t k_dim = (GK == 0) ? 1 : (splitk_batch_offset.splitted_k / GK);
|
||||
const index_t k_stride = 0; // your original code keeps K stride 0
|
||||
|
||||
// Step 1: Create tensor view
|
||||
const auto scale_m_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
kargs.scale_m_ptr.ptr,
|
||||
make_tuple(kargs.M / ScaleGranularityM,
|
||||
ScaleGranularityKA == 0
|
||||
? 1
|
||||
: (splitk_batch_offset.splitted_k / ScaleGranularityKA)),
|
||||
make_tuple(scale_stride_m, 0),
|
||||
number < ScaleGranularityM == 1 ? FlatmmPipeline::GetVectorSizeA() : 1 > {},
|
||||
make_tuple(m_dim, k_dim),
|
||||
make_tuple(m_stride, k_stride),
|
||||
number < (GM == 1) ? FlatmmPipeline::GetVectorSizeA() : 1 > {},
|
||||
number<1>{});
|
||||
|
||||
// Step 2: Create tile window
|
||||
// Window extents: if GM==0, we still just broadcast from [0,*]
|
||||
return make_tile_window(scale_m_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number < ScaleGranularityKA == 0
|
||||
? TilePartitioner::NPerBlock
|
||||
: TilePartitioner::KPerBlock > {}),
|
||||
number < (GK == 0) ? TilePartitioner::NPerBlock
|
||||
: TilePartitioner::KPerBlock > {}),
|
||||
{block_idx_m, 0});
|
||||
}
|
||||
|
||||
@@ -811,27 +813,29 @@ struct FlatmmKernel
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
constexpr int ScaleGranularityN = decltype(kargs.scale_n_ptr)::GranularityMN;
|
||||
constexpr int ScaleGranularityKB = decltype(kargs.scale_n_ptr)::GranularityK;
|
||||
constexpr int GN = decltype(kargs.scale_n_ptr)::GranularityMN;
|
||||
constexpr int GK = decltype(kargs.scale_n_ptr)::GranularityK;
|
||||
|
||||
auto scale_stride_n = ScaleGranularityN == 0 ? 0 // per-tensor scale
|
||||
: 1; // per-channel scale
|
||||
static_assert(GN != -1,
|
||||
"MakeScaleNWindow should only be instantiated when scale is enabled");
|
||||
|
||||
// per-tensor (GN==0) -> Ndim = 1, stride 0
|
||||
const index_t n_dim = (GN == 0) ? 1 : (kargs.N / GN);
|
||||
const index_t n_stride = (GN == 0) ? 0 : 1;
|
||||
|
||||
const index_t k_dim = (GK == 0) ? 1 : (splitk_batch_offset.splitted_k / GK);
|
||||
const index_t k_stride = 0;
|
||||
|
||||
// Step 1: Create tensor view
|
||||
const auto scale_n_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
kargs.scale_n_ptr.ptr,
|
||||
make_tuple(
|
||||
ScaleGranularityKB == 0 ? 1 : (splitk_batch_offset.splitted_k / ScaleGranularityKB),
|
||||
kargs.N / ScaleGranularityN),
|
||||
make_tuple(0, scale_stride_n),
|
||||
number < ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1 > {},
|
||||
make_tuple(k_dim, n_dim),
|
||||
make_tuple(k_stride, n_stride),
|
||||
number < (GN == 1) ? FlatmmPipeline::GetVectorSizeB() : 1 > {},
|
||||
number<1>{});
|
||||
|
||||
// Step 2: Create tile window
|
||||
return make_tile_window(scale_n_view,
|
||||
make_tuple(number < ScaleGranularityKB == 0
|
||||
? TilePartitioner::MPerBlock
|
||||
: TilePartitioner::KPerBlock > {},
|
||||
make_tuple(number < (GK == 0) ? TilePartitioner::MPerBlock
|
||||
: TilePartitioner::KPerBlock > {},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{0, block_idx_n});
|
||||
}
|
||||
@@ -854,8 +858,6 @@ struct FlatmmKernel
|
||||
MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
|
||||
const auto& b_flat_block_window = MakeBFlatBlockWindow(b_flat_ptr, kargs, block_idx_n);
|
||||
const auto& ds_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
|
||||
const auto& scale_m_window = MakeScaleMWindow(kargs, splitk_batch_offset, block_idx_m);
|
||||
const auto& scale_n_window = MakeScaleNWindow(kargs, splitk_batch_offset, block_idx_n);
|
||||
|
||||
const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
|
||||
|
||||
@@ -866,6 +868,8 @@ struct FlatmmKernel
|
||||
// Run Epilogue Pipeline with k_batch dispatching
|
||||
if constexpr(ScaleM::GranularityMN != -1 || ScaleN::GranularityMN != -1)
|
||||
{
|
||||
const auto& scale_m_window = MakeScaleMWindow(kargs, splitk_batch_offset, block_idx_m);
|
||||
const auto& scale_n_window = MakeScaleNWindow(kargs, splitk_batch_offset, block_idx_n);
|
||||
if(kargs.k_batch == 1)
|
||||
{
|
||||
auto e_block_window = MakeEBlockWindow<memory_operation_enum::set>(
|
||||
|
||||
@@ -76,7 +76,8 @@ struct ABQuantBlockUniversalGemmAsBsCrAsync : public BlockGemmQuantBase
|
||||
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
|
||||
static constexpr index_t KIterPerWarp = KPerBlock / (KWarp * WarpGemm::kK);
|
||||
|
||||
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
|
||||
static constexpr bool APreshuffleQuant = Problem::Traits::APreshuffleQuant;
|
||||
static constexpr bool BPreshuffleQuant = Problem::Traits::BPreshuffleQuant;
|
||||
|
||||
static constexpr index_t QScalesPerBlockRow =
|
||||
integer_divide_ceil(KPerBlock / KWarp, BQuantGroupSize::kK);
|
||||
@@ -158,7 +159,8 @@ struct ABQuantBlockUniversalGemmAsBsCrAsync : public BlockGemmQuantBase
|
||||
using BWarpTensor = typename WarpGemm::BWarpTensor;
|
||||
using CWarpTensor = typename WarpGemm::CWarpTensor;
|
||||
|
||||
static constexpr bool PreshuffleQuant = Traits::PreshuffleQuant;
|
||||
static constexpr bool APreshuffleQuant = Traits::APreshuffleQuant;
|
||||
static constexpr bool BPreshuffleQuant = Traits::BPreshuffleQuant;
|
||||
|
||||
static_assert(std::is_same_v<typename WarpGemm::CDataType, float>);
|
||||
|
||||
@@ -364,7 +366,7 @@ struct ABQuantBlockUniversalGemmAsBsCrAsync : public BlockGemmQuantBase
|
||||
AQPickerCommon<AQBlockTensor, Traits, mIter, kQScale> aq_picker(
|
||||
aq_block_tensor);
|
||||
|
||||
if constexpr(PreshuffleQuant)
|
||||
if constexpr(BPreshuffleQuant)
|
||||
{
|
||||
constexpr index_t reg_offset = nIter;
|
||||
auto pull_from_lane =
|
||||
|
||||
Reference in New Issue
Block a user