Multi AB support for wave transfer (#3578)

* Add multi AB support to wave transfer

* Improviments to multi ABD examples

* Add instances and use intrawave v1 instead of interwave

* Apply changes to other transfers

* Wave transfer: add support for multiple internal vgpr buffers

* Fix compilation error gfx11
This commit is contained in:
Enrico Degregori
2026-01-29 19:29:40 +01:00
committed by GitHub
parent fabac7e2c3
commit f16d9100e4
21 changed files with 374 additions and 188 deletions

View File

@@ -488,6 +488,19 @@ struct ABTransferThreadTiles
{
return make_dynamic_buffer<AddressSpaceEnum::Lds>(p_shared_AB, size);
}
template <index_t numElements, typename Type>
__device__ __forceinline__ static auto get_first_element_workaround(Type& array)
{
if constexpr(numElements > 1)
{
return array;
}
else
{
return array[I0];
}
}
};
} // namespace ck

View File

@@ -133,6 +133,19 @@ struct ABTransferThreadTilesPreShuffle
{
return make_static_buffer<AddressSpaceEnum::Vgpr, LDSType>(size);
}
template <index_t numElements, typename Type>
__device__ __forceinline__ static auto get_first_element_workaround(Type& array)
{
if constexpr(numElements > 1)
{
return array;
}
else
{
return array[I0];
}
}
};
} // namespace ck

View File

@@ -318,43 +318,43 @@ struct ABTransferWaveTiles
const index_t block_mn_id,
const index_t)
{
// Note: GlobalBufferNum is currently not used but it will be needed
// once we add other pipelines. It is currently needed only for
// consistency with the thread tiles approach
static_assert(GlobalBufferNum == 1, "single global buffer is only supported");
constexpr index_t NumABTensor = ABsDataType::Size();
static_assert(NumABTensor == 1, "multiAB currently not supported");
using ABDataType = remove_cvref_t<tuple_element_t<0, ABsDataType>>;
const auto wave_idx = GetWaveIdx();
index_t wave_idK = wave_idx[I1];
index_t wave_idMN = wave_idx[I0];
const auto grid_lane_id = GetGridLaneIdx<ABDataType>();
index_t lane_group_grid = grid_lane_id[I0];
index_t lane_local_id_grid = grid_lane_id[I1];
const auto block_lane_id = GetBlockLaneIdx();
index_t lane_group_block = block_lane_id[I0];
index_t lane_local_id_block = block_lane_id[I1];
return ThreadGroupTransferGlobal<decltype(grid_descriptor[I0]),
const auto idx_as_block_begin = generate_tuple(
[&](auto iTensor) {
using ABDataType = remove_cvref_t<tuple_element_t<iTensor, ABsDataType>>;
const auto grid_lane_id = GetGridLaneIdx<ABDataType>();
index_t lane_group_grid = grid_lane_id[I0];
index_t lane_local_id_grid = grid_lane_id[I1];
return make_multi_index(block_mn_id * (MNRepeat_ * MNWaves_) + wave_idMN,
wave_idK,
lane_group_grid,
lane_local_id_grid);
},
Number<NumABTensor>{});
return ThreadGroupTransferGlobal<GridDescriptor,
BlockDescriptor,
ABDataType,
ABDataType,
ABsDataType,
LDSTypeAB,
ABElementwiseOperation,
Sequence<MNRepeat_, KRepeat_, I1, I1>,
Sequence<MNWaves_, KWaves_, I1, I1>,
Sequence<I0, I1, I2, I3>,
ABK1Value,
ABDoTranspose>(
grid_descriptor[I0],
ABDoTranspose,
GlobalBufferNum>(
grid_descriptor,
block_descriptor,
make_multi_index(block_mn_id * (MNRepeat_ * MNWaves_) + wave_idMN,
wave_idK,
lane_group_grid,
lane_local_id_grid),
idx_as_block_begin,
make_multi_index(wave_idMN, wave_idK, lane_group_block, lane_local_id_block),
ab_element_op);
}
@@ -398,6 +398,12 @@ struct ABTransferWaveTiles
{
return make_dynamic_buffer<AddressSpaceEnum::Lds>(p_shared_AB, size);
}
template <index_t numElements, typename Type>
__device__ __forceinline__ static auto get_first_element_workaround(Type& array)
{
return array;
}
};
} // namespace ck

View File

