mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
Change output gemm type to AccDataType in two stage conv bwd wei (#1283)
This commit is contained in:
@@ -197,6 +197,12 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
K0PerBlock,
|
||||
ConvBackwardWeightSpecialization>{};
|
||||
|
||||
static constexpr index_t MaxScalarPerVectorFP32 = 4;
|
||||
static constexpr index_t WorkspaceInOutScalarPerVector =
|
||||
is_same_v<AccDataType, float>
|
||||
? math::min(CBlockTransferScalarPerVector_NWaveNPerXdl, MaxScalarPerVectorFP32)
|
||||
: CBlockTransferScalarPerVector_NWaveNPerXdl;
|
||||
|
||||
// Bytes per 32 lds bank: 32 * 4 bytes
|
||||
static constexpr auto BankLength = 128;
|
||||
static constexpr auto ElePerBank = BankLength / sizeof(ADataType);
|
||||
@@ -297,7 +303,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
@@ -337,7 +343,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
BBlockLdsN1Padding,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CBlockTransferScalarPerVector_NWaveNPerXdl,
|
||||
WorkspaceInOutScalarPerVector,
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
true,
|
||||
true,
|
||||
@@ -349,7 +355,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
static constexpr auto MakeElementwiseInputSequence()
|
||||
{
|
||||
return generate_sequence_v2(
|
||||
[&](auto) constexpr { return Number<CBlockTransferScalarPerVector_NWaveNPerXdl>{}; },
|
||||
[&](auto) constexpr { return Number<WorkspaceInOutScalarPerVector>{}; },
|
||||
Number<NumDTensor + 1>{});
|
||||
}
|
||||
|
||||
@@ -499,7 +505,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
using DsGridDesc_M_N = decltype(MakeDsGridDescriptor_M_N<NDimSpatial>({}, {}));
|
||||
using CDGridDesc_M_N = decltype(concat_tuple(Tuple<CGridDesc_M_N>{}, DsGridDesc_M_N{}));
|
||||
using DsGridPointerTuple = decltype(GetDsGridPointerTuple());
|
||||
using CDDataTypes = decltype(concat_tuple(Tuple<const EDataType*>{}, DsGridPointerTuple{}));
|
||||
using CDDataTypes = decltype(concat_tuple(Tuple<const AccDataType*>{}, DsGridPointerTuple{}));
|
||||
using EGridDesc_M_N = CGridDesc_M_N;
|
||||
static constexpr index_t ClusterLengthMPerBlock =
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1);
|
||||
@@ -659,7 +665,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
|
||||
std::size_t GetWorkspaceSizeBytes() const
|
||||
{
|
||||
return sizeof(EDataType) * ce_grid_desc_m_n_.GetElementSpaceSize() * Conv_G_;
|
||||
return sizeof(AccDataType) * ce_grid_desc_m_n_.GetElementSpaceSize() * Conv_G_;
|
||||
}
|
||||
|
||||
const ADataType* p_a_grid_;
|
||||
@@ -738,7 +744,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
|
||||
|
||||
auto launch_gemm_kernel = [&](auto has_main_k_block_loop) {
|
||||
EDataType* p_c_grid = type_convert<EDataType*>(arg.p_workspace_);
|
||||
AccDataType* p_c_grid = type_convert<AccDataType*>(arg.p_workspace_);
|
||||
const index_t grid_size =
|
||||
arg.block_2_ctile_map_.CalculateGridSize(arg.ce_grid_desc_m_n_) * arg.Conv_G_;
|
||||
|
||||
@@ -753,7 +759,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
GridwiseGemm,
|
||||
ADataType,
|
||||
BDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
OutElementwiseOperation,
|
||||
InElementwiseOperation,
|
||||
element_wise::PassThrough,
|
||||
@@ -786,7 +792,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
};
|
||||
|
||||
auto launch_elementwise_kernel = [&]() {
|
||||
const EDataType* p_c_grid = type_convert<const EDataType*>(arg.p_workspace_);
|
||||
const AccDataType* p_c_grid = type_convert<const AccDataType*>(arg.p_workspace_);
|
||||
const index_t grid_size =
|
||||
arg.elementwise_block_2_ctile_map_.CalculateGridSize(arg.ce_grid_desc_m_n_) *
|
||||
arg.Conv_G_;
|
||||
@@ -907,7 +913,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
}
|
||||
|
||||
// vector store C matrix into global memory
|
||||
if(!(arg.Conv_C_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0))
|
||||
if(!(arg.Conv_C_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0 &&
|
||||
arg.Conv_C_ % WorkspaceInOutScalarPerVector == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user