mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
Multi AB support for wave transfer (#3578)
* Add multi AB support to wave transfer * Improviments to multi ABD examples * Add instances and use intrawave v1 instead of interwave * Apply changes to other transfers * Wave transfer: add support for multiple internal vgpr buffers * Fix compilation error gfx11
This commit is contained in:
@@ -488,6 +488,19 @@ struct ABTransferThreadTiles
|
||||
{
|
||||
return make_dynamic_buffer<AddressSpaceEnum::Lds>(p_shared_AB, size);
|
||||
}
|
||||
|
||||
template <index_t numElements, typename Type>
|
||||
__device__ __forceinline__ static auto get_first_element_workaround(Type& array)
|
||||
{
|
||||
if constexpr(numElements > 1)
|
||||
{
|
||||
return array;
|
||||
}
|
||||
else
|
||||
{
|
||||
return array[I0];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -133,6 +133,19 @@ struct ABTransferThreadTilesPreShuffle
|
||||
{
|
||||
return make_static_buffer<AddressSpaceEnum::Vgpr, LDSType>(size);
|
||||
}
|
||||
|
||||
template <index_t numElements, typename Type>
|
||||
__device__ __forceinline__ static auto get_first_element_workaround(Type& array)
|
||||
{
|
||||
if constexpr(numElements > 1)
|
||||
{
|
||||
return array;
|
||||
}
|
||||
else
|
||||
{
|
||||
return array[I0];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -318,43 +318,43 @@ struct ABTransferWaveTiles
|
||||
const index_t block_mn_id,
|
||||
const index_t)
|
||||
{
|
||||
// Note: GlobalBufferNum is currently not used but it will be needed
|
||||
// once we add other pipelines. It is currently needed only for
|
||||
// consistency with the thread tiles approach
|
||||
static_assert(GlobalBufferNum == 1, "single global buffer is only supported");
|
||||
constexpr index_t NumABTensor = ABsDataType::Size();
|
||||
static_assert(NumABTensor == 1, "multiAB currently not supported");
|
||||
|
||||
using ABDataType = remove_cvref_t<tuple_element_t<0, ABsDataType>>;
|
||||
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
index_t wave_idK = wave_idx[I1];
|
||||
index_t wave_idMN = wave_idx[I0];
|
||||
|
||||
const auto grid_lane_id = GetGridLaneIdx<ABDataType>();
|
||||
index_t lane_group_grid = grid_lane_id[I0];
|
||||
index_t lane_local_id_grid = grid_lane_id[I1];
|
||||
|
||||
const auto block_lane_id = GetBlockLaneIdx();
|
||||
index_t lane_group_block = block_lane_id[I0];
|
||||
index_t lane_local_id_block = block_lane_id[I1];
|
||||
|
||||
return ThreadGroupTransferGlobal<decltype(grid_descriptor[I0]),
|
||||
const auto idx_as_block_begin = generate_tuple(
|
||||
[&](auto iTensor) {
|
||||
using ABDataType = remove_cvref_t<tuple_element_t<iTensor, ABsDataType>>;
|
||||
const auto grid_lane_id = GetGridLaneIdx<ABDataType>();
|
||||
index_t lane_group_grid = grid_lane_id[I0];
|
||||
index_t lane_local_id_grid = grid_lane_id[I1];
|
||||
return make_multi_index(block_mn_id * (MNRepeat_ * MNWaves_) + wave_idMN,
|
||||
wave_idK,
|
||||
lane_group_grid,
|
||||
lane_local_id_grid);
|
||||
},
|
||||
Number<NumABTensor>{});
|
||||
|
||||
return ThreadGroupTransferGlobal<GridDescriptor,
|
||||
BlockDescriptor,
|
||||
ABDataType,
|
||||
ABDataType,
|
||||
ABsDataType,
|
||||
LDSTypeAB,
|
||||
ABElementwiseOperation,
|
||||
Sequence<MNRepeat_, KRepeat_, I1, I1>,
|
||||
Sequence<MNWaves_, KWaves_, I1, I1>,
|
||||
Sequence<I0, I1, I2, I3>,
|
||||
ABK1Value,
|
||||
ABDoTranspose>(
|
||||
grid_descriptor[I0],
|
||||
ABDoTranspose,
|
||||
GlobalBufferNum>(
|
||||
grid_descriptor,
|
||||
block_descriptor,
|
||||
make_multi_index(block_mn_id * (MNRepeat_ * MNWaves_) + wave_idMN,
|
||||
wave_idK,
|
||||
lane_group_grid,
|
||||
lane_local_id_grid),
|
||||
idx_as_block_begin,
|
||||
make_multi_index(wave_idMN, wave_idK, lane_group_block, lane_local_id_block),
|
||||
ab_element_op);
|
||||
}
|
||||
@@ -398,6 +398,12 @@ struct ABTransferWaveTiles
|
||||
{
|
||||
return make_dynamic_buffer<AddressSpaceEnum::Lds>(p_shared_AB, size);
|
||||
}
|
||||
|
||||
template <index_t numElements, typename Type>
|
||||
__device__ __forceinline__ static auto get_first_element_workaround(Type& array)
|
||||
{
|
||||
return array;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -218,45 +218,46 @@ struct ABTransferWaveTilesInterleave : ABTransferWaveTiles<ABLayout,
|
||||
const index_t block_mn_id,
|
||||
const index_t)
|
||||
{
|
||||
// Note: GlobalBufferNum is currently not used but it will be needed
|
||||
// once we add other pipelines. It is currently needed only for
|
||||
// consistency with the thread tiles approach
|
||||
static_assert(GlobalBufferNum == 1, "single global buffer is only supported");
|
||||
constexpr index_t NumABTensor = ABsDataType::Size();
|
||||
static_assert(NumABTensor == 1, "multiAB currently not supported");
|
||||
|
||||
using ABDataType = remove_cvref_t<tuple_element_t<0, ABsDataType>>;
|
||||
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
index_t wave_idK = wave_idx[I1];
|
||||
index_t wave_idMN = wave_idx[I0];
|
||||
|
||||
const auto grid_lane_id = Base::template GetGridLaneIdx<ABDataType>();
|
||||
index_t lane_group_grid = grid_lane_id[I0];
|
||||
index_t lane_local_id_grid = grid_lane_id[I1];
|
||||
|
||||
const auto block_lane_id = GetBlockLaneIdx();
|
||||
index_t lane_group_block = block_lane_id[I0];
|
||||
index_t lane_local_id_block = block_lane_id[I1];
|
||||
|
||||
constexpr index_t MNRepeatRatio = MNRepeat_Grid / MNRepeat_;
|
||||
return ThreadGroupTransferGlobal<decltype(grid_descriptor[I0]),
|
||||
|
||||
const auto idx_as_block_begin = generate_tuple(
|
||||
[&](auto iTensor) {
|
||||
using ABDataType = remove_cvref_t<tuple_element_t<iTensor, ABsDataType>>;
|
||||
const auto grid_lane_id = Base::template GetGridLaneIdx<ABDataType>();
|
||||
index_t lane_group_grid = grid_lane_id[I0];
|
||||
index_t lane_local_id_grid = grid_lane_id[I1];
|
||||
return make_multi_index(block_mn_id * MNWaves_Grid + wave_idMN / MNRepeatRatio,
|
||||
wave_idK * KRepeat_Grid,
|
||||
(wave_idMN % MNRepeatRatio) * MNRepeat_,
|
||||
lane_group_grid,
|
||||
lane_local_id_grid);
|
||||
},
|
||||
Number<NumABTensor>{});
|
||||
|
||||
return ThreadGroupTransferGlobal<GridDescriptor,
|
||||
BlockDescriptor,
|
||||
ABDataType,
|
||||
ABDataType,
|
||||
ABsDataType,
|
||||
LDSTypeAB,
|
||||
ABElementwiseOperation,
|
||||
Sequence<I1, KRepeat_, MNRepeat_, I1, I1>,
|
||||
Sequence<I1, KWaves_, I1, I1, I1>,
|
||||
Sequence<I0, I1, I2, I3, I4>,
|
||||
ABK1Value,
|
||||
ABDoTranspose>(
|
||||
grid_descriptor[I0],
|
||||
ABDoTranspose,
|
||||
GlobalBufferNum>(
|
||||
grid_descriptor,
|
||||
block_descriptor,
|
||||
make_multi_index(block_mn_id * MNWaves_Grid + wave_idMN / MNRepeatRatio,
|
||||
wave_idK * KRepeat_Grid,
|
||||
(wave_idMN % MNRepeatRatio) * MNRepeat_,
|
||||
lane_group_grid,
|
||||
lane_local_id_grid),
|
||||
idx_as_block_begin,
|
||||
make_multi_index(wave_idMN / MNRepeatRatio,
|
||||
wave_idK * KRepeat_,
|
||||
(wave_idMN % MNRepeatRatio) * MNRepeat_,
|
||||
|
||||
@@ -364,7 +364,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
|
||||
__host__ __device__ static constexpr bool AWaveTransferApplicable()
|
||||
{
|
||||
return !ForceThreadTileTransfer && NumATensor == 1 && APackedSize == 1 &&
|
||||
return !ForceThreadTileTransfer && APackedSize == 1 &&
|
||||
ABlockTransferSrcScalarPerVector == 8 && ABlockTransferDstScalarPerVector_AK1 == 8 &&
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && AK1Value == 8 &&
|
||||
!IsBPreShuffled;
|
||||
@@ -372,13 +372,11 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
|
||||
__host__ __device__ static constexpr bool BWaveTransferApplicable()
|
||||
{
|
||||
return !ForceThreadTileTransfer && NumBTensor == 1 && BPackedSize == 1 &&
|
||||
return !ForceThreadTileTransfer && BPackedSize == 1 &&
|
||||
BBlockTransferSrcScalarPerVector == 8 && BBlockTransferDstScalarPerVector_BK1 == 8 &&
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && BK1Value == 8;
|
||||
}
|
||||
|
||||
// Limitations of the current implementation:
|
||||
// - no multiAB
|
||||
#ifdef __gfx12__
|
||||
static constexpr bool IsAWaveTransferApplicable = AWaveTransferApplicable();
|
||||
|
||||
@@ -1319,19 +1317,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t numElements, typename Type>
|
||||
__device__ __forceinline__ static auto get_first_element_workaround(Type& array)
|
||||
{
|
||||
if constexpr(numElements > 1)
|
||||
{
|
||||
return array;
|
||||
}
|
||||
else
|
||||
{
|
||||
return array[I0];
|
||||
}
|
||||
}
|
||||
|
||||
// Note: arguments k_batch and k_id should be set if splitk is used
|
||||
// with implicit gemm (no pointer shift but shift using tensor descriptors)
|
||||
template <typename AGridDesc_AK0_M_K1,
|
||||
@@ -1435,16 +1420,16 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
ATransfer::GetKDimension(as_grid_desc_ak0_m_ak1[I0]) / (KPerBlock * k_batch));
|
||||
|
||||
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
|
||||
get_first_element_workaround<NumATensor>(as_grid_desc_ak0_m_ak1),
|
||||
ATransfer::template get_first_element_workaround<NumATensor>(as_grid_desc_ak0_m_ak1),
|
||||
a_block_desc_ak0_m_ak1,
|
||||
a_blockwise_copy,
|
||||
get_first_element_workaround<NumATensor>(as_grid_buf),
|
||||
ATransfer::template get_first_element_workaround<NumATensor>(as_grid_buf),
|
||||
a_block_buf,
|
||||
a_block_slice_copy_step,
|
||||
get_first_element_workaround<NumBTensor>(bs_grid_desc_bk0_n_bk1),
|
||||
BTransfer::template get_first_element_workaround<NumBTensor>(bs_grid_desc_bk0_n_bk1),
|
||||
b_block_desc_bk0_n_bk1,
|
||||
b_blockwise_copy,
|
||||
get_first_element_workaround<NumBTensor>(bs_grid_buf),
|
||||
BTransfer::template get_first_element_workaround<NumBTensor>(bs_grid_buf),
|
||||
b_block_buf,
|
||||
b_block_slice_copy_step,
|
||||
c_thread_buf,
|
||||
|
||||
Reference in New Issue
Block a user