Refactor thread_copy_lds_direct_load; fix gfx942 direct lds load example; fix f16_pki4 example

This commit is contained in:
aska-0096
2025-06-03 14:54:20 +00:00
parent 407489d2c0
commit 86c8bef5d7
9 changed files with 68 additions and 42 deletions

View File

@@ -38,7 +38,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>;
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 0, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>;
// clang-format on
#else
// clang-format off

View File

@@ -60,7 +60,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShu
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | Wave| Wave| Lengths_KBatch_K0_M_K1| | | PerVector| | Lengths_KBatch_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 2, 128, 32, 16, 4, 16, 16, 16, 1, 1, S<1, 2, 8, 8>, S<0, 2, 1, 3>, 3, 2, true, S<1, 2, 8, 8>, S<0, 2, 1, 3>, 3, 2, true, 1, 1, S<1, 32, 1, 4>, 4>;
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 2, 128, 32, 16, 4, 8, 16, 16, 1, 1, S<1, 4, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 32, 1, 4>, 4>;
// clang-format on
#else

View File

@@ -145,7 +145,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v3<BlockGemmPipelineSch
using Base::MWaves;
static constexpr auto xdlops_gemm =
XdlopsGemm<ComputeDataType, MPerXDL, NPerXDL, KPack, BDataType>{};
XdlopsGemm<ComputeDataType, MPerXDL, NPerXDL, KPack, ComputeDataType>{};
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;

View File

@@ -50,8 +50,7 @@ template <typename ThreadGroup,
typename SrcDimAccessOrder,
index_t SrcVectorDim,
index_t DstVectorDim,
index_t ScalarPerVector,
bool SrcXor = true>
index_t ScalarPerVector>
struct ThreadGroupTensorSliceTransfer_DirectLoad
{
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
@@ -68,20 +67,12 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
static constexpr auto block_slice_lengths = BlockSliceLengths{};
static constexpr auto thread_cluster_lengths = ThreadClusterLengths{};
static constexpr auto wave_thread_cluster_lengths =
Sequence<ThreadClusterLengths{}.At(I0),
ThreadClusterLengths{}.At(I1) * 64 / ThreadGroup::GetNumOfThread(),
1>{};
static constexpr auto wave_cluster_lengths =
Sequence<1, ThreadGroup::GetNumOfThread() / 64, 1>{};
static constexpr auto thread_single_load_size = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, ScalarPerVector>{}, Number<nDim>{});
// After a load, each thread moves by `thread_steps` instead of loading the next elements.
// It makes the whole wavefront load contiguous memory, what is required for direct loads.
static constexpr auto thread_steps = thread_cluster_lengths * thread_single_load_size;
static constexpr auto wave_single_load_size =
wave_thread_cluster_lengths * thread_single_load_size;
static constexpr auto thread_steps = thread_cluster_lengths * thread_single_load_size;
static constexpr auto thread_slice_lengths = block_slice_lengths / thread_steps;
static __device__ constexpr bool AreThreadClusterLengthsValid()
@@ -180,6 +171,25 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
const auto thread_cluster_idx =
thread_cluster_desc_.CalculateBottomIndex(make_multi_index(ThreadGroup::GetThreadId()));
constexpr auto wave_cluster_lengths = generate_sequence_v2(
[&](auto i) {
if constexpr(ThreadClusterArrangeOrder{}.At(i) == (nDim - 3))
{
return Number<ThreadGroup::GetNumOfThread() / 64>{};
}
else
{
return I1;
}
},
Number<nDim>{});
constexpr auto wave_thread_cluster_lengths = ThreadClusterLengths{} / wave_cluster_lengths;
constexpr auto wave_single_load_size =
wave_thread_cluster_lengths * thread_single_load_size;
constexpr auto wave_cluster_desc_ =
make_cluster_descriptor(wave_cluster_lengths, ThreadClusterArrangeOrder{});
const auto wave_cluster_idx = wave_cluster_desc_.CalculateBottomIndex(
make_multi_index(ThreadGroup::GetThreadId() / 64));
@@ -327,8 +337,6 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
private:
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
static constexpr auto wave_cluster_desc_ =
make_cluster_descriptor(wave_cluster_lengths, ThreadClusterArrangeOrder{});
SrcCoord src_coord_;
DstCoord dst_coord_;

View File

@@ -98,10 +98,12 @@ struct DeviceGemmXdlSplitKCShuffle_LdsDirectLoad : public DeviceGemmSplitK<ALayo
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferScalarPerVector,
ABlockLdsAddExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferScalarPerVector,
BBlockLdsAddExtraN,

