mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Conv:TF32: add more instances - 2 (#2879)
* add instances of device_grouped_conv_fwd_xdl_f32_comp_instances * add instances of device_grouped_conv_fwd_xdl_f32_tf32_mem_instances * add instances of device_grouped_conv_fwd_xdl_large_tensor_f32_tf32_instances * tf32:conv:add instances for base class DeviceConvFwd * tf32:conv:add instances for base class DeviceGroupedConvBwdDataMultipleD * tf32:conv:add instances for base class DeviceGroupedConvBwdWeight * add tf32 in profiler * remove gnhwc/ngchw/ngcdhw instances * remove non-ndhwgc/nhwgc/nhwc instances * add check in IsSupportedArgument()
This commit is contained in:
@@ -280,8 +280,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
using FloatBAdjusted =
|
||||
conditional_t<is_same_v<ComputeTypeB, ck::half_t>, ck::bhalf_t, ComputeTypeB>;
|
||||
#else
|
||||
using FloatAAdjusted = ComputeTypeA;
|
||||
using FloatBAdjusted = ComputeTypeB;
|
||||
using FloatAAdjusted = conditional_t<is_same_v<ComputeTypeA, ck::tf32_t>, float, ComputeTypeA>;
|
||||
using FloatBAdjusted = conditional_t<is_same_v<ComputeTypeB, ck::tf32_t>, float, ComputeTypeB>;
|
||||
#endif
|
||||
|
||||
// M0/M1/M1Padding
|
||||
@@ -760,19 +760,19 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
// register
|
||||
// sanity check
|
||||
constexpr bool is_single_rate_mfma =
|
||||
(((is_same<FloatAAdjusted, half_t>::value || is_same<FloatAAdjusted, bhalf_t>::value) &&
|
||||
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
|
||||
K1 <= 4) ||
|
||||
(is_same<FloatAAdjusted, int8_t>::value && K1 <= 8) ||
|
||||
((is_same<FloatAAdjusted, f8_t>::value || is_same<FloatAAdjusted, bf8_t>::value) &&
|
||||
(is_same<ComputeTypeA, int8_t>::value && K1 <= 8) ||
|
||||
((is_same<ComputeTypeA, f8_t>::value || is_same<ComputeTypeA, bf8_t>::value) &&
|
||||
K1 < 32))
|
||||
? true
|
||||
: false;
|
||||
constexpr auto is_scale_mfma = false;
|
||||
constexpr index_t KPack = math::max(K1,
|
||||
MfmaSelector<FloatAAdjusted,
|
||||
MfmaSelector<ComputeTypeA,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
FloatBAdjusted,
|
||||
ComputeTypeB,
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>::selected_mfma.k_per_blk);
|
||||
|
||||
@@ -787,7 +787,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
NPerXdl,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
KPack,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>{};
|
||||
|
||||
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user