mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
Conv:TF32: add more instances - 2 (#2879)
* add instances of device_grouped_conv_fwd_xdl_f32_comp_instances
* add instances of device_grouped_conv_fwd_xdl_f32_tf32_mem_instances
* add instances of device_grouped_conv_fwd_xdl_large_tensor_f32_tf32_instances
* tf32:conv:add instances for base class DeviceConvFwd
* tf32:conv:add instances for base class DeviceGroupedConvBwdDataMultipleD
* tf32:conv:add instances for base class DeviceGroupedConvBwdWeight
* add tf32 in profiler
* remove gnhwc/ngchw/ngcdhw instances
* remove non-ndhwgc/nhwgc/nhwc instances
* add check in IsSupportedArgument()
[ROCm/composable_kernel commit: fada1a3cae]
This commit is contained in:
@@ -31,13 +31,15 @@ double get_relative_threshold(const int number_of_accumulations = 1)
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
using TF32 = ck::tf32_t;
|
||||
using I8 = int8_t;
|
||||
using I32 = int32_t;
|
||||
|
||||
static_assert(is_same_v<ComputeDataType, F4> || is_same_v<ComputeDataType, F8> ||
|
||||
is_same_v<ComputeDataType, F16> || is_same_v<ComputeDataType, BF16> ||
|
||||
is_same_v<ComputeDataType, F32> || is_same_v<ComputeDataType, I8> ||
|
||||
is_same_v<ComputeDataType, I32> || is_same_v<ComputeDataType, int>,
|
||||
is_same_v<ComputeDataType, F32> || is_same_v<ComputeDataType, TF32> ||
|
||||
is_same_v<ComputeDataType, I8> || is_same_v<ComputeDataType, I32> ||
|
||||
is_same_v<ComputeDataType, int>,
|
||||
"Warning: Unhandled ComputeDataType for setting up the relative threshold!");
|
||||
double compute_error = 0;
|
||||
if constexpr(is_same_v<ComputeDataType, I8> || is_same_v<ComputeDataType, I32> ||
|
||||
@@ -52,8 +54,9 @@ double get_relative_threshold(const int number_of_accumulations = 1)
|
||||
|
||||
static_assert(is_same_v<OutDataType, F4> || is_same_v<OutDataType, F8> ||
|
||||
is_same_v<OutDataType, F16> || is_same_v<OutDataType, BF16> ||
|
||||
is_same_v<OutDataType, F32> || is_same_v<OutDataType, I8> ||
|
||||
is_same_v<OutDataType, I32> || is_same_v<OutDataType, int>,
|
||||
is_same_v<OutDataType, F32> || is_same_v<ComputeDataType, TF32> ||
|
||||
is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> ||
|
||||
is_same_v<OutDataType, int>,
|
||||
"Warning: Unhandled OutDataType for setting up the relative threshold!");
|
||||
double output_error = 0;
|
||||
if constexpr(is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> ||
|
||||
@@ -69,8 +72,9 @@ double get_relative_threshold(const int number_of_accumulations = 1)
|
||||
|
||||
static_assert(is_same_v<AccDataType, F4> || is_same_v<AccDataType, F8> ||
|
||||
is_same_v<AccDataType, F16> || is_same_v<AccDataType, BF16> ||
|
||||
is_same_v<AccDataType, F32> || is_same_v<AccDataType, I8> ||
|
||||
is_same_v<AccDataType, I32> || is_same_v<AccDataType, int>,
|
||||
is_same_v<AccDataType, F32> || is_same_v<ComputeDataType, TF32> ||
|
||||
is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> ||
|
||||
is_same_v<AccDataType, int>,
|
||||
"Warning: Unhandled AccDataType for setting up the relative threshold!");
|
||||
double acc_error = 0;
|
||||
if constexpr(is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> ||
|
||||
@@ -93,13 +97,15 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
using TF32 = ck::tf32_t;
|
||||
using I8 = int8_t;
|
||||
using I32 = int32_t;
|
||||
|
||||
static_assert(is_same_v<ComputeDataType, F4> || is_same_v<ComputeDataType, F8> ||
|
||||
is_same_v<ComputeDataType, F16> || is_same_v<ComputeDataType, BF16> ||
|
||||
is_same_v<ComputeDataType, F32> || is_same_v<ComputeDataType, I8> ||
|
||||
is_same_v<ComputeDataType, I32> || is_same_v<ComputeDataType, int>,
|
||||
is_same_v<ComputeDataType, F32> || is_same_v<ComputeDataType, TF32> ||
|
||||
is_same_v<ComputeDataType, I8> || is_same_v<ComputeDataType, I32> ||
|
||||
is_same_v<ComputeDataType, int>,
|
||||
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
|
||||
auto expo = std::log2(std::abs(max_possible_num));
|
||||
double compute_error = 0;
|
||||
@@ -115,8 +121,9 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
|
||||
|
||||
static_assert(is_same_v<OutDataType, F4> || is_same_v<OutDataType, F8> ||
|
||||
is_same_v<OutDataType, F16> || is_same_v<OutDataType, BF16> ||
|
||||
is_same_v<OutDataType, F32> || is_same_v<OutDataType, I8> ||
|
||||
is_same_v<OutDataType, I32> || is_same_v<OutDataType, int>,
|
||||
is_same_v<OutDataType, F32> || is_same_v<ComputeDataType, TF32> ||
|
||||
is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> ||
|
||||
is_same_v<OutDataType, int>,
|
||||
"Warning: Unhandled OutDataType for setting up the absolute threshold!");
|
||||
double output_error = 0;
|
||||
if constexpr(is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> ||
|
||||
@@ -132,8 +139,9 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
|
||||
|
||||
static_assert(is_same_v<AccDataType, F4> || is_same_v<AccDataType, F8> ||
|
||||
is_same_v<AccDataType, F16> || is_same_v<AccDataType, BF16> ||
|
||||
is_same_v<AccDataType, F32> || is_same_v<AccDataType, I8> ||
|
||||
is_same_v<AccDataType, I32> || is_same_v<AccDataType, int>,
|
||||
is_same_v<AccDataType, F32> || is_same_v<ComputeDataType, TF32> ||
|
||||
is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> ||
|
||||
is_same_v<AccDataType, int>,
|
||||
"Warning: Unhandled AccDataType for setting up the absolute threshold!");
|
||||
double acc_error = 0;
|
||||
if constexpr(is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> ||
|
||||
@@ -149,11 +157,67 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
|
||||
return std::max(acc_error, midway_error);
|
||||
}
|
||||
|
||||
template <typename Range, typename RefRange>
|
||||
template <typename Range,
|
||||
typename RefRange,
|
||||
typename ComputeDataType = ranges::range_value_t<Range>>
|
||||
typename std::enable_if<
|
||||
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
|
||||
std::is_same_v<ranges::range_value_t<Range>, float> &&
|
||||
std::is_same_v<ComputeDataType, ck::tf32_t>,
|
||||
bool>::type
|
||||
check_err(const Range& out,
|
||||
const RefRange& ref,
|
||||
const std::string& msg = "Error: Incorrect results!",
|
||||
double rtol = 1e-5,
|
||||
double atol = 3e-5)
|
||||
{
|
||||
if(out.size() != ref.size())
|
||||
{
|
||||
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool res{true};
|
||||
int err_count = 0;
|
||||
double err = 0;
|
||||
double max_err = std::numeric_limits<double>::min();
|
||||
for(std::size_t i = 0; i < ref.size(); ++i)
|
||||
{
|
||||
const double o = *std::next(std::begin(out), i);
|
||||
const double r = *std::next(std::begin(ref), i);
|
||||
err = std::abs(o - r);
|
||||
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
|
||||
{
|
||||
max_err = err > max_err ? err : max_err;
|
||||
if(err_count < 5)
|
||||
{
|
||||
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
|
||||
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
|
||||
}
|
||||
res = false;
|
||||
err_count++;
|
||||
}
|
||||
}
|
||||
if(!res)
|
||||
{
|
||||
const float error_percent =
|
||||
static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
|
||||
std::cerr << "max err: " << max_err;
|
||||
std::cerr << ", number of errors: " << err_count;
|
||||
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename Range,
|
||||
typename RefRange,
|
||||
typename ComputeDataType = ranges::range_value_t<Range>>
|
||||
typename std::enable_if<
|
||||
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
|
||||
std::is_floating_point_v<ranges::range_value_t<Range>> &&
|
||||
!std::is_same_v<ranges::range_value_t<Range>, half_t>,
|
||||
!std::is_same_v<ranges::range_value_t<Range>, half_t> &&
|
||||
!std::is_same_v<ComputeDataType, ck::tf32_t>,
|
||||
bool>::type
|
||||
check_err(const Range& out,
|
||||
const RefRange& ref,
|
||||
@@ -200,7 +264,9 @@ check_err(const Range& out,
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename Range, typename RefRange>
|
||||
template <typename Range,
|
||||
typename RefRange,
|
||||
typename ComputeDataType = ranges::range_value_t<Range>>
|
||||
typename std::enable_if<
|
||||
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
|
||||
std::is_same_v<ranges::range_value_t<Range>, bhalf_t>,
|
||||
@@ -251,7 +317,9 @@ check_err(const Range& out,
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename Range, typename RefRange>
|
||||
template <typename Range,
|
||||
typename RefRange,
|
||||
typename ComputeDataType = ranges::range_value_t<Range>>
|
||||
typename std::enable_if<
|
||||
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
|
||||
std::is_same_v<ranges::range_value_t<Range>, half_t>,
|
||||
@@ -301,7 +369,9 @@ check_err(const Range& out,
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename Range, typename RefRange>
|
||||
template <typename Range,
|
||||
typename RefRange,
|
||||
typename ComputeDataType = ranges::range_value_t<Range>>
|
||||
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
|
||||
std::is_integral_v<ranges::range_value_t<Range>> &&
|
||||
!std::is_same_v<ranges::range_value_t<Range>, bhalf_t> &&
|
||||
@@ -358,7 +428,9 @@ check_err(const Range& out,
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename Range, typename RefRange>
|
||||
template <typename Range,
|
||||
typename RefRange,
|
||||
typename ComputeDataType = ranges::range_value_t<Range>>
|
||||
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
|
||||
std::is_same_v<ranges::range_value_t<Range>, f8_t>),
|
||||
bool>
|
||||
@@ -407,7 +479,9 @@ check_err(const Range& out,
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename Range, typename RefRange>
|
||||
template <typename Range,
|
||||
typename RefRange,
|
||||
typename ComputeDataType = ranges::range_value_t<Range>>
|
||||
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
|
||||
std::is_same_v<ranges::range_value_t<Range>, bf8_t>),
|
||||
bool>
|
||||
@@ -452,7 +526,9 @@ check_err(const Range& out,
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename Range, typename RefRange>
|
||||
template <typename Range,
|
||||
typename RefRange,
|
||||
typename ComputeDataType = ranges::range_value_t<Range>>
|
||||
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
|
||||
std::is_same_v<ranges::range_value_t<Range>, f4_t>),
|
||||
bool>
|
||||
|
||||
@@ -1499,6 +1499,22 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if constexpr(is_same_v<AComputeType, ck::tf32_t> || is_same_v<BComputeType, ck::tf32_t>)
|
||||
{
|
||||
if(!is_tf32_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if constexpr(!is_same_v<AComputeType, BComputeType>)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "ComputeDataType for A and B should be same while using TF32"
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(!IsSplitKSupported)
|
||||
{
|
||||
|
||||
@@ -951,6 +951,22 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if constexpr(is_same_v<ComputeTypeA, ck::tf32_t> || is_same_v<ComputeTypeB, ck::tf32_t>)
|
||||
{
|
||||
if(!is_tf32_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if constexpr(!is_same_v<ComputeTypeA, ComputeTypeB>)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "ComputeDataType for A and B should be same while using TF32"
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
if constexpr(!is_GNWC_GKXC_GNWK<InLayout, WeiLayout, OutLayout>())
|
||||
|
||||
@@ -1687,6 +1687,23 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
const index_t GemmK =
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
|
||||
|
||||
if constexpr(is_same_v<ComputeTypeA, ck::tf32_t> || is_same_v<ComputeTypeB, ck::tf32_t>)
|
||||
{
|
||||
if(!is_tf32_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if constexpr(!is_same_v<ComputeTypeA, ComputeTypeB>)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "ComputeDataType for A and B should be same while using TF32"
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if(get_warp_size() == 64)
|
||||
{
|
||||
if constexpr(NXdlPerWave64 > 0)
|
||||
|
||||
@@ -950,6 +950,22 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if constexpr(is_same_v<ComputeTypeA, ck::tf32_t> || is_same_v<ComputeTypeB, ck::tf32_t>)
|
||||
{
|
||||
if(!is_tf32_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if constexpr(!is_same_v<ComputeTypeA, ComputeTypeB>)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "ComputeDataType for A and B should be same while using TF32"
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
if constexpr(!is_GNWC_GKXC_GNWK<InLayout, WeiLayout, OutLayout>())
|
||||
|
||||
@@ -1289,6 +1289,23 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) *
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2);
|
||||
|
||||
if constexpr(is_same_v<ComputeTypeA, ck::tf32_t> || is_same_v<ComputeTypeB, ck::tf32_t>)
|
||||
{
|
||||
if(!is_tf32_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if constexpr(!is_same_v<ComputeTypeA, ComputeTypeB>)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "ComputeDataType for A and B should be same while using TF32"
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if(get_warp_size() == 64)
|
||||
{
|
||||
if constexpr(NXdlPerWave64 > 0)
|
||||
|
||||
@@ -1399,6 +1399,25 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<AComputeDataType, ck::tf32_t> ||
|
||||
is_same_v<BComputeDataType, ck::tf32_t>)
|
||||
{
|
||||
if(!is_tf32_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if constexpr(!is_same_v<AComputeDataType, BComputeDataType>)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "ComputeDataType for A and B should be same while using TF32"
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// check ConvolutionForwardSpecialization
|
||||
if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
|
||||
@@ -820,6 +820,23 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if constexpr(is_same_v<AComputeDataType, ck::tf32_t> ||
|
||||
is_same_v<BComputeDataType, ck::tf32_t>)
|
||||
{
|
||||
if(!is_tf32_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if constexpr(!is_same_v<AComputeDataType, BComputeDataType>)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "ComputeDataType for A and B should be same while using TF32"
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// check ConvolutionForwardSpecialization
|
||||
if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
|
||||
@@ -280,8 +280,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
using FloatBAdjusted =
|
||||
conditional_t<is_same_v<ComputeTypeB, ck::half_t>, ck::bhalf_t, ComputeTypeB>;
|
||||
#else
|
||||
using FloatAAdjusted = ComputeTypeA;
|
||||
using FloatBAdjusted = ComputeTypeB;
|
||||
using FloatAAdjusted = conditional_t<is_same_v<ComputeTypeA, ck::tf32_t>, float, ComputeTypeA>;
|
||||
using FloatBAdjusted = conditional_t<is_same_v<ComputeTypeB, ck::tf32_t>, float, ComputeTypeB>;
|
||||
#endif
|
||||
|
||||
// M0/M1/M1Padding
|
||||
@@ -760,19 +760,19 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
// register
|
||||
// sanity check
|
||||
constexpr bool is_single_rate_mfma =
|
||||
(((is_same<FloatAAdjusted, half_t>::value || is_same<FloatAAdjusted, bhalf_t>::value) &&
|
||||
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
|
||||
K1 <= 4) ||
|
||||
(is_same<FloatAAdjusted, int8_t>::value && K1 <= 8) ||
|
||||
((is_same<FloatAAdjusted, f8_t>::value || is_same<FloatAAdjusted, bf8_t>::value) &&
|
||||
(is_same<ComputeTypeA, int8_t>::value && K1 <= 8) ||
|
||||
((is_same<ComputeTypeA, f8_t>::value || is_same<ComputeTypeA, bf8_t>::value) &&
|
||||
K1 < 32))
|
||||
? true
|
||||
: false;
|
||||
constexpr auto is_scale_mfma = false;
|
||||
constexpr index_t KPack = math::max(K1,
|
||||
MfmaSelector<FloatAAdjusted,
|
||||
MfmaSelector<ComputeTypeA,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
FloatBAdjusted,
|
||||
ComputeTypeB,
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>::selected_mfma.k_per_blk);
|
||||
|
||||
@@ -787,7 +787,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
NPerXdl,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
KPack,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>{};
|
||||
|
||||
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
|
||||
|
||||
|
||||
@@ -45,6 +45,24 @@ struct NumericUtils<float>
|
||||
using bitwise_type = uint32_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericUtils<ck::tf32_t>
|
||||
{
|
||||
static constexpr int exp = 8;
|
||||
static constexpr int mant = 10;
|
||||
static constexpr int bias = 127;
|
||||
static constexpr uint32_t nan_mask = 0x7F800000;
|
||||
static constexpr uint32_t head_mask = 0xFF800000;
|
||||
static constexpr uint32_t mant_mask = 0x7FFFFF;
|
||||
static constexpr uint32_t exp_mask = 0xFF;
|
||||
static constexpr uint32_t Inf = 0x7F800000;
|
||||
static constexpr uint32_t NegInf = 0xFF800000;
|
||||
static constexpr uint32_t NaN = 0x7F800001;
|
||||
static constexpr uint32_t Neg0 = 0x80000000;
|
||||
static constexpr bool has_inf = true;
|
||||
using bitwise_type = uint32_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericUtils<half_t>
|
||||
{
|
||||
|
||||
@@ -28,6 +28,7 @@ template <ck::index_t NDimSpatial,
|
||||
ck::index_t NumAElementwiseTensor = 0,
|
||||
ck::index_t NumBElementwiseTensor = 0,
|
||||
ck::index_t NumDElementwiseTensor = 0,
|
||||
typename ComputeDataType = OutDataType,
|
||||
typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
|
||||
struct ReferenceConvBwdData : public device::BaseOperator
|
||||
{
|
||||
@@ -142,8 +143,10 @@ struct ReferenceConvBwdData : public device::BaseOperator
|
||||
c,
|
||||
x);
|
||||
|
||||
v_acc += ck::type_convert<float>(v_out) *
|
||||
ck::type_convert<float>(v_wei);
|
||||
v_acc += ck::type_convert<float>(
|
||||
ck::type_convert<ComputeDataType>(v_out)) *
|
||||
ck::type_convert<float>(
|
||||
ck::type_convert<ComputeDataType>(v_wei));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -235,8 +238,11 @@ struct ReferenceConvBwdData : public device::BaseOperator
|
||||
y,
|
||||
x);
|
||||
|
||||
v_acc += ck::type_convert<float>(v_out) *
|
||||
ck::type_convert<float>(v_wei);
|
||||
v_acc +=
|
||||
ck::type_convert<float>(
|
||||
ck::type_convert<ComputeDataType>(v_out)) *
|
||||
ck::type_convert<float>(
|
||||
ck::type_convert<ComputeDataType>(v_wei));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -354,8 +360,12 @@ struct ReferenceConvBwdData : public device::BaseOperator
|
||||
x);
|
||||
|
||||
v_acc +=
|
||||
ck::type_convert<float>(v_out) *
|
||||
ck::type_convert<float>(v_wei);
|
||||
ck::type_convert<float>(
|
||||
ck::type_convert<
|
||||
ComputeDataType>(v_out)) *
|
||||
ck::type_convert<float>(
|
||||
ck::type_convert<
|
||||
ComputeDataType>(v_wei));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ namespace instance {
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using TF32 = ck::tf32_t;
|
||||
using BF8 = ck::bf8_t;
|
||||
using F8 = ck::f8_t;
|
||||
|
||||
@@ -84,17 +85,17 @@ using device_grouped_conv_bwd_data_transpose_xdl_bf16_instances =
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, make_default_loop_scheduler(), BF16, BF16, 2, 2>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, make_default_loop_scheduler(), BF16, BF16, 2, 2>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, make_default_loop_scheduler(), BF16, BF16, 2, 2>,
|
||||
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, make_default_loop_scheduler(), BF16, BF16, 4, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, make_default_loop_scheduler(), BF16, BF16, 4, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, make_default_loop_scheduler(), BF16, BF16, 4, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, make_default_loop_scheduler(), BF16, BF16, 4, 4>,
|
||||
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, make_default_loop_scheduler(), BF16, BF16, 1, 2>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, make_default_loop_scheduler(), BF16, BF16, 1, 2>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, make_default_loop_scheduler(), BF16, BF16, 1, 2>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, make_default_loop_scheduler(), BF16, BF16, 1, 2>,
|
||||
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, make_default_loop_scheduler(), BF16, BF16, 2, 1>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, make_default_loop_scheduler(), BF16, BF16, 2, 1>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, make_default_loop_scheduler(), BF16, BF16, 2, 1>,
|
||||
|
||||
@@ -18,6 +18,7 @@ namespace instance {
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using TF32 = ck::tf32_t;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
@@ -143,6 +144,43 @@ using device_grouped_conv_bwd_data_xdl_bilinear_f32_instances =
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
// f32_f32_f32_f32 tf32
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardDataSpecialization ConvSpec>
|
||||
using device_grouped_conv_bwd_data_xdl_bilinear_f32_tf32_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer|
|
||||
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// generic instance
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, make_default_loop_scheduler(), TF32, TF32>,
|
||||
// instances for small conv.K and conv.C
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8> , 4, make_default_loop_scheduler(), TF32, TF32>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 32, 1, 4>, 1, make_default_loop_scheduler(), TF32, TF32>,
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8> , 4, make_default_loop_scheduler(), TF32, TF32>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8> , 4, make_default_loop_scheduler(), TF32, TF32>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8> , 4, make_default_loop_scheduler(), TF32, TF32>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8> , 4, make_default_loop_scheduler(), TF32, TF32>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8> , 4, make_default_loop_scheduler(), TF32, TF32>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 4> , 4, make_default_loop_scheduler(), TF32, TF32>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8> , 4, make_default_loop_scheduler(), TF32, TF32>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4> , 4, make_default_loop_scheduler(), TF32, TF32>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8> , 4, make_default_loop_scheduler(), TF32, TF32>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8> , 4, make_default_loop_scheduler(), TF32, TF32>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 4> , 4, make_default_loop_scheduler(), TF32, TF32>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8> , 4, make_default_loop_scheduler(), TF32, TF32>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4> , 4, make_default_loop_scheduler(), TF32, TF32>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4> , 8, make_default_loop_scheduler(), TF32, TF32>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -18,6 +18,7 @@ namespace instance {
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using TF32 = ck::tf32_t;
|
||||
using BF8 = ck::bf8_t;
|
||||
using F8 = ck::f8_t;
|
||||
|
||||
@@ -375,6 +376,41 @@ using device_grouped_conv_bwd_data_xdl_f32_optimized_loads_instances =
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardDataSpecialization ConvSpec>
|
||||
using device_grouped_conv_bwd_data_xdl_f32_tf32_optimized_loads_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer|
|
||||
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// A K1 one access for each thread per load
|
||||
// 32x32
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 32, 1, 8>, 4, make_default_loop_scheduler(), TF32, TF32>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 16, 4, 4, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 4, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 32, 1, 8>, 4, make_default_loop_scheduler(), TF32, TF32>,
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 16, 1, 16>, 2, make_default_loop_scheduler(), TF32, TF32>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 16, 4, 4, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 16, 1, 16>, 2, make_default_loop_scheduler(), TF32, TF32>,
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 32>, 1, make_default_loop_scheduler(), TF32, TF32>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 16, 4, 4, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 32>, 1, make_default_loop_scheduler(), TF32, TF32>,
|
||||
// 16x16
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 4, make_default_loop_scheduler(), TF32, TF32>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 16, 4, 4, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 2, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 4, make_default_loop_scheduler(), TF32, TF32>,
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 32, 1, 8>, 2, make_default_loop_scheduler(), TF32, TF32>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 16, 4, 4, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 32, 1, 8>, 2, make_default_loop_scheduler(), TF32, TF32>,
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1, make_default_loop_scheduler(), TF32, TF32>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 16, 4, 4, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1, make_default_loop_scheduler(), TF32, TF32>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
@@ -411,6 +447,83 @@ using device_grouped_conv_bwd_data_xdl_f32_instances =
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
// f32_f32_f32_f32 tf32
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardDataSpecialization ConvSpec>
|
||||
using device_grouped_conv_bwd_data_xdl_f32_tf32_generic_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer|
|
||||
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// generic instance
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, make_default_loop_scheduler(), TF32, TF32>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardDataSpecialization ConvSpec>
|
||||
using device_grouped_conv_bwd_data_xdl_f32_tf32_16_16_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer|
|
||||
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 16, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, make_default_loop_scheduler(), TF32, TF32>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 16, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, make_default_loop_scheduler(), TF32, TF32>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 16, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, make_default_loop_scheduler(), TF32, TF32>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, make_default_loop_scheduler(), TF32, TF32>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, make_default_loop_scheduler(), TF32, TF32>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, make_default_loop_scheduler(), TF32, TF32>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardDataSpecialization ConvSpec>
|
||||
using device_grouped_conv_bwd_data_xdl_f32_tf32_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer|
|
||||
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// generic instance
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, make_default_loop_scheduler(), TF32, TF32>,
|
||||
// instances for small conv.K and conv.C
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4, make_default_loop_scheduler(), TF32, TF32>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 32, 1, 4>, 1, make_default_loop_scheduler(), TF32, TF32>,
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4, make_default_loop_scheduler(), TF32, TF32>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4, make_default_loop_scheduler(), TF32, TF32>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4, make_default_loop_scheduler(), TF32, TF32>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, make_default_loop_scheduler(), TF32, TF32>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4, make_default_loop_scheduler(), TF32, TF32>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 4, make_default_loop_scheduler(), TF32, TF32>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, make_default_loop_scheduler(), TF32, TF32>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4, make_default_loop_scheduler(), TF32, TF32>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4, make_default_loop_scheduler(), TF32, TF32>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4, make_default_loop_scheduler(), TF32, TF32>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 4, make_default_loop_scheduler(), TF32, TF32>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, make_default_loop_scheduler(), TF32, TF32>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4, make_default_loop_scheduler(), TF32, TF32>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, make_default_loop_scheduler(), TF32, TF32>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
// f16_f16_f16_comp_f8
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
|
||||
@@ -18,6 +18,7 @@ namespace instance {
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using TF32 = ck::tf32_t;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
@@ -143,6 +144,43 @@ using device_grouped_conv_bwd_data_xdl_scale_f32_instances =
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
// f32_f32_f32_f32 tf32
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardDataSpecialization ConvSpec>
|
||||
using device_grouped_conv_bwd_data_xdl_scale_f32_tf32_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer|
|
||||
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// generic instance
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1>,
|
||||
// instances for small conv.K and conv.C
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 32, 1, 4>, 1>,
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -18,6 +18,7 @@ using namespace ck::tensor_layout::convolution;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using TF32 = ck::tf32_t;
|
||||
|
||||
#ifdef CK_ENABLE_FP8
|
||||
using F8 = ck::f8_t;
|
||||
@@ -58,6 +59,24 @@ using device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f32_instances = std::tuple
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardWeightSpecialization ConvSpec,
|
||||
BlockGemmPipelineScheduler Scheduler,
|
||||
BlockGemmPipelineVersion PipelineVersion>
|
||||
using device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f32_tf32_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm|
|
||||
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline|
|
||||
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version|
|
||||
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | |
|
||||
// generic instance
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion, TF32, TF32>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
|
||||
@@ -18,6 +18,7 @@ using namespace ck::tensor_layout::convolution;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using TF32 = ck::tf32_t;
|
||||
|
||||
#ifdef CK_ENABLE_FP8
|
||||
using F8 = ck::f8_t;
|
||||
@@ -74,6 +75,39 @@ using device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_bilinear_instances = std:
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardWeightSpecialization ConvSpec>
|
||||
using device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_bilinear_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#########################################| Num| InLayout| WeiLayout| OutLayout| DsData| InData| WeiData| OutData| AccData| DsData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer|
|
||||
//#########################################| Dim| | | | Layout| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector|
|
||||
//#########################################| Spatial| | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl|
|
||||
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| |
|
||||
// generic instance
|
||||
DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, F32, F32, F32, F32, Tuple<F32>, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, 1, 1, S<1, 16, 1, 4>, 1, TF32, TF32>,
|
||||
// instances for small conv.K and conv.C
|
||||
DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, F32, F32, F32, F32, Tuple<F32>, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 1, true, 1, 1, S<1, 32, 1, 4>, 1, TF32, TF32>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, F32, F32, F32, F32, Tuple<F32>, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4, TF32, TF32>,
|
||||
|
||||
DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, F32, F32, F32, F32, Tuple<F32>, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 8>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, F32, F32, F32, F32, Tuple<F32>, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, F32, F32, F32, F32, Tuple<F32>, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, F32, F32, F32, F32, Tuple<F32>, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, F32, F32, F32, F32, Tuple<F32>, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, F32, F32, F32, F32, Tuple<F32>, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, F32, F32, F32, F32, Tuple<F32>, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, F32, F32, F32, F32, Tuple<F32>, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, F32, F32, F32, F32, Tuple<F32>, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, F32, F32, F32, F32, Tuple<F32>, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, F32, F32, F32, F32, Tuple<F32>, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, F32, F32, F32, F32, Tuple<F32>, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 16, 1, 4>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, F32, F32, F32, F32, Tuple<F32>, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4, TF32, TF32>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
|
||||
@@ -18,6 +18,7 @@ using namespace ck::tensor_layout::convolution;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using TF32 = ck::tf32_t;
|
||||
|
||||
#ifdef CK_ENABLE_FP8
|
||||
using F8 = ck::f8_t;
|
||||
@@ -96,6 +97,62 @@ using device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances = std::tuple<
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardWeightSpecialization ConvSpec>
|
||||
using device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_generic_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer|
|
||||
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector|
|
||||
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl|
|
||||
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| |
|
||||
// generic instance
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, 1, 1, S<1, 16, 1, 4>, 1, TF32, TF32>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardWeightSpecialization ConvSpec,
|
||||
index_t MaxTransposeTransferSrcScalarPerVector = 1,
|
||||
index_t MaxTransposeTransferDstScalarPerVector = 1>
|
||||
using device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| Compute| Compute| MaxTranspose| MaxTranspose|
|
||||
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| TypeA| TypeB| TransferSrc| TransferDst|
|
||||
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| | | ScalarPerVector| ScalarPerVector|
|
||||
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | |
|
||||
// generic instance
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, 1, 1, S<1, 16, 1, 4>, 1, TF32, TF32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>,
|
||||
// instances for small conv.K and conv.C
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 1, true, 1, 1, S<1, 32, 1, 4>, 1, TF32, TF32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4, TF32, TF32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>,
|
||||
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 8>, 4, TF32, TF32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4, TF32, TF32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4, TF32, TF32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4, TF32, TF32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4, TF32, TF32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4, TF32, TF32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4, TF32, TF32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4, TF32, TF32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4, TF32, TF32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4, TF32, TF32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4, TF32, TF32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 16, 1, 4>, 4, TF32, TF32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4, TF32, TF32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>,
|
||||
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 8, 8, 32, 32, 1, 1, S<1, 8, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, 2, 4, 4, true, S<1, 8, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 8, 8, 32, 32, 1, 1, S<1, 8, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, 2, 4, 4, true, S<1, 8, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, 2, 1, 4, true, 1, 1, S<1, 4, 1, 64>, 1, TF32, TF32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 8, 8, 32, 32, 1, 1, S<1, 8, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, 2, 1, 4, true, S<1, 8, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 8, 8, 32, 32, 1, 1, S<1, 8, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, 2, 1, 4, true, S<1, 8, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, 2, 1, 4, true, 1, 1, S<1, 4, 1, 64>, 1, TF32, TF32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
|
||||
@@ -18,6 +18,7 @@ using namespace ck::tensor_layout::convolution;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using TF32 = ck::tf32_t;
|
||||
|
||||
#ifdef CK_ENABLE_FP8
|
||||
using F8 = ck::f8_t;
|
||||
@@ -74,6 +75,39 @@ using device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_scale_instances = std::tu
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardWeightSpecialization ConvSpec>
|
||||
using device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_scale_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#########################################| Num| InLayout| WeiLayout| OutLayout| DsData| InData| WeiData| OutData| AccData| DsData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer|
|
||||
//#########################################| Dim| | | | Layout| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector|
|
||||
//#########################################| Spatial| | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl|
|
||||
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| |
|
||||
// generic instance
|
||||
DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F32, F32, F32, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, 1, 1, S<1, 16, 1, 4>, 1, TF32, TF32>,
|
||||
// instances for small conv.K and conv.C
|
||||
DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F32, F32, F32, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 1, true, 1, 1, S<1, 32, 1, 4>, 1, TF32, TF32>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F32, F32, F32, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4, TF32, TF32>,
|
||||
|
||||
DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F32, F32, F32, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 8>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F32, F32, F32, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F32, F32, F32, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F32, F32, F32, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F32, F32, F32, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F32, F32, F32, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F32, F32, F32, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F32, F32, F32, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F32, F32, F32, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F32, F32, F32, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F32, F32, F32, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F32, F32, F32, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 16, 1, 4>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F32, F32, F32, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4, TF32, TF32>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
|
||||
@@ -117,14 +117,28 @@ struct DeviceOperationInstanceFactory<
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
|
||||
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
|
||||
is_same_v<ComputeTypeB, F32>)
|
||||
is_same_v<OutDataType, F32>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_16_16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_optimized_loads_instances(
|
||||
op_ptrs);
|
||||
static_assert(is_same_v<ComputeTypeA, ComputeTypeB>,
|
||||
"Error: this operator requires the same compute type");
|
||||
if constexpr(is_same_v<ComputeTypeA, TF32>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_16_16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_optimized_loads_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_16_16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_optimized_loads_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
@@ -275,12 +289,26 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
|
||||
is_same_v<ComputeTypeB, F32>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_16_16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_optimized_loads_instances(
|
||||
op_ptrs);
|
||||
static_assert(is_same_v<ComputeTypeA, ComputeTypeB>,
|
||||
"Error: this operator requires the same compute type");
|
||||
if constexpr(is_same_v<ComputeTypeA, F32>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_16_16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_optimized_loads_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ComputeTypeA, TF32>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_16_16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_optimized_loads_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
|
||||
@@ -127,6 +127,38 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_16_16_instance
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_16_16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_optimized_loads_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
@@ -140,6 +172,22 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_optimized_load
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_optimized_loads_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
|
||||
@@ -479,6 +527,38 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_16_16_insta
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_16_16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_optimized_loads_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
@@ -492,6 +572,22 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_optimized_l
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_optimized_loads_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
|
||||
|
||||
@@ -349,20 +349,37 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
{
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float> && is_same_v<ComputeTypeA, float> &&
|
||||
is_same_v<ComputeTypeB, float>)
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
static_assert(is_same_v<ComputeTypeA, ComputeTypeB>,
|
||||
"Error: ComputeTypeA and ComputeTypeB should be the same");
|
||||
if constexpr(is_same_v<ComputeTypeA, float>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev2_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev5_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev2_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev5_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev2_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev5_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev2_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev5_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ComputeTypeA, TF32>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_default_pipev2_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_default_pipev5_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_pad0_pipev2_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_pad0_pipev5_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
@@ -595,20 +612,36 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
{
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float> && is_same_v<ComputeTypeA, float> &&
|
||||
is_same_v<ComputeTypeB, float>)
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev2_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev5_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev2_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev2_instances(
|
||||
op_ptrs);
|
||||
static_assert(is_same_v<ComputeTypeA, ComputeTypeB>,
|
||||
"Error: ComputeTypeA and ComputeTypeB should be the same");
|
||||
if constexpr(is_same_v<ComputeTypeA, float>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev2_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev5_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev2_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev5_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ComputeTypeA, TF32>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev2_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev5_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev2_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev5_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
|
||||
@@ -62,6 +62,21 @@ void add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_
|
||||
PassThrough,
|
||||
Bilinear,
|
||||
PassThrough>>>& instances);
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeightMultipleD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
Tuple<GKZYXC>,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32>,
|
||||
PassThrough,
|
||||
Bilinear,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
#endif
|
||||
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances(
|
||||
@@ -138,11 +153,20 @@ struct DeviceOperationInstanceFactory<
|
||||
{
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float> && is_same_v<ComputeTypeA, float> &&
|
||||
is_same_v<ComputeTypeB, float>)
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
static_assert(is_same_v<ComputeTypeA, ComputeTypeB>,
|
||||
"Error: this operator requires the same compute type");
|
||||
if constexpr(is_same_v<ComputeTypeA, TF32>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
|
||||
@@ -62,6 +62,22 @@ void add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_ins
|
||||
PassThrough,
|
||||
Scale,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeightMultipleD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
Tuple<>,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<>,
|
||||
PassThrough,
|
||||
Scale,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
#endif
|
||||
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances(
|
||||
@@ -138,11 +154,20 @@ struct DeviceOperationInstanceFactory<
|
||||
{
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float> && is_same_v<ComputeTypeA, float> &&
|
||||
is_same_v<ComputeTypeB, float>)
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
static_assert(is_same_v<ComputeTypeA, ComputeTypeB>,
|
||||
"Error: this operator requires the same compute type");
|
||||
if constexpr(is_same_v<ComputeTypeA, TF32>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
|
||||
@@ -570,6 +570,20 @@ void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NGCHW,
|
||||
@@ -606,6 +620,20 @@ void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipe
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_default_pipev2_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev5_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
@@ -618,6 +646,20 @@ void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipe
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_default_pipev5_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev2_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
@@ -630,6 +672,20 @@ void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev2_
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_pad0_pipev2_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev5_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
@@ -641,6 +697,20 @@ void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev5_
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_pad0_pipev5_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
#endif
|
||||
// conv3d backward weight
|
||||
#ifdef CK_ENABLE_BF16
|
||||
@@ -1177,7 +1247,7 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipe
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev2_instances(
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev5_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -1188,6 +1258,76 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipe
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev2_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev5_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev2_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev5_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
#endif
|
||||
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances(
|
||||
|
||||
@@ -570,6 +570,22 @@ void add_device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwg
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddClamp,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
|
||||
@@ -7,12 +7,15 @@ add_instance_library(
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_16_16_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_16_16_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_16_16_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16_16_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_optimized_loads_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_optimized_loads_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_optimized_loads_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_optimized_loads_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_f16_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_bf16_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_f32_instance.cpp
|
||||
|
||||
@@ -0,0 +1,51 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_16_16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_f32_tf32_16_16_instances<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
ConvBwdDataDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_bwd_data_xdl_f32_tf32_16_16_instances<
|
||||
2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
ConvBwdDataFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,51 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_f32_tf32_instances<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
ConvBwdDataDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_f32_tf32_instances<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
ConvBwdDataFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,52 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_optimized_loads_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_f32_tf32_optimized_loads_instances<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
ConvBwdDataDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_f32_tf32_optimized_loads_instances<
|
||||
2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
ConvBwdDataFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -10,6 +10,7 @@ set(GROUPED_CONV2D_BWD_WEIGHT
|
||||
|
||||
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp
|
||||
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp
|
||||
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp
|
||||
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instance.cpp
|
||||
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
|
||||
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev2_instance.cpp
|
||||
@@ -21,9 +22,13 @@ set(GROUPED_CONV2D_BWD_WEIGHT
|
||||
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_pad0_pipev2_instance.cpp
|
||||
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_pad0_pipev5_instance.cpp
|
||||
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev2_instance.cpp
|
||||
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_default_pipev2_instance.cpp
|
||||
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev5_instance.cpp
|
||||
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_default_pipev5_instance.cpp
|
||||
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev2_instance.cpp
|
||||
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_pad0_pipev2_instance.cpp
|
||||
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev5_instance.cpp
|
||||
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_pad0_pipev5_instance.cpp
|
||||
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp
|
||||
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp
|
||||
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev2_instance.cpp
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_default_pipev2_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f32_tf32_instances<
|
||||
2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
ConvBwdWeightDefault,
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion::v2>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,42 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_default_pipev5_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f32_tf32_instances<
|
||||
2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
ConvBwdWeightDefault,
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion::v5>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,48 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
ConvBwdWeightDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_instances<
|
||||
2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
ConvBwdWeightFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,42 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_pad0_pipev2_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f32_tf32_instances<
|
||||
2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
ConvBwdWeightFilter1x1Stride1Pad0,
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion::v2>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,42 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_pad0_pipev5_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f32_tf32_instances<
|
||||
2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
ConvBwdWeightFilter1x1Stride1Pad0,
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion::v5>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -6,12 +6,15 @@ set(GROUPED_CONV3D_BWD_DATA
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16_16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16_16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16_16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_16_16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_optimized_loads_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_optimized_loads_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_optimized_loads_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_optimized_loads_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkzyxc_ngkdhw_f16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkzyxc_ngkdhw_f32_instance.cpp
|
||||
|
||||
@@ -0,0 +1,51 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_16_16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_f32_tf32_16_16_instances<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
ConvBwdDataDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_bwd_data_xdl_f32_tf32_16_16_instances<
|
||||
3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
ConvBwdDataFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,51 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_f32_tf32_instances<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
ConvBwdDataDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_f32_tf32_instances<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
ConvBwdDataFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,52 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_optimized_loads_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_f32_tf32_optimized_loads_instances<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
ConvBwdDataDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_f32_tf32_optimized_loads_instances<
|
||||
3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
ConvBwdDataFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -6,6 +6,7 @@ set(GROUPED_CONV3D_BWD_WEIGHT
|
||||
|
||||
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
|
||||
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
|
||||
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp
|
||||
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instance.cpp
|
||||
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
|
||||
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev2_instance.cpp
|
||||
@@ -17,9 +18,13 @@ set(GROUPED_CONV3D_BWD_WEIGHT
|
||||
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev2_instance.cpp
|
||||
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev5_instance.cpp
|
||||
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev2_instance.cpp
|
||||
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev2_instance.cpp
|
||||
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev5_instance.cpp
|
||||
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev5_instance.cpp
|
||||
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev2_instance.cpp
|
||||
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev2_instance.cpp
|
||||
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev5_instance.cpp
|
||||
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev5_instance.cpp
|
||||
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp
|
||||
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp
|
||||
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_instance.cpp
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev2_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f32_tf32_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
ConvBwdWeightDefault,
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion::v2>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,42 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev5_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f32_tf32_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
ConvBwdWeightDefault,
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion::v5>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,48 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
ConvBwdWeightDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
ConvBwdWeightFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,42 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev2_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f32_tf32_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
ConvBwdWeightFilter1x1Stride1Pad0,
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion::v2>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,42 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev5_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f32_tf32_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
ConvBwdWeightFilter1x1Stride1Pad0,
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion::v5>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -2,6 +2,7 @@
|
||||
set(GROUPED_CONV3D_BWD_WEIGHT_BILINEAR
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp)
|
||||
|
||||
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_bilinear_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeightMultipleD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
Tuple<GKZYXC>,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32>,
|
||||
PassThrough,
|
||||
Bilinear,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_bilinear_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
ConvBwdWeightDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_bilinear_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
ConvBwdWeightFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -2,6 +2,7 @@
|
||||
set(GROUPED_CONV3D_BWD_WEIGHT_SCALE
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp)
|
||||
|
||||
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_scale_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeightMultipleD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
Tuple<>,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<>,
|
||||
PassThrough,
|
||||
Scale,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_scale_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
ConvBwdWeightDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_scale_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
ConvBwdWeightFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -29,7 +29,8 @@ template <ck::index_t NDimSpatial,
|
||||
typename InLayout,
|
||||
typename OutDataType,
|
||||
typename WeiDataType,
|
||||
typename InDataType>
|
||||
typename InDataType,
|
||||
typename ComputeDataType = InDataType>
|
||||
bool profile_grouped_conv_bwd_data_impl(int do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
@@ -96,7 +97,11 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp>();
|
||||
OutElementOp,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
ComputeDataType>();
|
||||
|
||||
auto ref_invoker = ref_conv.MakeInvoker();
|
||||
|
||||
@@ -171,9 +176,13 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
|
||||
{
|
||||
in_device_buf.FromDevice(in_device.mData.data());
|
||||
|
||||
using ComputeType = std::conditional_t<sizeof(OutDataType) < sizeof(WeiDataType),
|
||||
OutDataType,
|
||||
WeiDataType>;
|
||||
using ComputeType_ = std::conditional_t<sizeof(OutDataType) < sizeof(WeiDataType),
|
||||
OutDataType,
|
||||
WeiDataType>;
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ComputeType_) < sizeof(ComputeDataType),
|
||||
ComputeType_,
|
||||
ComputeDataType>;
|
||||
using AccDataType =
|
||||
std::conditional_t<std::is_same_v<ComputeType, int8_t>, int32_t, float>;
|
||||
const index_t num_accums = conv_param.K_;
|
||||
@@ -222,18 +231,21 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
|
||||
};
|
||||
|
||||
// do GEMM
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD<NDimSpatial,
|
||||
OutLayout,
|
||||
WeiLayout,
|
||||
ck::Tuple<>,
|
||||
InLayout,
|
||||
OutDataType,
|
||||
WeiDataType,
|
||||
ck::Tuple<>,
|
||||
InDataType,
|
||||
OutElementOp,
|
||||
WeiElementOp,
|
||||
InElementOp>;
|
||||
using DeviceOp =
|
||||
ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD<NDimSpatial,
|
||||
OutLayout,
|
||||
WeiLayout,
|
||||
ck::Tuple<>,
|
||||
InLayout,
|
||||
OutDataType,
|
||||
WeiDataType,
|
||||
ck::Tuple<>,
|
||||
InDataType,
|
||||
OutElementOp,
|
||||
WeiElementOp,
|
||||
InElementOp,
|
||||
ComputeDataType,
|
||||
ComputeDataType>;
|
||||
|
||||
// get device op instances
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
|
||||
@@ -21,9 +21,10 @@ enum struct ConvLayout
|
||||
|
||||
enum struct ConvDataType
|
||||
{
|
||||
F32_F32_F32, // 0
|
||||
F16_F16_F16, // 1
|
||||
BF16_BF16_BF16, // 2
|
||||
F32_F32_F32, // 0
|
||||
F16_F16_F16, // 1
|
||||
BF16_BF16_BF16, // 2
|
||||
F32_F32_F32_TF32, // 3
|
||||
};
|
||||
|
||||
#define OP_NAME "grouped_conv_bwd_data"
|
||||
@@ -37,6 +38,7 @@ static void print_helper_msg()
|
||||
<< "arg2: data type (0: Output fp32, Weight fp32, Input fp32\n"
|
||||
<< " 1: Output fp16, Weight fp16, Input fp16\n"
|
||||
<< " 2: Output bf16, Weight bf16, Input bf16\n"
|
||||
<< " 3: Output fp32, Weight fp32, Input fp32, Compute tf32)\n"
|
||||
<< "arg3: tensor layout (0: Output[G, N, Ho, Wo, C], Weight[G, K, Y, X, C], Input[G, N, Hi, Wi, K]\n"
|
||||
<< " 1: Output[N, Ho, Wo, G, C], Weight[G, K, Y, X, C], Input[N, Hi, Wi, G, K])\n"
|
||||
<< " 2: Output[N, G, C, Ho, Wo], Weight[G, K, Y, X, C], Input[N, G, K, Hi, Wi])\n"
|
||||
@@ -82,6 +84,9 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
|
||||
using F32 = float;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
#if defined(__gfx942__)
|
||||
using TF32 = ck::tf32_t;
|
||||
#endif
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
@@ -94,16 +99,18 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
|
||||
auto in_layout,
|
||||
auto wei_type,
|
||||
auto out_type,
|
||||
auto in_type) {
|
||||
auto in_type,
|
||||
auto compute_type) {
|
||||
constexpr ck::index_t NDimSpatial = num_dim_spatial_tmp.value;
|
||||
|
||||
using OutLayout = decltype(out_layout);
|
||||
using WeiLayout = decltype(wei_layout);
|
||||
using InLayout = decltype(in_layout);
|
||||
|
||||
using OutDataType = decltype(out_type);
|
||||
using WeiDataType = decltype(wei_type);
|
||||
using InDataType = decltype(in_type);
|
||||
using OutDataType = decltype(out_type);
|
||||
using WeiDataType = decltype(wei_type);
|
||||
using InDataType = decltype(in_type);
|
||||
using ComputeDataType = decltype(compute_type);
|
||||
|
||||
bool pass = ck::profiler::profile_grouped_conv_bwd_data_impl<NDimSpatial,
|
||||
OutLayout,
|
||||
@@ -111,7 +118,8 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
|
||||
InLayout,
|
||||
OutDataType,
|
||||
WeiDataType,
|
||||
InDataType>(
|
||||
InDataType,
|
||||
ComputeDataType>(
|
||||
do_verification, init_method, do_log, time_kernel, params, split_k);
|
||||
|
||||
return pass ? 0 : 1;
|
||||
@@ -123,60 +131,84 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return profile(I2, GNHWK{}, GKYXC{}, GNHWC{}, F32{}, F32{}, F32{});
|
||||
return profile(I2, GNHWK{}, GKYXC{}, GNHWC{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return profile(I2, GNHWK{}, GKYXC{}, GNHWC{}, F16{}, F16{}, F16{});
|
||||
return profile(I2, GNHWK{}, GKYXC{}, GNHWC{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return profile(I2, GNHWK{}, GKYXC{}, GNHWC{}, BF16{}, BF16{}, BF16{});
|
||||
return profile(I2, GNHWK{}, GKYXC{}, GNHWC{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
return profile(I2, GNHWK{}, GKYXC{}, GNHWC{}, F32{}, F32{}, F32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if(layout == ConvLayout::NHWGC_GKYXC_NHWGK)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return profile(I2, NHWGK{}, GKYXC{}, NHWGC{}, F32{}, F32{}, F32{});
|
||||
return profile(I2, NHWGK{}, GKYXC{}, NHWGC{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return profile(I2, NHWGK{}, GKYXC{}, NHWGC{}, F16{}, F16{}, F16{});
|
||||
return profile(I2, NHWGK{}, GKYXC{}, NHWGC{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return profile(I2, NHWGK{}, GKYXC{}, NHWGC{}, BF16{}, BF16{}, BF16{});
|
||||
return profile(I2, NHWGK{}, GKYXC{}, NHWGC{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
return profile(I2, NHWGK{}, GKYXC{}, NHWGC{}, F32{}, F32{}, F32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return profile(I2, NGKHW{}, GKYXC{}, NGCHW{}, F32{}, F32{}, F32{});
|
||||
return profile(I2, NGKHW{}, GKYXC{}, NGCHW{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return profile(I2, NGKHW{}, GKYXC{}, NGCHW{}, F16{}, F16{}, F16{});
|
||||
return profile(I2, NGKHW{}, GKYXC{}, NGCHW{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return profile(I2, NGKHW{}, GKYXC{}, NGCHW{}, BF16{}, BF16{}, BF16{});
|
||||
return profile(I2, NGKHW{}, GKYXC{}, NGCHW{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
return profile(I2, NGKHW{}, GKYXC{}, NGCHW{}, F32{}, F32{}, F32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if(layout == ConvLayout::NGCHW_GKCYX_NGKHW)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return profile(I2, NGKHW{}, GKCYX{}, NGCHW{}, F32{}, F32{}, F32{});
|
||||
return profile(I2, NGKHW{}, GKCYX{}, NGCHW{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return profile(I2, NGKHW{}, GKCYX{}, NGCHW{}, F16{}, F16{}, F16{});
|
||||
return profile(I2, NGKHW{}, GKCYX{}, NGCHW{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return profile(I2, NGKHW{}, GKCYX{}, NGCHW{}, BF16{}, BF16{}, BF16{});
|
||||
return profile(I2, NGKHW{}, GKCYX{}, NGCHW{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
return profile(I2, NGKHW{}, GKCYX{}, NGCHW{}, F32{}, F32{}, F32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -186,60 +218,84 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return profile(I3, GNDHWK{}, GKZYXC{}, GNDHWC{}, F32{}, F32{}, F32{});
|
||||
return profile(I3, GNDHWK{}, GKZYXC{}, GNDHWC{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return profile(I3, GNDHWK{}, GKZYXC{}, GNDHWC{}, F16{}, F16{}, F16{});
|
||||
return profile(I3, GNDHWK{}, GKZYXC{}, GNDHWC{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return profile(I3, GNDHWK{}, GKZYXC{}, GNDHWC{}, BF16{}, BF16{}, BF16{});
|
||||
return profile(I3, GNDHWK{}, GKZYXC{}, GNDHWC{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
return profile(I3, GNDHWK{}, GKZYXC{}, GNDHWC{}, F32{}, F32{}, F32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if(layout == ConvLayout::NHWGC_GKYXC_NHWGK)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return profile(I3, NDHWGK{}, GKZYXC{}, NDHWGC{}, F32{}, F32{}, F32{});
|
||||
return profile(I3, NDHWGK{}, GKZYXC{}, NDHWGC{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return profile(I3, NDHWGK{}, GKZYXC{}, NDHWGC{}, F16{}, F16{}, F16{});
|
||||
return profile(I3, NDHWGK{}, GKZYXC{}, NDHWGC{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return profile(I3, NDHWGK{}, GKZYXC{}, NDHWGC{}, BF16{}, BF16{}, BF16{});
|
||||
return profile(I3, NDHWGK{}, GKZYXC{}, NDHWGC{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
return profile(I3, NDHWGK{}, GKZYXC{}, NDHWGC{}, F32{}, F32{}, F32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return profile(I3, NGKDHW{}, GKZYXC{}, NGCDHW{}, F32{}, F32{}, F32{});
|
||||
return profile(I3, NGKDHW{}, GKZYXC{}, NGCDHW{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return profile(I3, NGKDHW{}, GKZYXC{}, NGCDHW{}, F16{}, F16{}, F16{});
|
||||
return profile(I3, NGKDHW{}, GKZYXC{}, NGCDHW{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return profile(I3, NGKDHW{}, GKZYXC{}, NGCDHW{}, BF16{}, BF16{}, BF16{});
|
||||
return profile(I3, NGKDHW{}, GKZYXC{}, NGCDHW{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
return profile(I3, NGKDHW{}, GKZYXC{}, NGCDHW{}, F32{}, F32{}, F32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return profile(I3, NGKDHW{}, GKCZYX{}, NGCDHW{}, F32{}, F32{}, F32{});
|
||||
return profile(I3, NGKDHW{}, GKCZYX{}, NGCDHW{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return profile(I3, NGKDHW{}, GKCZYX{}, NGCDHW{}, F16{}, F16{}, F16{});
|
||||
return profile(I3, NGKDHW{}, GKCZYX{}, NGCDHW{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return profile(I3, NGKDHW{}, GKCZYX{}, NGCDHW{}, BF16{}, BF16{}, BF16{});
|
||||
return profile(I3, NGKDHW{}, GKCZYX{}, NGCDHW{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
return profile(I3, NGKDHW{}, GKCZYX{}, NGCDHW{}, F32{}, F32{}, F32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,6 +28,7 @@ enum struct ConvDataType
|
||||
F16_F16_F16_BF8_F8, // 3
|
||||
I8_I8_I8, // 4
|
||||
BF16_BF16_BF16, // 5
|
||||
F32_F32_F32_TF32, // 6
|
||||
};
|
||||
|
||||
#define OP_NAME "grouped_conv_bwd_weight"
|
||||
@@ -41,7 +42,8 @@ static void print_helper_msg()
|
||||
<< " 2: Input bf16, Weight fp32, Output bf16\n"
|
||||
<< " 3: Input fp16, Weight fp16, Output fp16, Gemm bf8@fp8\n"
|
||||
<< " 4: Input int8, Weight int8, Output int8\n"
|
||||
<< " 5: Input bf16, Weight bf16, Output bf16)\n"
|
||||
<< " 5: Input bf16, Weight bf16, Output bf16\n"
|
||||
<< " 6: Input fp32, Weight fp32, Output fp32, Compute tf32)\n"
|
||||
<< "arg3: tensor layout (0: Input[G, N, C, Hi, Wi], Weight[G, K, C, Y, X], Output[G, "
|
||||
"N, K, Ho, Wo]\n"
|
||||
<< " 1: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, "
|
||||
@@ -97,6 +99,9 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F8 = ck::f8_t;
|
||||
using BF8 = ck::bf8_t;
|
||||
#if defined(__gfx942__)
|
||||
using TF32 = ck::tf32_t;
|
||||
#endif
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
@@ -155,6 +160,12 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
|
||||
// fp32 atomic add is used for weight tensor in bf16 kernel
|
||||
return profile(I1, GNWC{}, GKXC{}, GNWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
|
||||
{
|
||||
@@ -171,6 +182,12 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
|
||||
// fp32 atomic add is used for weight tensor in bf16 kernel
|
||||
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
|
||||
{
|
||||
@@ -191,6 +208,12 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
|
||||
{
|
||||
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKYXC_NGKHW)
|
||||
{
|
||||
@@ -218,6 +241,12 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
|
||||
{
|
||||
return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
|
||||
{
|
||||
@@ -239,6 +268,12 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
|
||||
return profile(
|
||||
I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, int8_t{}, int8_t{}, int8_t{}, int8_t{}, int8_t{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
|
||||
{
|
||||
@@ -269,6 +304,12 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
|
||||
return profile(
|
||||
I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, int8_t{}, int8_t{}, int8_t{}, int8_t{}, int8_t{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 3 && layout == ConvLayout::NGCHW_GKYXC_NGKHW)
|
||||
{
|
||||
@@ -297,6 +338,12 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
|
||||
{
|
||||
return profile(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
return profile(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "this data_type & layout is not implemented" << std::endl;
|
||||
|
||||
@@ -226,6 +226,12 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
{
|
||||
return profile(I1, GNWC{}, GKXC{}, GNWK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
|
||||
{
|
||||
@@ -245,6 +251,12 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
{
|
||||
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
|
||||
{
|
||||
@@ -292,6 +304,12 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
{
|
||||
return profile(I1, NWGC{}, GKXC{}, NWGK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
return profile(I1, NWGC{}, GKXC{}, NWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
|
||||
{
|
||||
@@ -311,6 +329,12 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
{
|
||||
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKYXC_NGKHW)
|
||||
{
|
||||
@@ -326,6 +350,12 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
{
|
||||
return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKCYX_NGKHW)
|
||||
{
|
||||
@@ -341,6 +371,12 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
{
|
||||
return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
|
||||
{
|
||||
|
||||
@@ -20,14 +20,15 @@ enum struct ConvLayout
|
||||
|
||||
enum struct ConvDataType
|
||||
{
|
||||
F32_F32_F32, // 0
|
||||
F16_F16_F16, // 1
|
||||
BF16_BF16_BF16, // 2
|
||||
INT8_INT8_INT8, // 3
|
||||
F8_F8_F8, // 4
|
||||
BF8_BF8_F8, // 5
|
||||
F8_BF8_F8, // 6
|
||||
BF8_F8_F8, // 7
|
||||
F32_F32_F32, // 0
|
||||
F16_F16_F16, // 1
|
||||
BF16_BF16_BF16, // 2
|
||||
INT8_INT8_INT8, // 3
|
||||
F8_F8_F8, // 4
|
||||
BF8_BF8_F8, // 5
|
||||
F8_BF8_F8, // 6
|
||||
BF8_F8_F8, // 7
|
||||
F32_F32_F32_TF32, // 8
|
||||
};
|
||||
|
||||
enum struct IndexType
|
||||
@@ -51,7 +52,8 @@ static void print_helper_msg()
|
||||
<< " 4: Input fp8, Weight fp8, Output fp8\n"
|
||||
<< " 5: Input bf8, Weight bf8, Output fp8\n"
|
||||
<< " 6: Input fp8, Weight bf8, Output fp8\n"
|
||||
<< " 7: Input bf8, Weight fp8, Output fp8)\n"
|
||||
<< " 7: Input bf8, Weight fp8, Output fp8\n"
|
||||
<< " 8: Input fp32, Weight fp32, Output fp32, Compute tf32)\n"
|
||||
<< "arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n"
|
||||
<< " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K]\n"
|
||||
<< " 2: Input[N, G, C, Hi, Wi], Weight[G, K, Y, X, C], Output[N, "
|
||||
@@ -103,6 +105,9 @@ int grouped_conv_fwd_bias_clamp(int argc, char* argv[])
|
||||
using F32 = float;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
#if defined(__gfx942__)
|
||||
using TF32 = ck::tf32_t;
|
||||
#endif
|
||||
|
||||
using GKZYXC = ck::tensor_layout::convolution::GKZYXC;
|
||||
using NDHWGC = ck::tensor_layout::convolution::NDHWGC;
|
||||
@@ -165,6 +170,12 @@ int grouped_conv_fwd_bias_clamp(int argc, char* argv[])
|
||||
{
|
||||
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
|
||||
{
|
||||
@@ -181,6 +192,12 @@ int grouped_conv_fwd_bias_clamp(int argc, char* argv[])
|
||||
return profile(
|
||||
I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "this data_type & layout is not implemented" << std::endl;
|
||||
|
||||
@@ -20,14 +20,15 @@ enum struct ConvLayout
|
||||
|
||||
enum struct ConvDataType
|
||||
{
|
||||
F32_F32_F32, // 0
|
||||
F16_F16_F16, // 1
|
||||
BF16_BF16_BF16, // 2
|
||||
INT8_INT8_INT8, // 3
|
||||
F8_F8_F8, // 4
|
||||
BF8_BF8_F8, // 5
|
||||
F8_BF8_F8, // 6
|
||||
BF8_F8_F8, // 7
|
||||
F32_F32_F32, // 0
|
||||
F16_F16_F16, // 1
|
||||
BF16_BF16_BF16, // 2
|
||||
INT8_INT8_INT8, // 3
|
||||
F8_F8_F8, // 4
|
||||
BF8_BF8_F8, // 5
|
||||
F8_BF8_F8, // 6
|
||||
BF8_F8_F8, // 7
|
||||
F32_F32_F32_TF32, // 8
|
||||
};
|
||||
|
||||
enum struct IndexType
|
||||
@@ -51,7 +52,8 @@ static void print_helper_msg()
|
||||
<< " 4: Input fp8, Weight fp8, Output fp8\n"
|
||||
<< " 5: Input bf8, Weight bf8, Output fp8\n"
|
||||
<< " 6: Input fp8, Weight bf8, Output fp8\n"
|
||||
<< " 7: Input bf8, Weight fp8, Output fp8)\n"
|
||||
<< " 7: Input bf8, Weight fp8, Output fp8\n"
|
||||
<< " 8: Input fp32, Weight fp32, Output fp32, Compute tf32)\n"
|
||||
<< "arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n"
|
||||
<< " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K]\n"
|
||||
<< " 2: Input[N, G, C, Hi, Wi], Weight[G, K, Y, X, C], Output[N, "
|
||||
@@ -103,6 +105,9 @@ int grouped_conv_fwd_clamp(int argc, char* argv[])
|
||||
using F32 = float;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
#if defined(__gfx942__)
|
||||
using TF32 = ck::tf32_t;
|
||||
#endif
|
||||
|
||||
using GKZYXC = ck::tensor_layout::convolution::GKZYXC;
|
||||
using NDHWGC = ck::tensor_layout::convolution::NDHWGC;
|
||||
@@ -168,6 +173,12 @@ int grouped_conv_fwd_clamp(int argc, char* argv[])
|
||||
{
|
||||
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
|
||||
{
|
||||
@@ -184,6 +195,12 @@ int grouped_conv_fwd_clamp(int argc, char* argv[])
|
||||
return profile(
|
||||
I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "this data_type & layout is not implemented" << std::endl;
|
||||
|
||||
Reference in New Issue
Block a user