@@ -218,45 +218,46 @@ struct ABTransferWaveTilesInterleave : ABTransferWaveTiles<ABLayout,
const index_t block_mn_id,
const index_t)
{
// Note: GlobalBufferNum is currently not used but it will be needed
// once we add other pipelines. It is currently needed only for
// consistency with the thread tiles approach
static_assert(GlobalBufferNum == 1, "single global buffer is only supported");
constexpr index_t NumABTensor = ABsDataType::Size();
static_assert(NumABTensor == 1, "multiAB currently not supported");
using ABDataType = remove_cvref_t<tuple_element_t<0, ABsDataType>>;
const auto wave_idx = GetWaveIdx();
index_t wave_idK = wave_idx[I1];
index_t wave_idMN = wave_idx[I0];
const auto grid_lane_id = Base::template GetGridLaneIdx<ABDataType>();
index_t lane_group_grid = grid_lane_id[I0];
index_t lane_local_id_grid = grid_lane_id[I1];
const auto block_lane_id = GetBlockLaneIdx();
index_t lane_group_block = block_lane_id[I0];
index_t lane_local_id_block = block_lane_id[I1];
constexpr index_t MNRepeatRatio = MNRepeat_Grid / MNRepeat_;
return ThreadGroupTransferGlobal<decltype(grid_descriptor[I0]),
const auto idx_as_block_begin = generate_tuple(
[&](auto iTensor) {
using ABDataType = remove_cvref_t<tuple_element_t<iTensor, ABsDataType>>;
const auto grid_lane_id = Base::template GetGridLaneIdx<ABDataType>();
index_t lane_group_grid = grid_lane_id[I0];
index_t lane_local_id_grid = grid_lane_id[I1];
return make_multi_index(block_mn_id * MNWaves_Grid + wave_idMN / MNRepeatRatio,
wave_idK * KRepeat_Grid,
(wave_idMN % MNRepeatRatio) * MNRepeat_,
lane_group_grid,
lane_local_id_grid);
},
Number<NumABTensor>{});
return ThreadGroupTransferGlobal<GridDescriptor,
BlockDescriptor,
ABDataType,
ABDataType,
ABsDataType,
LDSTypeAB,
ABElementwiseOperation,
Sequence<I1, KRepeat_, MNRepeat_, I1, I1>,
Sequence<I1, KWaves_, I1, I1, I1>,
Sequence<I0, I1, I2, I3, I4>,
ABK1Value,
ABDoTranspose>(
grid_descriptor[I0],
ABDoTranspose,
GlobalBufferNum>(
grid_descriptor,
block_descriptor,
make_multi_index(block_mn_id * MNWaves_Grid + wave_idMN / MNRepeatRatio,
wave_idK * KRepeat_Grid,
(wave_idMN % MNRepeatRatio) * MNRepeat_,
lane_group_grid,
lane_local_id_grid),
idx_as_block_begin,
make_multi_index(wave_idMN / MNRepeatRatio,
wave_idK * KRepeat_,
(wave_idMN % MNRepeatRatio) * MNRepeat_,

View File

@@ -364,7 +364,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
__host__ __device__ static constexpr bool AWaveTransferApplicable()
{
return !ForceThreadTileTransfer && NumATensor == 1 && APackedSize == 1 &&
return !ForceThreadTileTransfer && APackedSize == 1 &&
ABlockTransferSrcScalarPerVector == 8 && ABlockTransferDstScalarPerVector_AK1 == 8 &&
BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && AK1Value == 8 &&
!IsBPreShuffled;
@@ -372,13 +372,11 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
__host__ __device__ static constexpr bool BWaveTransferApplicable()
{
return !ForceThreadTileTransfer && NumBTensor == 1 && BPackedSize == 1 &&
return !ForceThreadTileTransfer && BPackedSize == 1 &&
BBlockTransferSrcScalarPerVector == 8 && BBlockTransferDstScalarPerVector_BK1 == 8 &&
BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && BK1Value == 8;
}
// Limitations of the current implementation:
// - no multiAB
#ifdef __gfx12__
static constexpr bool IsAWaveTransferApplicable = AWaveTransferApplicable();
@@ -1319,19 +1317,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
}
}
template <index_t numElements, typename Type>
__device__ __forceinline__ static auto get_first_element_workaround(Type& array)
{
if constexpr(numElements > 1)
{
return array;
}
else
{
return array[I0];
}
}
// Note: arguments k_batch and k_id should be set if splitk is used
// with implicit gemm (no pointer shift but shift using tensor descriptors)
template <typename AGridDesc_AK0_M_K1,
@@ -1435,16 +1420,16 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
ATransfer::GetKDimension(as_grid_desc_ak0_m_ak1[I0]) / (KPerBlock * k_batch));
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
get_first_element_workaround<NumATensor>(as_grid_desc_ak0_m_ak1),
ATransfer::template get_first_element_workaround<NumATensor>(as_grid_desc_ak0_m_ak1),
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
get_first_element_workaround<NumATensor>(as_grid_buf),
ATransfer::template get_first_element_workaround<NumATensor>(as_grid_buf),
a_block_buf,
a_block_slice_copy_step,
get_first_element_workaround<NumBTensor>(bs_grid_desc_bk0_n_bk1),
BTransfer::template get_first_element_workaround<NumBTensor>(bs_grid_desc_bk0_n_bk1),
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
get_first_element_workaround<NumBTensor>(bs_grid_buf),
BTransfer::template get_first_element_workaround<NumBTensor>(bs_grid_buf),
b_block_buf,
b_block_slice_copy_step,
c_thread_buf,