mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
* Revert "Revert "feature:tf32:add initial conv3d fwd kernel support (#2763)" (#2848)"
This reverts commit 03b59f8c76.
* fix compile error on gf12x
* only run tf32 example on gfx942
* only build tf32 instance on gfx942
* ckProfiler:only support tf32 in gfx942
* delete unuseful messages
This commit is contained in:
@@ -59,6 +59,7 @@ template <ck::index_t NDimSpatial,
|
||||
ck::index_t NumAElementwiseTensor = 0,
|
||||
ck::index_t NumBElementwiseTensor = 0,
|
||||
ck::index_t NumDElementwiseTensor = 0,
|
||||
typename ComputeDataType = InDataType,
|
||||
typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
|
||||
struct ReferenceConvFwd : public device::BaseOperator
|
||||
{
|
||||
@@ -163,8 +164,18 @@ struct ReferenceConvFwd : public device::BaseOperator
|
||||
k,
|
||||
c,
|
||||
x);
|
||||
v_acc +=
|
||||
ck::type_convert<float>(v_in) * ck::type_convert<float>(v_wei);
|
||||
if constexpr(is_same_v<ComputeDataType, ck::tf32_t>)
|
||||
{
|
||||
v_acc += ck::type_convert<float>(
|
||||
ck::type_convert<ComputeDataType>(v_in)) *
|
||||
ck::type_convert<float>(
|
||||
ck::type_convert<ComputeDataType>(v_wei));
|
||||
}
|
||||
else
|
||||
{
|
||||
v_acc += ck::type_convert<float>(v_in) *
|
||||
ck::type_convert<float>(v_wei);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -238,8 +249,18 @@ struct ReferenceConvFwd : public device::BaseOperator
|
||||
c,
|
||||
y,
|
||||
x);
|
||||
v_acc += ck::type_convert<float>(v_in) *
|
||||
ck::type_convert<float>(v_wei);
|
||||
if constexpr(is_same_v<ComputeDataType, ck::tf32_t>)
|
||||
{
|
||||
v_acc += ck::type_convert<float>(
|
||||
ck::type_convert<ComputeDataType>(v_in)) *
|
||||
ck::type_convert<float>(
|
||||
ck::type_convert<ComputeDataType>(v_wei));
|
||||
}
|
||||
else
|
||||
{
|
||||
v_acc += ck::type_convert<float>(v_in) *
|
||||
ck::type_convert<float>(v_wei);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -327,8 +348,18 @@ struct ReferenceConvFwd : public device::BaseOperator
|
||||
z,
|
||||
y,
|
||||
x);
|
||||
v_acc += ck::type_convert<float>(v_in) *
|
||||
ck::type_convert<float>(v_wei);
|
||||
if constexpr(is_same_v<ComputeDataType, ck::tf32_t>)
|
||||
{
|
||||
v_acc += ck::type_convert<float>(
|
||||
ck::type_convert<ComputeDataType>(v_in)) *
|
||||
ck::type_convert<float>(
|
||||
ck::type_convert<ComputeDataType>(v_wei));
|
||||
}
|
||||
else
|
||||
{
|
||||
v_acc += ck::type_convert<float>(v_in) *
|
||||
ck::type_convert<float>(v_wei);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,6 +25,12 @@ template <typename ADataType,
|
||||
typename ComputeTypeB = ComputeTypeA>
|
||||
struct ReferenceGemm : public device::BaseOperator
|
||||
{
|
||||
|
||||
using ElementDataTypeA =
|
||||
ck::conditional_t<is_same_v<ComputeTypeA, ck::tf32_t>, float, ComputeTypeA>;
|
||||
using ElementDataTypeB =
|
||||
ck::conditional_t<is_same_v<ComputeTypeB, ck::tf32_t>, float, ComputeTypeB>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public device::BaseArgument
|
||||
{
|
||||
@@ -63,8 +69,8 @@ struct ReferenceGemm : public device::BaseOperator
|
||||
const int K = arg.a_m_k_.mDesc.GetLengths()[1];
|
||||
|
||||
AccDataType v_acc{0};
|
||||
ComputeTypeA v_a{0};
|
||||
ComputeTypeB v_b{0};
|
||||
ElementDataTypeA v_a{0};
|
||||
ElementDataTypeB v_b{0};
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
@@ -77,16 +83,16 @@ struct ReferenceGemm : public device::BaseOperator
|
||||
else
|
||||
i4 = (i4x2 >> 4) & 0xf;
|
||||
i4 = i4 - 8;
|
||||
v_a = type_convert<ComputeTypeA>(i4);
|
||||
v_a = type_convert<ElementDataTypeA>(i4);
|
||||
}
|
||||
else if constexpr(is_same_v<ADataType, f4x2_pk_t>)
|
||||
{
|
||||
// TODO: add support for ColMajor layout as well
|
||||
if(k % 2 == 1)
|
||||
v_a = type_convert<ComputeTypeA>(
|
||||
v_a = type_convert<ElementDataTypeA>(
|
||||
f4_t(arg.a_m_k_(m, k).template unpack<>(Number<1>{})));
|
||||
else
|
||||
v_a = type_convert<ComputeTypeA>(
|
||||
v_a = type_convert<ElementDataTypeA>(
|
||||
f4_t(arg.a_m_k_(m, k).template unpack<>(Number<0>{})));
|
||||
}
|
||||
else if constexpr(is_same_v<ADataType, f6x16_pk_t> ||
|
||||
@@ -94,7 +100,7 @@ struct ReferenceGemm : public device::BaseOperator
|
||||
is_same_v<ADataType, f6x32_pk_t> ||
|
||||
is_same_v<ADataType, bf6x32_pk_t>)
|
||||
{
|
||||
v_a = type_convert<ComputeTypeA>(
|
||||
v_a = type_convert<ElementDataTypeA>(
|
||||
arg.a_m_k_(m, k).unpack(k % ADataType::packed_size));
|
||||
}
|
||||
else
|
||||
@@ -111,16 +117,16 @@ struct ReferenceGemm : public device::BaseOperator
|
||||
else
|
||||
i4 = (i4x2 >> 4) & 0xf;
|
||||
i4 = i4 - 8;
|
||||
v_b = type_convert<ComputeTypeB>(i4);
|
||||
v_b = type_convert<ElementDataTypeB>(i4);
|
||||
}
|
||||
else if constexpr(is_same_v<BDataType, f4x2_pk_t>)
|
||||
{
|
||||
// TODO: add support for RowMajor layout as well
|
||||
if(k % 2 == 1)
|
||||
v_b = type_convert<ComputeTypeB>(
|
||||
v_b = type_convert<ElementDataTypeB>(
|
||||
f4_t(arg.b_k_n_(k, n).template unpack<>(Number<1>{})));
|
||||
else
|
||||
v_b = type_convert<ComputeTypeB>(
|
||||
v_b = type_convert<ElementDataTypeB>(
|
||||
f4_t(arg.b_k_n_(k, n).template unpack<>(Number<0>{})));
|
||||
}
|
||||
else if constexpr(is_same_v<BDataType, f6x16_pk_t> ||
|
||||
@@ -128,7 +134,7 @@ struct ReferenceGemm : public device::BaseOperator
|
||||
is_same_v<BDataType, f6x32_pk_t> ||
|
||||
is_same_v<BDataType, bf6x32_pk_t>)
|
||||
{
|
||||
v_b = type_convert<ComputeTypeB>(
|
||||
v_b = type_convert<ElementDataTypeB>(
|
||||
arg.b_k_n_(k, n).unpack(k % BDataType::packed_size));
|
||||
}
|
||||
else
|
||||
@@ -136,8 +142,18 @@ struct ReferenceGemm : public device::BaseOperator
|
||||
arg.b_element_op_(v_b, arg.b_k_n_(k, n));
|
||||
}
|
||||
|
||||
v_acc +=
|
||||
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
|
||||
if constexpr(is_same_v<ComputeTypeA, ComputeTypeB> &&
|
||||
is_same_v<ComputeTypeA, ck::tf32_t>)
|
||||
{ // only for tf32 now
|
||||
v_acc +=
|
||||
ck::type_convert<AccDataType>(ck::type_convert<ComputeTypeA>(v_a)) *
|
||||
ck::type_convert<AccDataType>(ck::type_convert<ComputeTypeB>(v_b));
|
||||
}
|
||||
else
|
||||
{
|
||||
v_acc +=
|
||||
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
|
||||
}
|
||||
}
|
||||
|
||||
CDataType v_c{0};
|
||||
|
||||
@@ -38,6 +38,10 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
const CDEElementwiseOperation c_element_op)
|
||||
{
|
||||
using RowMajor = ck::tensor_layout::gemm::RowMajor;
|
||||
using ElementDataTypeA =
|
||||
ck::conditional_t<is_same_v<ComputeTypeA, ck::tf32_t>, float, ComputeTypeA>;
|
||||
using ElementDataTypeB =
|
||||
ck::conditional_t<is_same_v<ComputeTypeB, ck::tf32_t>, float, ComputeTypeB>;
|
||||
|
||||
const int row_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int col_idx = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
@@ -46,8 +50,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
{
|
||||
|
||||
AccDataType v_acc{0};
|
||||
ComputeTypeA v_a{0};
|
||||
ComputeTypeB v_b{0};
|
||||
ElementDataTypeA v_a{0};
|
||||
ElementDataTypeB v_b{0};
|
||||
CDataType v_c{0};
|
||||
|
||||
for(int k_idx = 0; k_idx < k; ++k_idx)
|
||||
@@ -76,7 +80,16 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
// apply b_element_op
|
||||
b_element_op(v_b, p_b_grid[element_idx_b]);
|
||||
// multiply and accumulate
|
||||
v_acc += type_convert<AccDataType>(v_a) * type_convert<AccDataType>(v_b);
|
||||
if constexpr(is_same_v<ComputeTypeA, ComputeTypeB> &&
|
||||
is_same_v<ComputeTypeA, ck::tf32_t>)
|
||||
{ // only for tf32 now
|
||||
v_acc += ck::type_convert<AccDataType>(ck::type_convert<ComputeTypeA>(v_a)) *
|
||||
ck::type_convert<AccDataType>(ck::type_convert<ComputeTypeB>(v_b));
|
||||
}
|
||||
else
|
||||
{
|
||||
v_acc += type_convert<AccDataType>(v_a) * type_convert<AccDataType>(v_b);
|
||||
}
|
||||
}
|
||||
// apply c_element_op
|
||||
c_element_op(v_c, v_acc);
|
||||
|
||||
@@ -16,6 +16,7 @@ namespace instance {
|
||||
// aliasing, for commonly used data type
|
||||
using F64 = double;
|
||||
using F32 = float;
|
||||
using TF32 = ck::tf32_t;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using I8 = int8_t;
|
||||
|
||||
@@ -16,6 +16,7 @@ namespace instance {
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using TF32 = ck::tf32_t;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
@@ -24,6 +24,7 @@ using BF8 = ck::bf8_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using TF32 = ck::tf32_t;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
@@ -199,7 +200,7 @@ using device_grouped_conv_fwd_xdl_f16_nchw_instances = std::tuple<
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 8, 1, 8>, 1>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 8, 1, 8>, 1>,
|
||||
// 32x32 instance
|
||||
// 32x32 instance
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
@@ -284,7 +285,45 @@ using device_grouped_conv_fwd_xdl_f32_instances = std::tuple<
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 192, 16, 4, 4, 32, 32, 2, 3, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ConvolutionForwardSpecialization ConvSpec,
|
||||
typename DsDataTypes = Tuple<>,
|
||||
typename OutElementOp = PassThrough>
|
||||
using device_grouped_conv_fwd_xdl_f32_tf32_instances = std::tuple<
|
||||
// clang-format off
|
||||
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| AComputeType| BComputeType|
|
||||
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| DATATYPE | DATATYPE |
|
||||
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl |
|
||||
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// generic instance
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, TF32, TF32>,
|
||||
// instances for small conv.K and conv.C
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>,
|
||||
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 192, 16, 4, 4, 32, 32, 2, 3, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
|
||||
@@ -443,6 +443,12 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float> && is_same_v<AComputeType, TF32> &&
|
||||
is_same_v<BComputeType, TF32>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP8
|
||||
|
||||
@@ -215,6 +215,14 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float> && is_same_v<AComputeType, TF32> &&
|
||||
is_same_v<BComputeType, TF32>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#endif // CK_USE_XDL
|
||||
|
||||
@@ -578,6 +578,22 @@ void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_insta
|
||||
PassThrough,
|
||||
AddClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddClamp,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
|
||||
@@ -210,6 +210,14 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float> && is_same_v<AComputeType, TF32> &&
|
||||
is_same_v<BComputeType, TF32>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#endif // CK_USE_XDL
|
||||
|
||||
@@ -578,6 +578,22 @@ void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
PassThrough,
|
||||
Clamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
|
||||
@@ -132,6 +132,7 @@ void add_device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_f32_insta
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
DynamicUnaryOp>>>& instances);
|
||||
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_INT8
|
||||
@@ -159,7 +160,8 @@ template <ck::index_t NumDimSpatial,
|
||||
typename WeiDataType,
|
||||
typename DDataTypes,
|
||||
typename OutDataType,
|
||||
typename ComputeType>
|
||||
typename AComputeType,
|
||||
typename BComputeType = AComputeType>
|
||||
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
|
||||
NumDimSpatial,
|
||||
InLayout,
|
||||
@@ -173,7 +175,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::DynamicUnaryOp,
|
||||
ComputeType>>
|
||||
AComputeType,
|
||||
BComputeType>>
|
||||
{
|
||||
using DeviceOp =
|
||||
DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
|
||||
@@ -188,7 +191,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::DynamicUnaryOp,
|
||||
ComputeType>;
|
||||
AComputeType,
|
||||
BComputeType>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
@@ -207,7 +211,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t> && is_same_v<ComputeType, half_t>)
|
||||
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
op_ptrs);
|
||||
@@ -244,7 +248,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t> && is_same_v<ComputeType, half_t>)
|
||||
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_xdl_dynamic_op_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
op_ptrs);
|
||||
|
||||
@@ -559,6 +559,22 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
|
||||
@@ -94,6 +94,11 @@ function(add_instance_library INSTANCE_NAME)
|
||||
message(DEBUG "removing gemm_universal_preshuffle_f8 instance ${source} ")
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
# Only build tf32 instances for gfx942
|
||||
if(NOT INST_TARGETS MATCHES "gfx942" AND source_name MATCHES "_tf32_")
|
||||
message(DEBUG "removing tf32 instance ${source} ")
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
|
||||
endforeach()
|
||||
|
||||
@@ -441,7 +446,7 @@ if(CK_DEVICE_MHA_INSTANCES AND NOT MIOPEN_REQ_LIBS_ONLY AND BUILD_MHA_LIB)
|
||||
add_library(composablekernels::device_mha_operations ALIAS device_mha_operations)
|
||||
target_compile_features(device_mha_operations PUBLIC)
|
||||
set_target_properties(device_mha_operations PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
|
||||
rocm_install(TARGETS device_mha_operations
|
||||
EXPORT device_mha_operationsTargets)
|
||||
rocm_install(EXPORT device_mha_operationsTargets
|
||||
|
||||
@@ -7,6 +7,7 @@ set(GROUPED_CONV3D_FWD
|
||||
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instance.cpp
|
||||
@@ -34,7 +35,7 @@ set(GROUPED_CONV3D_FWD
|
||||
xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp
|
||||
xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instance.cpp
|
||||
xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instance.cpp
|
||||
|
||||
|
||||
xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instance.cpp
|
||||
xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_comp_instance.cpp
|
||||
xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_2x_instance.cpp
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
ConvFwdDefault>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -2,7 +2,7 @@
|
||||
set(GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP)
|
||||
include(ShardInstantiation)
|
||||
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances
|
||||
@@ -11,7 +11,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
|
||||
@@ -20,7 +20,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances
|
||||
@@ -29,7 +29,16 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances
|
||||
TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in
|
||||
NUM_SHARDS 16
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances
|
||||
@@ -38,7 +47,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instances
|
||||
@@ -47,7 +56,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances
|
||||
@@ -58,7 +67,7 @@ generate_sharded_instantiations(
|
||||
)
|
||||
# large tensor
|
||||
# NDHWGC, GKZYXC, NDHWGK
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances
|
||||
@@ -67,7 +76,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances
|
||||
@@ -76,7 +85,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances
|
||||
@@ -87,7 +96,7 @@ generate_sharded_instantiations(
|
||||
)
|
||||
# merged groups
|
||||
# NDHWGC, GKZYXC, NDHWGK
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances
|
||||
@@ -96,7 +105,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances
|
||||
@@ -105,7 +114,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances
|
||||
@@ -116,7 +125,7 @@ generate_sharded_instantiations(
|
||||
)
|
||||
#mem
|
||||
# NDHWGC, GKZYXC, NDHWGK
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances
|
||||
@@ -125,7 +134,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instances
|
||||
@@ -134,7 +143,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances
|
||||
@@ -144,7 +153,7 @@ generate_sharded_instantiations(
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
|
||||
)
|
||||
# NDHWGC, GKZYXC, NDHWGK
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances
|
||||
@@ -153,7 +162,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instances
|
||||
@@ -162,7 +171,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances
|
||||
@@ -173,7 +182,7 @@ generate_sharded_instantiations(
|
||||
)
|
||||
#comp
|
||||
# NDHWGC, GKZYXC, NDHWGK
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances
|
||||
@@ -182,7 +191,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances
|
||||
@@ -191,7 +200,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances
|
||||
@@ -200,7 +209,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_2x_instances
|
||||
@@ -209,7 +218,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_2x_instances
|
||||
@@ -218,7 +227,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_part2_instances
|
||||
@@ -227,7 +236,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_part2_instances
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances =
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp,
|
||||
TF32,
|
||||
TF32>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances_shard(
|
||||
device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
ck::util::filter_tuple_by_modulo_t<device_grouped_conv_fwd_xdl_f32_tf32_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
ck::util::filter_tuple_by_modulo_t<device_grouped_conv_fwd_xdl_f32_tf32_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
ck::util::filter_tuple_by_modulo_t<device_grouped_conv_fwd_xdl_f32_tf32_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -23,6 +23,8 @@ set(GROUPED_CONV3D_FWD
|
||||
xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_mem_inter_instance.cpp
|
||||
xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_mem_intra_instance.cpp
|
||||
xdl/comp/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_comp_instance.cpp
|
||||
)
|
||||
|
||||
xdl/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp
|
||||
)
|
||||
|
||||
add_instance_library(device_grouped_conv3d_fwd_bias_clamp_instance ${GROUPED_CONV3D_FWD})
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddClamp,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<F32>,
|
||||
AddClamp>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Tuple<F32>,
|
||||
AddClamp>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Tuple<F32>,
|
||||
AddClamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -23,6 +23,8 @@ set(GROUPED_CONV3D_FWD
|
||||
xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_mem_inter_instance.cpp
|
||||
xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_mem_intra_instance.cpp
|
||||
xdl/comp/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_comp_instance.cpp
|
||||
)
|
||||
|
||||
xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp
|
||||
)
|
||||
|
||||
add_instance_library(device_grouped_conv3d_fwd_clamp_instance ${GROUPED_CONV3D_FWD})
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -53,7 +53,7 @@ std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc)
|
||||
|
||||
os << "strides {";
|
||||
LogRange(os, desc.GetStrides(), ", ");
|
||||
os << "}";
|
||||
os << "} ";
|
||||
|
||||
return os;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user