mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
Grouped convolution forward with clamp (#2334)
* Grouped convolution forward with clamp * Optimize clamp * unary fixes * test gk bias * Revert "test gk bias" This reverts commit8e42e29d7b. * Revert "Revert "test gk bias"" This reverts commite73c0550ce. * workaround comment
This commit is contained in:
@@ -311,8 +311,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
|
||||
static_assert(NumGroupsToMerge >= 1);
|
||||
|
||||
static constexpr bool isMultiA = is_detected<is_tuple, ADataType>::value;
|
||||
static constexpr bool isMultiB = is_detected<is_tuple, BDataType>::value;
|
||||
static constexpr bool isMultiA = is_detected<is_tuple, ADataType>::value;
|
||||
static constexpr bool isMultiB = is_detected<is_tuple, BDataType>::value;
|
||||
static constexpr bool isMultiAB = isMultiA || isMultiB;
|
||||
|
||||
// NGCHW is not supported for multiAB
|
||||
static_assert(!(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
@@ -323,6 +324,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
static constexpr index_t NumBTensor = GetNumABTensors<isMultiB, BDataType>();
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static constexpr bool DoElementwiseBeforeCShuffle =
|
||||
NumDTensor == 0 && !isMultiAB && is_same_v<EDataType, bhalf_t> &&
|
||||
!is_same_v<CDEElementwiseOperation, tensor_operation::element_wise::PassThrough>;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
@@ -465,7 +470,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
|
||||
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
|
||||
BComputeDataType
|
||||
BComputeDataType, DoElementwiseBeforeCShuffle
|
||||
// Use appropriate gridwise gemm
|
||||
using GridwiseGemm = std::conditional_t<
|
||||
isMultiA || isMultiB,
|
||||
|
||||
@@ -279,6 +279,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
static constexpr bool isMultiD = DsDataType::Size() > 0;
|
||||
static constexpr bool isMultiABD = isMultiA || isMultiB || isMultiD;
|
||||
|
||||
static constexpr bool DoElementwiseBeforeCShuffle =
|
||||
!isMultiABD && is_same_v<EDataType, bhalf_t> &&
|
||||
!is_same_v<CDEElementwiseOperation, tensor_operation::element_wise::PassThrough>;
|
||||
|
||||
static constexpr index_t NumATensor = GetNumABTensors<isMultiA, ADataType>();
|
||||
static constexpr index_t NumBTensor = GetNumABTensors<isMultiB, BDataType>();
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
@@ -412,7 +416,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
|
||||
CDEBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, \
|
||||
AComputeDataType, BComputeDataType
|
||||
AComputeDataType, BComputeDataType, false, false, DoElementwiseBeforeCShuffle
|
||||
|
||||
// Use appropriate gridwise gemm
|
||||
using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3<GridwiseGemmV3TemplateParams>;
|
||||
@@ -780,8 +784,20 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
sizeof(EDataType);
|
||||
}
|
||||
|
||||
typename GridwiseGemm::Argument gemm_arg{
|
||||
p_a_grid, p_b_grid, p_e_grid, GemmM, GemmN, GemmK, I0, I0, I0, I1};
|
||||
typename GridwiseGemm::Argument gemm_arg{p_a_grid,
|
||||
p_b_grid,
|
||||
p_e_grid,
|
||||
GemmM,
|
||||
GemmN,
|
||||
GemmK,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
I1,
|
||||
false,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_};
|
||||
|
||||
const auto Run = [&](const auto& kernel) {
|
||||
if(stream_config.flush_cache)
|
||||
|
||||
@@ -192,6 +192,9 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
static constexpr index_t MaxGemmsNum = 32;
|
||||
static constexpr bool DoElementwiseBeforeCShuffle =
|
||||
NumDTensor == 0 && is_same_v<EDataType, bhalf_t> &&
|
||||
!is_same_v<CDEElementwiseOperation, tensor_operation::element_wise::PassThrough>;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
@@ -361,7 +364,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
|
||||
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
|
||||
AComputeDataType
|
||||
AComputeDataType, DoElementwiseBeforeCShuffle
|
||||
// Use appropriate gridwise gemm
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmTemplateParameters>;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user