Conv:TF32: add more instances - 2 (#2879)

* add instances of device_grouped_conv_fwd_xdl_f32_comp_instances
* add instances of device_grouped_conv_fwd_xdl_f32_tf32_mem_instances
* add instances of device_grouped_conv_fwd_xdl_large_tensor_f32_tf32_instances
* tf32:conv:add instances for base class DeviceConvFwd
* tf32:conv:add instances for base class DeviceGroupedConvBwdDataMultipleD
* tf32:conv:add instances for base class DeviceGroupedConvBwdWeight
* add tf32 in profiler
* remove gnhwc/ngchw/ngcdhw instances
* remove non-ndhwgc/nhwgc/nhwc instances
* add check in IsSupportedArgument()
This commit is contained in:
yinglu
2025-10-10 15:28:17 +08:00
committed by GitHub
parent ad7a215aba
commit fada1a3cae
56 changed files with 2119 additions and 152 deletions

View File

@@ -1499,6 +1499,22 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
{
return false;
}
if constexpr(is_same_v<AComputeType, ck::tf32_t> || is_same_v<BComputeType, ck::tf32_t>)
{
if(!is_tf32_supported())
{
return false;
}
if constexpr(!is_same_v<AComputeType, BComputeType>)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "ComputeDataType for A and B should be same while using TF32"
<< std::endl;
}
return false;
}
}
if constexpr(!IsSplitKSupported)
{

View File

@@ -951,6 +951,22 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
{
return false;
}
if constexpr(is_same_v<ComputeTypeA, ck::tf32_t> || is_same_v<ComputeTypeB, ck::tf32_t>)
{
if(!is_tf32_supported())
{
return false;
}
if constexpr(!is_same_v<ComputeTypeA, ComputeTypeB>)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "ComputeDataType for A and B should be same while using TF32"
<< std::endl;
}
return false;
}
}
if constexpr(NDimSpatial == 1)
{
if constexpr(!is_GNWC_GKXC_GNWK<InLayout, WeiLayout, OutLayout>())

View File

@@ -1687,6 +1687,23 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
const index_t GemmK =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
if constexpr(is_same_v<ComputeTypeA, ck::tf32_t> || is_same_v<ComputeTypeB, ck::tf32_t>)
{
if(!is_tf32_supported())
{
return false;
}
if constexpr(!is_same_v<ComputeTypeA, ComputeTypeB>)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "ComputeDataType for A and B should be same while using TF32"
<< std::endl;
}
return false;
}
}
if(get_warp_size() == 64)
{
if constexpr(NXdlPerWave64 > 0)

View File

@@ -950,6 +950,22 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
{
return false;
}
if constexpr(is_same_v<ComputeTypeA, ck::tf32_t> || is_same_v<ComputeTypeB, ck::tf32_t>)
{
if(!is_tf32_supported())
{
return false;
}
if constexpr(!is_same_v<ComputeTypeA, ComputeTypeB>)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "ComputeDataType for A and B should be same while using TF32"
<< std::endl;
}
return false;
}
}
if constexpr(NDimSpatial == 1)
{
if constexpr(!is_GNWC_GKXC_GNWK<InLayout, WeiLayout, OutLayout>())

View File

@@ -1289,6 +1289,23 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) *
arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2);
if constexpr(is_same_v<ComputeTypeA, ck::tf32_t> || is_same_v<ComputeTypeB, ck::tf32_t>)
{
if(!is_tf32_supported())
{
return false;
}
if constexpr(!is_same_v<ComputeTypeA, ComputeTypeB>)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "ComputeDataType for A and B should be same while using TF32"
<< std::endl;
}
return false;
}
}
if(get_warp_size() == 64)
{
if constexpr(NXdlPerWave64 > 0)

View File

@@ -1399,6 +1399,25 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
}
return false;
}
if constexpr(is_same_v<AComputeDataType, ck::tf32_t> ||
is_same_v<BComputeDataType, ck::tf32_t>)
{
if(!is_tf32_supported())
{
return false;
}
if constexpr(!is_same_v<AComputeDataType, BComputeDataType>)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "ComputeDataType for A and B should be same while using TF32"
<< std::endl;
}
return false;
}
}
// check ConvolutionForwardSpecialization
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)

View File

@@ -820,6 +820,23 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
{
return false;
}
if constexpr(is_same_v<AComputeDataType, ck::tf32_t> ||
is_same_v<BComputeDataType, ck::tf32_t>)
{
if(!is_tf32_supported())
{
return false;
}
if constexpr(!is_same_v<AComputeDataType, BComputeDataType>)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "ComputeDataType for A and B should be same while using TF32"
<< std::endl;
}
return false;
}
}
// check ConvolutionForwardSpecialization
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)

View File

@@ -280,8 +280,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
using FloatBAdjusted =
conditional_t<is_same_v<ComputeTypeB, ck::half_t>, ck::bhalf_t, ComputeTypeB>;
#else
using FloatAAdjusted = ComputeTypeA;
using FloatBAdjusted = ComputeTypeB;
using FloatAAdjusted = conditional_t<is_same_v<ComputeTypeA, ck::tf32_t>, float, ComputeTypeA>;
using FloatBAdjusted = conditional_t<is_same_v<ComputeTypeB, ck::tf32_t>, float, ComputeTypeB>;
#endif
// M0/M1/M1Padding
@@ -760,19 +760,19 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
// register
// sanity check
constexpr bool is_single_rate_mfma =
(((is_same<FloatAAdjusted, half_t>::value || is_same<FloatAAdjusted, bhalf_t>::value) &&
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
K1 <= 4) ||
(is_same<FloatAAdjusted, int8_t>::value && K1 <= 8) ||
((is_same<FloatAAdjusted, f8_t>::value || is_same<FloatAAdjusted, bf8_t>::value) &&
(is_same<ComputeTypeA, int8_t>::value && K1 <= 8) ||
((is_same<ComputeTypeA, f8_t>::value || is_same<ComputeTypeA, bf8_t>::value) &&
K1 < 32))
? true
: false;
constexpr auto is_scale_mfma = false;
constexpr index_t KPack = math::max(K1,
MfmaSelector<FloatAAdjusted,
MfmaSelector<ComputeTypeA,
MPerXdl,
NPerXdl,
FloatBAdjusted,
ComputeTypeB,
is_single_rate_mfma,
is_scale_mfma>::selected_mfma.k_per_blk);
@@ -787,7 +787,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
NPerXdl,
MRepeat,
NRepeat,
KPack>{};
KPack,
ComputeTypeA,
ComputeTypeB>{};
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();