mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
Refactor thread_copy_lds_direct_load; fix gfx942 direct lds load example; fix f16_pki4 example
This commit is contained in:
@@ -38,7 +38,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
|
||||
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>;
|
||||
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 0, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>;
|
||||
// clang-format on
|
||||
#else
|
||||
// clang-format off
|
||||
|
||||
@@ -60,7 +60,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShu
|
||||
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
|
||||
//######| | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | Wave| Wave| Lengths_KBatch_K0_M_K1| | | PerVector| | Lengths_KBatch_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 2, 128, 32, 16, 4, 16, 16, 16, 1, 1, S<1, 2, 8, 8>, S<0, 2, 1, 3>, 3, 2, true, S<1, 2, 8, 8>, S<0, 2, 1, 3>, 3, 2, true, 1, 1, S<1, 32, 1, 4>, 4>;
|
||||
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 2, 128, 32, 16, 4, 8, 16, 16, 1, 1, S<1, 4, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 32, 1, 4>, 4>;
|
||||
// clang-format on
|
||||
|
||||
#else
|
||||
|
||||
@@ -145,7 +145,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v3<BlockGemmPipelineSch
|
||||
using Base::MWaves;
|
||||
|
||||
static constexpr auto xdlops_gemm =
|
||||
XdlopsGemm<ComputeDataType, MPerXDL, NPerXDL, KPack, BDataType>{};
|
||||
XdlopsGemm<ComputeDataType, MPerXDL, NPerXDL, KPack, ComputeDataType>{};
|
||||
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
|
||||
@@ -50,8 +50,7 @@ template <typename ThreadGroup,
|
||||
typename SrcDimAccessOrder,
|
||||
index_t SrcVectorDim,
|
||||
index_t DstVectorDim,
|
||||
index_t ScalarPerVector,
|
||||
bool SrcXor = true>
|
||||
index_t ScalarPerVector>
|
||||
struct ThreadGroupTensorSliceTransfer_DirectLoad
|
||||
{
|
||||
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
|
||||
@@ -68,20 +67,12 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
|
||||
|
||||
static constexpr auto block_slice_lengths = BlockSliceLengths{};
|
||||
static constexpr auto thread_cluster_lengths = ThreadClusterLengths{};
|
||||
static constexpr auto wave_thread_cluster_lengths =
|
||||
Sequence<ThreadClusterLengths{}.At(I0),
|
||||
ThreadClusterLengths{}.At(I1) * 64 / ThreadGroup::GetNumOfThread(),
|
||||
1>{};
|
||||
static constexpr auto wave_cluster_lengths =
|
||||
Sequence<1, ThreadGroup::GetNumOfThread() / 64, 1>{};
|
||||
|
||||
static constexpr auto thread_single_load_size = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, ScalarPerVector>{}, Number<nDim>{});
|
||||
// After a load, each thread moves by `thread_steps` instead of loading the next elements.
|
||||
// It makes the whole wavefront load contiguous memory, what is required for direct loads.
|
||||
static constexpr auto thread_steps = thread_cluster_lengths * thread_single_load_size;
|
||||
static constexpr auto wave_single_load_size =
|
||||
wave_thread_cluster_lengths * thread_single_load_size;
|
||||
static constexpr auto thread_steps = thread_cluster_lengths * thread_single_load_size;
|
||||
static constexpr auto thread_slice_lengths = block_slice_lengths / thread_steps;
|
||||
|
||||
static __device__ constexpr bool AreThreadClusterLengthsValid()
|
||||
@@ -180,6 +171,25 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc_.CalculateBottomIndex(make_multi_index(ThreadGroup::GetThreadId()));
|
||||
|
||||
constexpr auto wave_cluster_lengths = generate_sequence_v2(
|
||||
[&](auto i) {
|
||||
if constexpr(ThreadClusterArrangeOrder{}.At(i) == (nDim - 3))
|
||||
{
|
||||
return Number<ThreadGroup::GetNumOfThread() / 64>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return I1;
|
||||
}
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto wave_thread_cluster_lengths = ThreadClusterLengths{} / wave_cluster_lengths;
|
||||
constexpr auto wave_single_load_size =
|
||||
wave_thread_cluster_lengths * thread_single_load_size;
|
||||
constexpr auto wave_cluster_desc_ =
|
||||
make_cluster_descriptor(wave_cluster_lengths, ThreadClusterArrangeOrder{});
|
||||
|
||||
const auto wave_cluster_idx = wave_cluster_desc_.CalculateBottomIndex(
|
||||
make_multi_index(ThreadGroup::GetThreadId() / 64));
|
||||
|
||||
@@ -327,8 +337,6 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
|
||||
private:
|
||||
static constexpr auto thread_cluster_desc_ =
|
||||
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
|
||||
static constexpr auto wave_cluster_desc_ =
|
||||
make_cluster_descriptor(wave_cluster_lengths, ThreadClusterArrangeOrder{});
|
||||
|
||||
SrcCoord src_coord_;
|
||||
DstCoord dst_coord_;
|
||||
|
||||
@@ -98,10 +98,12 @@ struct DeviceGemmXdlSplitKCShuffle_LdsDirectLoad : public DeviceGemmSplitK<ALayo
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferScalarPerVector,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferScalarPerVector,
|
||||
BBlockLdsAddExtraN,
|
||||
|
||||
@@ -174,17 +174,15 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
|
||||
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
|
||||
{
|
||||
// A matrix in LDS memory, destination of blockwise copy.
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(AK0PerBlock, Number<MPerBlock>{}, AK1),
|
||||
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1, AK1, I1));
|
||||
return make_naive_tensor_descriptor(make_tuple(AK0PerBlock, Number<MPerBlock>{}, AK1),
|
||||
make_tuple(AK1, Number<KPerBlock>{}, I1));
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
|
||||
{
|
||||
// B matrix in LDS memory, destination of blockwise copy.
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(BK0PerBlock, Number<NPerBlock>{}, BK1),
|
||||
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1));
|
||||
return make_naive_tensor_descriptor(make_tuple(BK0PerBlock, Number<NPerBlock>{}, BK1),
|
||||
make_tuple(BK1, Number<KPerBlock>{}, I1));
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
@@ -566,10 +564,12 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
|
||||
ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock,
|
||||
Sequence<AK0PerBlock, MPerBlock, AK1>,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ADataType,
|
||||
AComputeDataType,
|
||||
decltype(a_grid_desc_ak0_m_ak1),
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
2,
|
||||
ABlockTransferScalarPerVector>(
|
||||
@@ -582,10 +582,12 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
|
||||
ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock,
|
||||
Sequence<BK0PerBlock, NPerBlock, BK1>,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BDataType,
|
||||
BComputeDataType,
|
||||
decltype(b_grid_desc_bk0_n_bk1),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
2,
|
||||
BBlockTransferScalarPerVector>(
|
||||
|
||||
@@ -76,10 +76,12 @@ template <index_t BlockSize,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
bool BBlockLdsExtraN,
|
||||
@@ -102,9 +104,10 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load
|
||||
static constexpr auto I7 = Number<7>{};
|
||||
|
||||
// K1 should be Number<...>
|
||||
static constexpr auto K1 = Number<K1Value>{};
|
||||
static constexpr auto M01 = 1;
|
||||
static constexpr auto N01 = 1;
|
||||
static constexpr auto K1 = Number<K1Value>{};
|
||||
static constexpr auto KPerBlock = Number<K1Value * K0PerBlock>{};
|
||||
static constexpr auto M01 = 1;
|
||||
static constexpr auto N01 = 1;
|
||||
|
||||
static constexpr auto gemm_padder =
|
||||
tensor_operation::device::GemmPadder<GemmSpec, index_t, index_t, index_t>{
|
||||
@@ -613,8 +616,9 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(K1, Number<KPerBlock>{}, I1));
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -630,9 +634,10 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
max_lds_align);
|
||||
make_tuple(
|
||||
Number<KPerBlock>{} * Number<MPerBlock>{}, K1, Number<KPerBlock>{}, I1));
|
||||
}
|
||||
}();
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
@@ -645,8 +650,9 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(K1, Number<KPerBlock>{}, I1));
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -662,9 +668,10 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
max_lds_align);
|
||||
make_tuple(
|
||||
Number<KPerBlock>{} * Number<NPerBlock>{}, K1, Number<KPerBlock>{}, I1));
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -672,10 +679,12 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load
|
||||
ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock,
|
||||
Sequence<1, K0PerBlock, MPerBlock, K1>,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
FloatA,
|
||||
ComputeType,
|
||||
decltype(a_b_k0_m_k1_grid_desc),
|
||||
decltype(a_b_k0_m_k1_block_desc),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
3,
|
||||
ABlockTransferSrcScalarPerVector>(
|
||||
@@ -688,10 +697,12 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load
|
||||
ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock,
|
||||
Sequence<1, K0PerBlock, NPerBlock, K1>,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
FloatB,
|
||||
ComputeType,
|
||||
decltype(b_b_k0_n_k1_grid_desc),
|
||||
decltype(b_b_k0_n_k1_block_desc),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
3,
|
||||
BBlockTransferSrcScalarPerVector>(
|
||||
|
||||
@@ -62,6 +62,18 @@ struct lambda_scalar_per_access_for_src_and_dst
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t WaveNum, index_t nDim>
|
||||
struct lambda_wave_cluster_dimension
|
||||
{
|
||||
__host__ __device__ constexpr auto operator()(index_t i) const
|
||||
{
|
||||
if((nDim - i) == 3)
|
||||
return WaveNum;
|
||||
else
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -1159,15 +1159,6 @@ struct MfmaSelector
|
||||
#endif
|
||||
}
|
||||
|
||||
// Use singal rate mfma instruction for this special case A (f8_t) * B (pk_i4_t)
|
||||
// See example gemm_xdl_fp8_pk_i4_bpreshuffle_v3
|
||||
// TODO: explore optimization opportunity by using new mfma instructions on gfx950
|
||||
template <>
|
||||
constexpr auto GetMfma<f8_t, 32, 32, pk_i4_t, true, false>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_32x32x16f8f8;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<f8_t, 32, 32, f8_t, true, false>()
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user