mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-03 13:48:30 +00:00
Add packed cast to gridwise gemm multi d.
This commit is contained in:
@@ -8,6 +8,7 @@
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/packed_cast.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
|
||||
@@ -1307,6 +1308,19 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
c_grid_desc_m_n);
|
||||
}
|
||||
|
||||
// check if we should apply gf950 device specific optimization for BF16 output
|
||||
__device__ static constexpr bool is_gfx650_and_bf16_output()
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
return
|
||||
!DoElementwiseBeforeCShuffle &&
|
||||
std::is_same_v<CShuffleDataType, ck::bhalf_t> &&
|
||||
std::is_same_v<AccDataType, float>;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
TailNumber TailNum,
|
||||
@@ -1578,29 +1592,53 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
}
|
||||
};
|
||||
|
||||
using ThreadwiseTransfer = std::conditional_t<
|
||||
is_gfx650_and_bf16_output(),
|
||||
ThreadwiseTensorSliceTransfer_v1r3_pass_through<
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
tensor_operation::element_wise::PassThrough,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
M2,
|
||||
I1,
|
||||
M4,
|
||||
I1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
7,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>,
|
||||
ThreadwiseTensorSliceTransfer_v1r3<
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
conditional_t<DoElementwiseBeforeCShuffle,
|
||||
CElementwiseOperation,
|
||||
tensor_operation::element_wise::PassThrough>,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
M2,
|
||||
I1,
|
||||
M4,
|
||||
I1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
7,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>>;
|
||||
|
||||
// shuffle: threadwise copy C from VGPR to LDS
|
||||
auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
conditional_t<DoElementwiseBeforeCShuffle,
|
||||
CElementwiseOperation,
|
||||
tensor_operation::element_wise::PassThrough>,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
M2,
|
||||
I1,
|
||||
M4,
|
||||
I1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
7,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
auto c_thread_copy_vgpr_to_lds = ThreadwiseTransfer{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
make_multi_index(0,
|
||||
0,
|
||||
m_thread_data_on_block_idx[I1],
|
||||
@@ -1718,6 +1756,21 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
// make sure it's safe to write to LDS
|
||||
block_sync_lds();
|
||||
|
||||
if constexpr (is_gfx650_and_bf16_output())
|
||||
{
|
||||
auto c_thread_packed_cast = PackedCastV2<
|
||||
M2,
|
||||
M4,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle
|
||||
>{};
|
||||
c_thread_packed_cast.Run(
|
||||
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc (TensorDescriptor struct)
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin
|
||||
c_thread_buf // source buffer
|
||||
);
|
||||
}
|
||||
|
||||
// each thread write its data from VGPR to LDS
|
||||
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
|
||||
@@ -2111,28 +2164,51 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
};
|
||||
|
||||
// shuffle: threadwise copy C from VGPR to LDS
|
||||
auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
conditional_t<DoElementwiseBeforeCShuffle,
|
||||
CElementwiseOperation,
|
||||
tensor_operation::element_wise::PassThrough>,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
M2,
|
||||
I1,
|
||||
M4,
|
||||
I1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
7,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
using ThreadwiseTransfer = std::conditional_t<
|
||||
is_gfx650_and_bf16_output(),
|
||||
ThreadwiseTensorSliceTransfer_v1r3_pass_through<
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
tensor_operation::element_wise::PassThrough,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
M2,
|
||||
I1,
|
||||
M4,
|
||||
I1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
7,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>,
|
||||
ThreadwiseTensorSliceTransfer_v1r3<
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
conditional_t<DoElementwiseBeforeCShuffle,
|
||||
CElementwiseOperation,
|
||||
tensor_operation::element_wise::PassThrough>,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
M2,
|
||||
I1,
|
||||
M4,
|
||||
I1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
7,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>>;
|
||||
auto c_thread_copy_vgpr_to_lds = ThreadwiseTransfer{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
make_multi_index(0,
|
||||
0,
|
||||
m_thread_data_on_block_idx[I1],
|
||||
@@ -2250,6 +2326,21 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
// make sure it's safe to write to LDS
|
||||
block_sync_lds();
|
||||
|
||||
if constexpr (is_gfx650_and_bf16_output())
|
||||
{
|
||||
auto c_thread_packed_cast = PackedCastV2<
|
||||
M2,
|
||||
M4,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle
|
||||
>{};
|
||||
c_thread_packed_cast.Run(
|
||||
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin
|
||||
c_thread_buf // source buffer
|
||||
);
|
||||
}
|
||||
|
||||
// each thread write its data from VGPR to LDS
|
||||
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
|
||||
|
||||
Reference in New Issue
Block a user