Grouped convolution forward with clamp (#2334)

* Grouped convolution forward with clamp

* Optimize clamp

* unary fixes

* test gk bias

* Revert "test gk bias"

This reverts commit 8e42e29d7b.

* Revert "Revert "test gk bias""

This reverts commit e73c0550ce.

* workaround comment
This commit is contained in:
Bartłomiej Kocot
2025-06-16 15:36:53 +02:00
committed by GitHub
parent d996bc78be
commit f6c2ff9dce
41 changed files with 2103 additions and 106 deletions

View File

@@ -311,8 +311,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
static_assert(NumGroupsToMerge >= 1);
static constexpr bool isMultiA = is_detected<is_tuple, ADataType>::value;
static constexpr bool isMultiB = is_detected<is_tuple, BDataType>::value;
static constexpr bool isMultiA = is_detected<is_tuple, ADataType>::value;
static constexpr bool isMultiB = is_detected<is_tuple, BDataType>::value;
static constexpr bool isMultiAB = isMultiA || isMultiB;
// NGCHW is not supported for multiAB
static_assert(!(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
@@ -323,6 +324,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
static constexpr index_t NumBTensor = GetNumABTensors<isMultiB, BDataType>();
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr bool DoElementwiseBeforeCShuffle =
NumDTensor == 0 && !isMultiAB && is_same_v<EDataType, bhalf_t> &&
!is_same_v<CDEElementwiseOperation, tensor_operation::element_wise::PassThrough>;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
@@ -465,7 +470,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
BComputeDataType
BComputeDataType, DoElementwiseBeforeCShuffle
// Use appropriate gridwise gemm
using GridwiseGemm = std::conditional_t<
isMultiA || isMultiB,

View File

@@ -279,6 +279,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
static constexpr bool isMultiD = DsDataType::Size() > 0;
static constexpr bool isMultiABD = isMultiA || isMultiB || isMultiD;
static constexpr bool DoElementwiseBeforeCShuffle =
!isMultiABD && is_same_v<EDataType, bhalf_t> &&
!is_same_v<CDEElementwiseOperation, tensor_operation::element_wise::PassThrough>;
static constexpr index_t NumATensor = GetNumABTensors<isMultiA, ADataType>();
static constexpr index_t NumBTensor = GetNumABTensors<isMultiB, BDataType>();
static constexpr index_t NumDTensor = DsDataType::Size();
@@ -412,7 +416,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
CDEBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, \
AComputeDataType, BComputeDataType
AComputeDataType, BComputeDataType, false, false, DoElementwiseBeforeCShuffle
// Use appropriate gridwise gemm
using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3<GridwiseGemmV3TemplateParams>;
@@ -780,8 +784,20 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
sizeof(EDataType);
}
typename GridwiseGemm::Argument gemm_arg{
p_a_grid, p_b_grid, p_e_grid, GemmM, GemmN, GemmK, I0, I0, I0, I1};
typename GridwiseGemm::Argument gemm_arg{p_a_grid,
p_b_grid,
p_e_grid,
GemmM,
GemmN,
GemmK,
I0,
I0,
I0,
I1,
false,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_};
const auto Run = [&](const auto& kernel) {
if(stream_config.flush_cache)

View File

@@ -192,6 +192,9 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr index_t MaxGemmsNum = 32;
static constexpr bool DoElementwiseBeforeCShuffle =
NumDTensor == 0 && is_same_v<EDataType, bhalf_t> &&
!is_same_v<CDEElementwiseOperation, tensor_operation::element_wise::PassThrough>;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
@@ -361,7 +364,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
AComputeDataType
AComputeDataType, DoElementwiseBeforeCShuffle
// Use appropriate gridwise gemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmTemplateParameters>;