mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-01 20:21:23 +00:00
Conv:TF32: add more instances - 2 (#2879)
* add instances of device_grouped_conv_fwd_xdl_f32_comp_instances * add instances of device_grouped_conv_fwd_xdl_f32_tf32_mem_instances * add instances of device_grouped_conv_fwd_xdl_large_tensor_f32_tf32_instances * tf32:conv:add instances for base class DeviceConvFwd * tf32:conv:add instances for base class DeviceGroupedConvBwdDataMultipleD * tf32:conv:add instances for base class DeviceGroupedConvBwdWeight * add tf32 in profiler * remove gnhwc/ngchw/ngcdhw instances * remove non-ndhwgc/nhwgc/nhwc instances * add check in IsSupportedArgument()
This commit is contained in:
@@ -1499,6 +1499,22 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if constexpr(is_same_v<AComputeType, ck::tf32_t> || is_same_v<BComputeType, ck::tf32_t>)
|
||||
{
|
||||
if(!is_tf32_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if constexpr(!is_same_v<AComputeType, BComputeType>)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "ComputeDataType for A and B should be same while using TF32"
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(!IsSplitKSupported)
|
||||
{
|
||||
|
||||
@@ -951,6 +951,22 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if constexpr(is_same_v<ComputeTypeA, ck::tf32_t> || is_same_v<ComputeTypeB, ck::tf32_t>)
|
||||
{
|
||||
if(!is_tf32_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if constexpr(!is_same_v<ComputeTypeA, ComputeTypeB>)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "ComputeDataType for A and B should be same while using TF32"
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
if constexpr(!is_GNWC_GKXC_GNWK<InLayout, WeiLayout, OutLayout>())
|
||||
|
||||
@@ -1687,6 +1687,23 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
const index_t GemmK =
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
|
||||
|
||||
if constexpr(is_same_v<ComputeTypeA, ck::tf32_t> || is_same_v<ComputeTypeB, ck::tf32_t>)
|
||||
{
|
||||
if(!is_tf32_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if constexpr(!is_same_v<ComputeTypeA, ComputeTypeB>)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "ComputeDataType for A and B should be same while using TF32"
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if(get_warp_size() == 64)
|
||||
{
|
||||
if constexpr(NXdlPerWave64 > 0)
|
||||
|
||||
@@ -950,6 +950,22 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if constexpr(is_same_v<ComputeTypeA, ck::tf32_t> || is_same_v<ComputeTypeB, ck::tf32_t>)
|
||||
{
|
||||
if(!is_tf32_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if constexpr(!is_same_v<ComputeTypeA, ComputeTypeB>)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "ComputeDataType for A and B should be same while using TF32"
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
if constexpr(!is_GNWC_GKXC_GNWK<InLayout, WeiLayout, OutLayout>())
|
||||
|
||||
@@ -1289,6 +1289,23 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) *
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2);
|
||||
|
||||
if constexpr(is_same_v<ComputeTypeA, ck::tf32_t> || is_same_v<ComputeTypeB, ck::tf32_t>)
|
||||
{
|
||||
if(!is_tf32_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if constexpr(!is_same_v<ComputeTypeA, ComputeTypeB>)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "ComputeDataType for A and B should be same while using TF32"
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if(get_warp_size() == 64)
|
||||
{
|
||||
if constexpr(NXdlPerWave64 > 0)
|
||||
|
||||
@@ -1399,6 +1399,25 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<AComputeDataType, ck::tf32_t> ||
|
||||
is_same_v<BComputeDataType, ck::tf32_t>)
|
||||
{
|
||||
if(!is_tf32_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if constexpr(!is_same_v<AComputeDataType, BComputeDataType>)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "ComputeDataType for A and B should be same while using TF32"
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// check ConvolutionForwardSpecialization
|
||||
if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
|
||||
@@ -820,6 +820,23 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if constexpr(is_same_v<AComputeDataType, ck::tf32_t> ||
|
||||
is_same_v<BComputeDataType, ck::tf32_t>)
|
||||
{
|
||||
if(!is_tf32_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if constexpr(!is_same_v<AComputeDataType, BComputeDataType>)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "ComputeDataType for A and B should be same while using TF32"
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// check ConvolutionForwardSpecialization
|
||||
if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
|
||||
Reference in New Issue
Block a user