View File

@@ -174,17 +174,15 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, destination of blockwise copy.
return make_naive_tensor_descriptor(
make_tuple(AK0PerBlock, Number<MPerBlock>{}, AK1),
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1, AK1, I1));
return make_naive_tensor_descriptor(make_tuple(AK0PerBlock, Number<MPerBlock>{}, AK1),
make_tuple(AK1, Number<KPerBlock>{}, I1));
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B matrix in LDS memory, destination of blockwise copy.
return make_naive_tensor_descriptor(
make_tuple(BK0PerBlock, Number<NPerBlock>{}, BK1),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1));
return make_naive_tensor_descriptor(make_tuple(BK0PerBlock, Number<NPerBlock>{}, BK1),
make_tuple(BK1, Number<KPerBlock>{}, I1));
}
__host__ __device__ static constexpr auto
@@ -566,10 +564,12 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock,
Sequence<AK0PerBlock, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferSrcAccessOrder,
ADataType,
AComputeDataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
2,
ABlockTransferScalarPerVector>(
@@ -582,10 +582,12 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock,
Sequence<BK0PerBlock, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferSrcAccessOrder,
BDataType,
BComputeDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
2,
BBlockTransferScalarPerVector>(

View File

@@ -76,10 +76,12 @@ template <index_t BlockSize,
index_t MRepeat,
index_t NRepeat,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
bool BBlockLdsExtraN,
@@ -102,9 +104,10 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load
static constexpr auto I7 = Number<7>{};
// K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{};
static constexpr auto M01 = 1;
static constexpr auto N01 = 1;
static constexpr auto K1 = Number<K1Value>{};
static constexpr auto KPerBlock = Number<K1Value * K0PerBlock>{};
static constexpr auto M01 = 1;
static constexpr auto N01 = 1;
static constexpr auto gemm_padder =
tensor_operation::device::GemmPadder<GemmSpec, index_t, index_t, index_t>{
@@ -613,8 +616,9 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(K1, Number<KPerBlock>{}, I1));
}
}();
@@ -630,9 +634,10 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load
}
else
{
return make_naive_tensor_descriptor_aligned(
return make_naive_tensor_descriptor(
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
max_lds_align);
make_tuple(
Number<KPerBlock>{} * Number<MPerBlock>{}, K1, Number<KPerBlock>{}, I1));
}
}();
// B matrix in LDS memory, dst of blockwise copy
@@ -645,8 +650,9 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(K1, Number<KPerBlock>{}, I1));
}
}();
@@ -662,9 +668,10 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load
}
else
{
return make_naive_tensor_descriptor_aligned(
return make_naive_tensor_descriptor(
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
max_lds_align);
make_tuple(
Number<KPerBlock>{} * Number<NPerBlock>{}, K1, Number<KPerBlock>{}, I1));
}
}();
@@ -672,10 +679,12 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load
ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock,
Sequence<1, K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferSrcAccessOrder,
FloatA,
ComputeType,
decltype(a_b_k0_m_k1_grid_desc),
decltype(a_b_k0_m_k1_block_desc),
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
3,
ABlockTransferSrcScalarPerVector>(
@@ -688,10 +697,12 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load
ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock,
Sequence<1, K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferSrcAccessOrder,
FloatB,
ComputeType,
decltype(b_b_k0_n_k1_grid_desc),
decltype(b_b_k0_n_k1_block_desc),
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
3,
BBlockTransferSrcScalarPerVector>(

View File

@@ -62,6 +62,18 @@ struct lambda_scalar_per_access_for_src_and_dst
}
};
template <index_t WaveNum, index_t nDim>
struct lambda_wave_cluster_dimension
{
__host__ __device__ constexpr auto operator()(index_t i) const
{
if((nDim - i) == 3)
return WaveNum;
else
return 1;
}
};
} // namespace detail
} // namespace ck

View File

@@ -1159,15 +1159,6 @@ struct MfmaSelector
#endif
}
// Use singal rate mfma instruction for this special case A (f8_t) * B (pk_i4_t)
// See example gemm_xdl_fp8_pk_i4_bpreshuffle_v3
// TODO: explore optimization opportunity by using new mfma instructions on gfx950
template <>
constexpr auto GetMfma<f8_t, 32, 32, pk_i4_t, true, false>()
{
return MfmaInstr::mfma_f32_32x32x16f8f8;
}
template <>
constexpr auto GetMfma<f8_t, 32, 32, f8_t, true, false>()
{