Optimize clamp

This commit is contained in:
Bartlomiej Kocot
2025-06-13 09:55:51 +00:00
parent b145e825f9
commit 2d4c5129ce
5 changed files with 140 additions and 73 deletions

View File

@@ -311,8 +311,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
static_assert(NumGroupsToMerge >= 1);
static constexpr bool isMultiA = is_detected<is_tuple, ADataType>::value;
static constexpr bool isMultiB = is_detected<is_tuple, BDataType>::value;
static constexpr bool isMultiA = is_detected<is_tuple, ADataType>::value;
static constexpr bool isMultiB = is_detected<is_tuple, BDataType>::value;
static constexpr bool isMultiAB = isMultiA || isMultiB;
// NGCHW is not supported for multiAB
static_assert(!(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
@@ -323,6 +324,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
static constexpr index_t NumBTensor = GetNumABTensors<isMultiB, BDataType>();
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr bool DoElementwiseBeforeCShuffle =
NumDTensor == 0 && !isMultiAB && is_same_v<EDataType, bhalf_t> &&
!is_same_v<CDEElementwiseOperation, tensor_operation::element_wise::PassThrough>;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
@@ -465,7 +470,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
BComputeDataType
BComputeDataType, DoElementwiseBeforeCShuffle
// Use appropriate gridwise gemm
using GridwiseGemm = std::conditional_t<
isMultiA || isMultiB,

View File

@@ -279,6 +279,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
static constexpr bool isMultiD = DsDataType::Size() > 0;
static constexpr bool isMultiABD = isMultiA || isMultiB || isMultiD;
static constexpr bool DoElementwiseBeforeCShuffle =
!isMultiABD && is_same_v<EDataType, bhalf_t> &&
!is_same_v<CDEElementwiseOperation, tensor_operation::element_wise::PassThrough>;
static constexpr index_t NumATensor = GetNumABTensors<isMultiA, ADataType>();
static constexpr index_t NumBTensor = GetNumABTensors<isMultiB, BDataType>();
static constexpr index_t NumDTensor = DsDataType::Size();
@@ -412,7 +416,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
CDEBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, \
AComputeDataType, BComputeDataType
AComputeDataType, BComputeDataType, false, false, DoElementwiseBeforeCShuffle
// Use appropriate gridwise gemm
using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3<GridwiseGemmV3TemplateParams>;

View File

@@ -192,6 +192,9 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr index_t MaxGemmsNum = 32;
static constexpr bool DoElementwiseBeforeCShuffle =
NumDTensor == 0 && is_same_v<EDataType, bhalf_t> &&
!is_same_v<CDEElementwiseOperation, tensor_operation::element_wise::PassThrough>;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
@@ -361,7 +364,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
AComputeDataType
AComputeDataType, DoElementwiseBeforeCShuffle
// Use appropriate gridwise gemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmTemplateParameters>;

View File

@@ -71,11 +71,13 @@ template <typename ADataType,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched,
PipelineVersion PipelineVer = PipelineVersion::v1,
typename BComputeDataType_ = AComputeDataType_>
PipelineVersion PipelineVer = PipelineVersion::v1,
typename BComputeDataType_ = AComputeDataType_,
bool DoElementwiseBeforeCShuffle = false>
struct GridwiseGemmMultipleD_xdl_cshuffle
{
static constexpr index_t NumDTensor = DsDataType::Size();
static_assert(!DoElementwiseBeforeCShuffle || NumDTensor == 0);
using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization;
@@ -796,37 +798,60 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
tensor_operation::element_wise::PassThrough pass_through{};
const auto& vpgr_to_lds_element_op = [&] {
if constexpr(DoElementwiseBeforeCShuffle)
{
return cde_element_op;
}
else
{
return pass_through;
}
};
const auto& lds_to_global_element_op = [&] {
if constexpr(!DoElementwiseBeforeCShuffle)
{
return cde_element_op;
}
else
{
return pass_through;
}
};
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
conditional_t<DoElementwiseBeforeCShuffle,
CDEElementwiseOperation,
tensor_operation::element_wise::PassThrough>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
vpgr_to_lds_element_op()};
// tuple of reference to C/Ds tensor descriptors
const auto c_ds_desc_refs = concat_tuple_of_reference(
@@ -860,7 +885,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
Tuple<EDataType>,
decltype(c_ds_desc_refs),
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
CDEElementwiseOperation,
conditional_t<!DoElementwiseBeforeCShuffle,
CDEElementwiseOperation,
tensor_operation::element_wise::PassThrough>,
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
// support arbitray type
Sequence<1,
@@ -881,7 +908,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
idx_c_ds_block_begin,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)),
cde_element_op};
lds_to_global_element_op()};
// space filling curve for threadwise C in VGPR before shuffle
constexpr auto sfc_c_vgpr =

View File

@@ -186,6 +186,8 @@ __global__ void
/// in global memory. Currently not supported!
/// @tparam PermuteB Whether the B input tensor has gridwise-gemm friendly data layout
/// in global memory (pre-shuffled).
/// @tparam DoElementwiseBeforeCShuffle Whether the cde_elementwise should be performed before or
/// after elementwise op.
template <typename ALayout,
typename BLayout,
typename CLayout,
@@ -233,7 +235,8 @@ template <typename ALayout,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA,
bool PermuteA = false,
bool PermuteB = false>
bool PermuteB = false,
bool DoElementwiseBeforeCShuffle = false>
struct GridwiseGemm_xdl_cshuffle_v3
{
static constexpr auto I0 = Number<0>{};
@@ -1615,42 +1618,67 @@ struct GridwiseGemm_xdl_cshuffle_v3
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
tensor_operation::element_wise::PassThrough pass_through{};
const auto& vpgr_to_lds_element_op = [&] {
if constexpr(DoElementwiseBeforeCShuffle)
{
return problem.c_element_op_;
}
else
{
return pass_through;
}
};
const auto& lds_to_global_element_op = [&] {
if constexpr(!DoElementwiseBeforeCShuffle)
{
return problem.c_element_op_;
}
else
{
return pass_through;
}
};
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
conditional_t<DoElementwiseBeforeCShuffle,
CElementwiseOperation,
tensor_operation::element_wise::PassThrough>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
vpgr_to_lds_element_op()};
// shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // ThreadGroup
CElementwiseOperation, // ElementwiseOperation,
ThisThreadBlock, // ThreadGroup
conditional_t<!DoElementwiseBeforeCShuffle,
CElementwiseOperation,
tensor_operation::element_wise::PassThrough>,
CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
@@ -1671,7 +1699,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_multi_index(0, 0, 0, 0),
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_m_id, 0, block_n_id, 0),
problem.c_element_op_};
lds_to_global_element_op()};
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =