Add packed cast to gridwise gemm multi d.

This commit is contained in:
Ville Pietilä
2025-08-27 11:27:03 +00:00
parent 6092643e9b
commit 481df169f2

View File

@@ -8,6 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/packed_cast.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
@@ -1307,6 +1308,19 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
c_grid_desc_m_n);
}
// check if we should apply gf950 device specific optimization for BF16 output
__device__ static constexpr bool is_gfx650_and_bf16_output()
{
#if defined(__gfx950__)
return
!DoElementwiseBeforeCShuffle &&
std::is_same_v<CShuffleDataType, ck::bhalf_t> &&
std::is_same_v<AccDataType, float>;
#else
return false;
#endif
}
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum,
@@ -1578,29 +1592,53 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
}
};
using ThreadwiseTransfer = std::conditional_t<
is_gfx650_and_bf16_output(),
ThreadwiseTensorSliceTransfer_v1r3_pass_through<
AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>,
ThreadwiseTensorSliceTransfer_v1r3<
AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
conditional_t<DoElementwiseBeforeCShuffle,
CElementwiseOperation,
tensor_operation::element_wise::PassThrough>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>>;
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
conditional_t<DoElementwiseBeforeCShuffle,
CElementwiseOperation,
tensor_operation::element_wise::PassThrough>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
auto c_thread_copy_vgpr_to_lds = ThreadwiseTransfer{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
@@ -1718,6 +1756,21 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
// make sure it's safe to write to LDS
block_sync_lds();
if constexpr (is_gfx650_and_bf16_output())
{
auto c_thread_packed_cast = PackedCastV2<
M2,
M4,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle
>{};
c_thread_packed_cast.Run(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc (TensorDescriptor struct)
sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin
c_thread_buf // source buffer
);
}
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
@@ -2111,28 +2164,51 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
};
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
conditional_t<DoElementwiseBeforeCShuffle,
CElementwiseOperation,
tensor_operation::element_wise::PassThrough>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
using ThreadwiseTransfer = std::conditional_t<
is_gfx650_and_bf16_output(),
ThreadwiseTensorSliceTransfer_v1r3_pass_through<
AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>,
ThreadwiseTensorSliceTransfer_v1r3<
AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
conditional_t<DoElementwiseBeforeCShuffle,
CElementwiseOperation,
tensor_operation::element_wise::PassThrough>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>>;
auto c_thread_copy_vgpr_to_lds = ThreadwiseTransfer{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
@@ -2250,6 +2326,21 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
// make sure it's safe to write to LDS
block_sync_lds();
if constexpr (is_gfx650_and_bf16_output())
{
auto c_thread_packed_cast = PackedCastV2<
M2,
M4,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle
>{};
c_thread_packed_cast.Run(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc
sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin
c_thread_buf // source buffer
);
}
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),