mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 02:27:57 +00:00
Merge commit 'f6c2ff9dcedbc58065ae1fc10a661f00716c6839' into develop
This commit is contained in:
@@ -311,8 +311,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
|
||||
static_assert(NumGroupsToMerge >= 1);
|
||||
|
||||
static constexpr bool isMultiA = is_detected<is_tuple, ADataType>::value;
|
||||
static constexpr bool isMultiB = is_detected<is_tuple, BDataType>::value;
|
||||
static constexpr bool isMultiA = is_detected<is_tuple, ADataType>::value;
|
||||
static constexpr bool isMultiB = is_detected<is_tuple, BDataType>::value;
|
||||
static constexpr bool isMultiAB = isMultiA || isMultiB;
|
||||
|
||||
// NGCHW is not supported for multiAB
|
||||
static_assert(!(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
@@ -323,6 +324,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
static constexpr index_t NumBTensor = GetNumABTensors<isMultiB, BDataType>();
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static constexpr bool DoElementwiseBeforeCShuffle =
|
||||
NumDTensor == 0 && !isMultiAB && is_same_v<EDataType, bhalf_t> &&
|
||||
!is_same_v<CDEElementwiseOperation, tensor_operation::element_wise::PassThrough>;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
@@ -465,7 +470,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
|
||||
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
|
||||
BComputeDataType
|
||||
BComputeDataType, DoElementwiseBeforeCShuffle
|
||||
// Use appropriate gridwise gemm
|
||||
using GridwiseGemm = std::conditional_t<
|
||||
isMultiA || isMultiB,
|
||||
|
||||
@@ -279,6 +279,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
static constexpr bool isMultiD = DsDataType::Size() > 0;
|
||||
static constexpr bool isMultiABD = isMultiA || isMultiB || isMultiD;
|
||||
|
||||
static constexpr bool DoElementwiseBeforeCShuffle =
|
||||
!isMultiABD && is_same_v<EDataType, bhalf_t> &&
|
||||
!is_same_v<CDEElementwiseOperation, tensor_operation::element_wise::PassThrough>;
|
||||
|
||||
static constexpr index_t NumATensor = GetNumABTensors<isMultiA, ADataType>();
|
||||
static constexpr index_t NumBTensor = GetNumABTensors<isMultiB, BDataType>();
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
@@ -412,7 +416,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
|
||||
CDEBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, \
|
||||
AComputeDataType, BComputeDataType
|
||||
AComputeDataType, BComputeDataType, false, false, DoElementwiseBeforeCShuffle
|
||||
|
||||
// Use appropriate gridwise gemm
|
||||
using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3<GridwiseGemmV3TemplateParams>;
|
||||
@@ -780,8 +784,20 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
sizeof(EDataType);
|
||||
}
|
||||
|
||||
typename GridwiseGemm::Argument gemm_arg{
|
||||
p_a_grid, p_b_grid, p_e_grid, GemmM, GemmN, GemmK, I0, I0, I0, I1};
|
||||
typename GridwiseGemm::Argument gemm_arg{p_a_grid,
|
||||
p_b_grid,
|
||||
p_e_grid,
|
||||
GemmM,
|
||||
GemmN,
|
||||
GemmK,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
I1,
|
||||
false,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_};
|
||||
|
||||
const auto Run = [&](const auto& kernel) {
|
||||
if(stream_config.flush_cache)
|
||||
|
||||
@@ -192,6 +192,9 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
static constexpr index_t MaxGemmsNum = 32;
|
||||
static constexpr bool DoElementwiseBeforeCShuffle =
|
||||
NumDTensor == 0 && is_same_v<EDataType, bhalf_t> &&
|
||||
!is_same_v<CDEElementwiseOperation, tensor_operation::element_wise::PassThrough>;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
@@ -361,7 +364,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
|
||||
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
|
||||
AComputeDataType
|
||||
AComputeDataType, DoElementwiseBeforeCShuffle
|
||||
// Use appropriate gridwise gemm
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmTemplateParameters>;
|
||||
|
||||
|
||||
@@ -730,6 +730,15 @@ struct UnaryAbs
|
||||
{
|
||||
y = ck::type_convert<f8_t>(ck::math::abs(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
|
||||
{
|
||||
y = ck::type_convert<bhalf_t>(ck::math::abs(x));
|
||||
};
|
||||
};
|
||||
|
||||
struct UnarySqrt
|
||||
@@ -744,6 +753,79 @@ struct UnarySqrt
|
||||
};
|
||||
};
|
||||
|
||||
struct Clamp
|
||||
{
|
||||
Clamp(float floor = 0.f, float ceil = NumericLimits<float>::Max())
|
||||
: floor_(floor), ceil_(ceil){};
|
||||
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<float, float>(float& y, const float& x) const
|
||||
{
|
||||
const float& a = x;
|
||||
y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<double, double>(double& y, const double& x) const
|
||||
{
|
||||
const double& a = x;
|
||||
y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<half_t, half_t>(half_t& y, const half_t& x) const
|
||||
{
|
||||
const float a = type_convert<half_t>(x);
|
||||
const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
|
||||
y = type_convert<half_t>(b);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<half_t, float>(half_t& y, const float& x) const
|
||||
{
|
||||
const float& a = x;
|
||||
const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
|
||||
y = type_convert<half_t>(b);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
|
||||
{
|
||||
const float& a = x;
|
||||
const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
|
||||
y = type_convert<bhalf_t>(b);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<bhalf_t, bhalf_t>(bhalf_t& y,
|
||||
const bhalf_t& x) const
|
||||
{
|
||||
const float a = type_convert<float>(x);
|
||||
const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
|
||||
y = type_convert<bhalf_t>(b);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<int, int>(int& y, const int& x) const
|
||||
{
|
||||
const int8_t& a = x;
|
||||
y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
|
||||
{
|
||||
const int8_t& a = x;
|
||||
y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
|
||||
};
|
||||
|
||||
const float floor_;
|
||||
const float ceil_;
|
||||
};
|
||||
|
||||
struct Relu
|
||||
{
|
||||
template <typename T>
|
||||
@@ -756,6 +838,9 @@ struct Relu
|
||||
y = x > 0 ? x : 0;
|
||||
}
|
||||
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const
|
||||
{
|
||||
@@ -763,6 +848,13 @@ struct Relu
|
||||
float y_f32 = x_f32 > 0 ? x_f32 : 0;
|
||||
y = type_convert<bhalf_t>(y_f32);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
|
||||
{
|
||||
float y_f32 = x > 0 ? x : 0;
|
||||
y = type_convert<bhalf_t>(y_f32);
|
||||
};
|
||||
};
|
||||
|
||||
// Fast GeLU
|
||||
@@ -915,6 +1007,16 @@ struct Sigmoid
|
||||
constexpr T one = type_convert<T>(1);
|
||||
y = one / (one + math::exp(-x));
|
||||
};
|
||||
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
|
||||
{
|
||||
constexpr float one = 1.f;
|
||||
y = type_convert<bhalf_t>(one / (one + math::exp(-x)));
|
||||
};
|
||||
};
|
||||
|
||||
struct Silu
|
||||
@@ -942,6 +1044,15 @@ struct TanH
|
||||
|
||||
y = math::tanh(x);
|
||||
};
|
||||
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
|
||||
{
|
||||
y = type_convert<bhalf_t>(math::tanh(x));
|
||||
};
|
||||
};
|
||||
|
||||
struct ACos
|
||||
@@ -1201,6 +1312,13 @@ struct Swish
|
||||
y = type_convert<Y>(x / (1.f + math::exp(bx)));
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
|
||||
{
|
||||
float bx = -beta_ * x;
|
||||
y = type_convert<bhalf_t>(x / (1.f + math::exp(bx)));
|
||||
};
|
||||
|
||||
const float beta_;
|
||||
};
|
||||
|
||||
@@ -1219,6 +1337,16 @@ struct SoftRelu
|
||||
constexpr T one = type_convert<T>(1);
|
||||
y = math::log(one + math::exp(x * casted_alpha)) / casted_alpha;
|
||||
}
|
||||
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
|
||||
{
|
||||
constexpr float one = 1.f;
|
||||
y = type_convert<bhalf_t>(math::log(one + math::exp(x * alpha_)) / alpha_);
|
||||
};
|
||||
const float alpha_;
|
||||
};
|
||||
|
||||
@@ -1240,6 +1368,17 @@ struct Power
|
||||
T shifted_scaled_x = casted_alpha + casted_beta * x;
|
||||
y = math::pow(shifted_scaled_x, casted_gamma);
|
||||
}
|
||||
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
|
||||
{
|
||||
const float shifted_scaled_x = alpha_ + beta_ * x;
|
||||
y = type_convert<bhalf_t>(math::pow(shifted_scaled_x, gamma_));
|
||||
};
|
||||
|
||||
const float alpha_;
|
||||
const float beta_;
|
||||
const float gamma_;
|
||||
@@ -1260,6 +1399,16 @@ struct ClippedRelu
|
||||
T casted_beta = type_convert<T>(beta_);
|
||||
y = math::min(casted_beta, math::max(casted_alpha, x));
|
||||
}
|
||||
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
|
||||
{
|
||||
y = type_convert<bhalf_t>(math::min(beta_, math::max(alpha_, x)));
|
||||
};
|
||||
|
||||
const float alpha_;
|
||||
const float beta_;
|
||||
};
|
||||
@@ -1278,6 +1427,16 @@ struct LeakyRelu
|
||||
T casted_alpha = type_convert<T>(alpha_);
|
||||
y = x >= 0 ? x : x * casted_alpha;
|
||||
}
|
||||
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
|
||||
{
|
||||
y = type_convert<bhalf_t>(x >= 0 ? x : x * alpha_);
|
||||
};
|
||||
|
||||
const float alpha_;
|
||||
};
|
||||
|
||||
@@ -1295,6 +1454,16 @@ struct Elu
|
||||
T casted_alpha = type_convert<T>(alpha_);
|
||||
y = x > 0 ? x : casted_alpha * math::expm1(x);
|
||||
}
|
||||
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
|
||||
{
|
||||
y = type_convert<bhalf_t>(x > 0 ? x : alpha_ * math::expm1(x));
|
||||
};
|
||||
|
||||
const float alpha_;
|
||||
};
|
||||
|
||||
@@ -1313,6 +1482,16 @@ struct Logistic
|
||||
constexpr T one = type_convert<T>(1);
|
||||
y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha);
|
||||
}
|
||||
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
|
||||
{
|
||||
constexpr float one = 1.f;
|
||||
y = type_convert<bhalf_t>(alpha_ / (one + ck::math::exp(-x) * alpha_));
|
||||
};
|
||||
const float alpha_;
|
||||
};
|
||||
|
||||
|
||||
@@ -71,11 +71,13 @@ template <typename ADataType,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopScheduler LoopSched,
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1,
|
||||
typename BComputeDataType_ = AComputeDataType_>
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1,
|
||||
typename BComputeDataType_ = AComputeDataType_,
|
||||
bool DoElementwiseBeforeCShuffle = false>
|
||||
struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
static_assert(!DoElementwiseBeforeCShuffle || NumDTensor == 0);
|
||||
|
||||
using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization;
|
||||
|
||||
@@ -796,37 +798,60 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(n_thread_data_on_block));
|
||||
|
||||
tensor_operation::element_wise::PassThrough pass_through{};
|
||||
const auto& vpgr_to_lds_element_op = [&] {
|
||||
if constexpr(DoElementwiseBeforeCShuffle)
|
||||
{
|
||||
return cde_element_op;
|
||||
}
|
||||
else
|
||||
{
|
||||
return pass_through;
|
||||
}
|
||||
};
|
||||
const auto& lds_to_global_element_op = [&] {
|
||||
if constexpr(!DoElementwiseBeforeCShuffle)
|
||||
{
|
||||
return cde_element_op;
|
||||
}
|
||||
else
|
||||
{
|
||||
return pass_through;
|
||||
}
|
||||
};
|
||||
|
||||
// shuffle: threadwise copy C from VGPR to LDS
|
||||
auto c_thread_copy_vgpr_to_lds =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
M2,
|
||||
I1,
|
||||
M4,
|
||||
I1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
7,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
make_multi_index(0,
|
||||
0,
|
||||
m_thread_data_on_block_idx[I1],
|
||||
n_thread_data_on_block_idx[I1],
|
||||
m_thread_data_on_block_idx[I2],
|
||||
m_thread_data_on_block_idx[I3],
|
||||
m_thread_data_on_block_idx[I4],
|
||||
n_thread_data_on_block_idx[I2]),
|
||||
ck::tensor_operation::element_wise::PassThrough{}};
|
||||
auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
conditional_t<DoElementwiseBeforeCShuffle,
|
||||
CDEElementwiseOperation,
|
||||
tensor_operation::element_wise::PassThrough>,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
M2,
|
||||
I1,
|
||||
M4,
|
||||
I1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
7,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
make_multi_index(0,
|
||||
0,
|
||||
m_thread_data_on_block_idx[I1],
|
||||
n_thread_data_on_block_idx[I1],
|
||||
m_thread_data_on_block_idx[I2],
|
||||
m_thread_data_on_block_idx[I3],
|
||||
m_thread_data_on_block_idx[I4],
|
||||
n_thread_data_on_block_idx[I2]),
|
||||
vpgr_to_lds_element_op()};
|
||||
|
||||
// tuple of reference to C/Ds tensor descriptors
|
||||
const auto c_ds_desc_refs = concat_tuple_of_reference(
|
||||
@@ -860,7 +885,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
Tuple<EDataType>,
|
||||
decltype(c_ds_desc_refs),
|
||||
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
|
||||
CDEElementwiseOperation,
|
||||
conditional_t<!DoElementwiseBeforeCShuffle,
|
||||
CDEElementwiseOperation,
|
||||
tensor_operation::element_wise::PassThrough>,
|
||||
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
|
||||
// support arbitray type
|
||||
Sequence<1,
|
||||
@@ -881,7 +908,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
idx_c_ds_block_begin,
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)),
|
||||
cde_element_op};
|
||||
lds_to_global_element_op()};
|
||||
|
||||
// space filling curve for threadwise C in VGPR before shuffle
|
||||
constexpr auto sfc_c_vgpr =
|
||||
|
||||
@@ -186,6 +186,8 @@ __global__ void
|
||||
/// in global memory. Currently not supported!
|
||||
/// @tparam PermuteB Whether the B input tensor has gridwise-gemm friendly data layout
|
||||
/// in global memory (pre-shuffled).
|
||||
/// @tparam DoElementwiseBeforeCShuffle Whether the cde_elementwise should be performed before or
|
||||
/// after elementwise op.
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
@@ -233,7 +235,8 @@ template <typename ALayout,
|
||||
typename ComputeTypeA = CDataType,
|
||||
typename ComputeTypeB = ComputeTypeA,
|
||||
bool PermuteA = false,
|
||||
bool PermuteB = false>
|
||||
bool PermuteB = false,
|
||||
bool DoElementwiseBeforeCShuffle = false>
|
||||
struct GridwiseGemm_xdl_cshuffle_v3
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
@@ -636,7 +639,10 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
index_t StrideA_,
|
||||
index_t StrideB_,
|
||||
index_t StrideC_,
|
||||
index_t KBatch_)
|
||||
index_t KBatch_,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
: M{M_},
|
||||
N{N_},
|
||||
K{K_},
|
||||
@@ -651,7 +657,10 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
AK0{CalculateAK0Padded(K_, KBatch_)},
|
||||
BK0{CalculateBK0Padded(K_, KBatch_)},
|
||||
MBlock{CalculateMBlock(M_)},
|
||||
NBlock{CalculateNBlock(N_)}
|
||||
NBlock{CalculateNBlock(N_)},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
c_element_op_{c_element_op}
|
||||
{
|
||||
}
|
||||
|
||||
@@ -689,6 +698,9 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
index_t BK0;
|
||||
index_t MBlock;
|
||||
index_t NBlock;
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CElementwiseOperation c_element_op_;
|
||||
};
|
||||
|
||||
// Argument
|
||||
@@ -704,8 +716,20 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
index_t StrideB_,
|
||||
index_t StrideC_,
|
||||
index_t k_batch_,
|
||||
bool is_reduce_ = false)
|
||||
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, k_batch_},
|
||||
bool is_reduce_ = false,
|
||||
AElementwiseOperation a_element_op = AElementwiseOperation{},
|
||||
BElementwiseOperation b_element_op = BElementwiseOperation{},
|
||||
CElementwiseOperation c_element_op = CElementwiseOperation{})
|
||||
: Problem{M_,
|
||||
N_,
|
||||
K_,
|
||||
StrideA_,
|
||||
StrideB_,
|
||||
StrideC_,
|
||||
k_batch_,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op},
|
||||
p_a_grid{p_a_grid_},
|
||||
p_b_grid{p_b_grid_},
|
||||
p_c_grid{p_c_grid_},
|
||||
@@ -1377,10 +1401,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
const AElementwiseOperation a_element_op{};
|
||||
const BElementwiseOperation b_element_op{};
|
||||
const CElementwiseOperation c_element_op{};
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
|
||||
|
||||
@@ -1440,7 +1460,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(0, m_block_data_idx_on_grid, 0),
|
||||
a_element_op,
|
||||
problem.a_element_op_,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
@@ -1471,7 +1491,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
make_multi_index(0, n_block_data_idx_on_grid, 0),
|
||||
b_element_op,
|
||||
problem.b_element_op_,
|
||||
b_block_desc_bk0_n_bk1,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
@@ -1598,42 +1618,67 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(n_thread_data_on_block));
|
||||
|
||||
tensor_operation::element_wise::PassThrough pass_through{};
|
||||
const auto& vpgr_to_lds_element_op = [&] {
|
||||
if constexpr(DoElementwiseBeforeCShuffle)
|
||||
{
|
||||
return problem.c_element_op_;
|
||||
}
|
||||
else
|
||||
{
|
||||
return pass_through;
|
||||
}
|
||||
};
|
||||
const auto& lds_to_global_element_op = [&] {
|
||||
if constexpr(!DoElementwiseBeforeCShuffle)
|
||||
{
|
||||
return problem.c_element_op_;
|
||||
}
|
||||
else
|
||||
{
|
||||
return pass_through;
|
||||
}
|
||||
};
|
||||
|
||||
// shuffle: threadwise copy C from VGPR to LDS
|
||||
auto c_thread_copy_vgpr_to_lds =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
M2,
|
||||
I1,
|
||||
M4,
|
||||
I1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
7,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
make_multi_index(0,
|
||||
0,
|
||||
m_thread_data_on_block_idx[I1],
|
||||
n_thread_data_on_block_idx[I1],
|
||||
m_thread_data_on_block_idx[I2],
|
||||
m_thread_data_on_block_idx[I3],
|
||||
m_thread_data_on_block_idx[I4],
|
||||
n_thread_data_on_block_idx[I2]),
|
||||
ck::tensor_operation::element_wise::PassThrough{}};
|
||||
auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
conditional_t<DoElementwiseBeforeCShuffle,
|
||||
CElementwiseOperation,
|
||||
tensor_operation::element_wise::PassThrough>,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
M2,
|
||||
I1,
|
||||
M4,
|
||||
I1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
7,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
make_multi_index(0,
|
||||
0,
|
||||
m_thread_data_on_block_idx[I1],
|
||||
n_thread_data_on_block_idx[I1],
|
||||
m_thread_data_on_block_idx[I2],
|
||||
m_thread_data_on_block_idx[I3],
|
||||
m_thread_data_on_block_idx[I4],
|
||||
n_thread_data_on_block_idx[I2]),
|
||||
vpgr_to_lds_element_op()};
|
||||
|
||||
// shuffle: blockwise copy C from LDS to global
|
||||
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
|
||||
ThisThreadBlock, // ThreadGroup
|
||||
CElementwiseOperation, // ElementwiseOperation,
|
||||
ThisThreadBlock, // ThreadGroup
|
||||
conditional_t<!DoElementwiseBeforeCShuffle,
|
||||
CElementwiseOperation,
|
||||
tensor_operation::element_wise::PassThrough>,
|
||||
CGlobalMemoryDataOperation, // DstInMemOp,
|
||||
Sequence<1,
|
||||
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
|
||||
@@ -1654,7 +1699,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
make_multi_index(0, 0, 0, 0),
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_multi_index(block_m_id, 0, block_n_id, 0),
|
||||
c_element_op};
|
||||
lds_to_global_element_op()};
|
||||
|
||||
// space filling curve for threadwise C in VGPR
|
||||
constexpr auto sfc_c_vgpr =
|
||||
@@ -1773,10 +1818,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
const AElementwiseOperation a_element_op{};
|
||||
const BElementwiseOperation b_element_op{};
|
||||
const CElementwiseOperation c_element_op{};
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
|
||||
|
||||
@@ -1836,7 +1877,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(0, m_block_data_idx_on_grid, 0),
|
||||
a_element_op,
|
||||
problem.a_element_op_,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
@@ -1867,7 +1908,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
make_multi_index(0, n_block_data_idx_on_grid, 0),
|
||||
b_element_op,
|
||||
problem.b_element_op_,
|
||||
b_block_desc_bk0_n_bk1,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
@@ -2059,7 +2100,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
make_multi_index(0, 0, 0, 0),
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_multi_index(block_m_id, 0, block_n_id, 0),
|
||||
c_element_op};
|
||||
problem.c_element_op_};
|
||||
|
||||
// space filling curve for threadwise C in VGPR
|
||||
constexpr auto sfc_c_vgpr =
|
||||
|
||||
@@ -121,6 +121,7 @@ using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
|
||||
using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu;
|
||||
using AddRelu = ck::tensor_operation::element_wise::AddRelu;
|
||||
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
|
||||
using Clamp = ck::tensor_operation::element_wise::Clamp;
|
||||
using AddSilu = ck::tensor_operation::element_wise::AddSilu;
|
||||
using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd;
|
||||
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
|
||||
|
||||
@@ -34,6 +34,7 @@ using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
|
||||
using Clamp = ck::tensor_operation::element_wise::Clamp;
|
||||
|
||||
static constexpr auto ConvFwdDefault =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
|
||||
@@ -34,6 +34,7 @@ using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
|
||||
using Clamp = ck::tensor_operation::element_wise::Clamp;
|
||||
|
||||
static constexpr auto ConvFwdDefault =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
|
||||
@@ -26,6 +26,7 @@ using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
|
||||
using Clamp = ck::tensor_operation::element_wise::Clamp;
|
||||
|
||||
static constexpr auto ConvFwdDefault =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
|
||||
@@ -34,6 +34,7 @@ using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
|
||||
using Clamp = ck::tensor_operation::element_wise::Clamp;
|
||||
|
||||
static constexpr auto ConvFwdDefault =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
|
||||
@@ -26,6 +26,7 @@ using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
|
||||
using Clamp = ck::tensor_operation::element_wise::Clamp;
|
||||
|
||||
static constexpr auto ConvFwdDefault =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
|
||||
@@ -0,0 +1,140 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
|
||||
#ifdef CK_USE_XDL
|
||||
#include "grouped_convolution_forward_clamp_xdl.inc"
|
||||
#endif
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
template <ck::index_t NumDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename DLayouts,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename DDataTypes,
|
||||
typename AComputeType,
|
||||
typename BComputeType>
|
||||
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
|
||||
NumDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
DLayouts,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
DDataTypes,
|
||||
OutDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::Clamp,
|
||||
AComputeType,
|
||||
BComputeType>>
|
||||
{
|
||||
using DeviceOp =
|
||||
DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
DLayouts,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
DDataTypes,
|
||||
OutDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::Clamp,
|
||||
AComputeType,
|
||||
BComputeType>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#ifdef CK_USE_XDL
|
||||
// layout NHWGC/GKYXC/NHWGK
|
||||
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
|
||||
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, NHWGK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t> &&
|
||||
is_same_v<AComputeType, ck::bhalf_t> &&
|
||||
is_same_v<BComputeType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
// layout NDHWGC/GKZYXC/NDHWGK
|
||||
if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWGC> &&
|
||||
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, NDHWGK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t> &&
|
||||
is_same_v<AComputeType, ck::bhalf_t> &&
|
||||
is_same_v<BComputeType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#endif // CK_USE_XDL
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,242 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#ifdef CK_ENABLE_BF16
|
||||
|
||||
void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp>>>& instances);
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,16 @@
|
||||
# ONLY XDL_KERNELS
|
||||
add_instance_library(device_grouped_conv2d_fwd_clamp_instance
|
||||
xdl/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
|
||||
xdl/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instance.cpp
|
||||
|
||||
xdl/large_tensor/device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
|
||||
|
||||
xdl/merged_groups/device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
|
||||
|
||||
xdl/mem/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp
|
||||
xdl/mem/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp
|
||||
|
||||
xdl/comp/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp
|
||||
xdl/comp/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp
|
||||
xdl/comp/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.cpp
|
||||
)
|
||||
@@ -0,0 +1,67 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp>>>& instances)
|
||||
{
|
||||
if(ck::get_device_name() == "gfx950")
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,61 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,67 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp>>>& instances)
|
||||
{
|
||||
if(ck::get_device_name() != "gfx950")
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,60 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_16x16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_16x16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_16x16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,60 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,41 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_large_tensor_bf16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,63 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Interwave,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Interwave,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Interwave,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,63 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Intrawave,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Intrawave,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Intrawave,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,80 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp>>>& instances)
|
||||
{
|
||||
if(ck::get_device_name() == "gfx950")
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_merged_groups_bf16_instances_2x<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_merged_groups_bf16_instances_2x<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
ConvFwd3x3,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_merged_groups_bf16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_merged_groups_bf16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
ConvFwd3x3,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,16 @@
|
||||
# ONLY XDL_KERNELS
|
||||
set(GROUPED_CONV3D_FWD
|
||||
xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.cpp
|
||||
|
||||
xdl/large_tensor/device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
|
||||
|
||||
xdl/merged_groups/device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
|
||||
|
||||
xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp
|
||||
xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp
|
||||
|
||||
xdl/comp/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp
|
||||
)
|
||||
|
||||
add_instance_library(device_grouped_conv3d_fwd_clamp_instance ${GROUPED_CONV3D_FWD})
|
||||
@@ -0,0 +1,127 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
|
||||
if(ck::get_device_name() != "gfx950")
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
}
|
||||
|
||||
if(ck::get_device_name() == "gfx950")
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,58 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_16x16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_16x16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_16x16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,58 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,41 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_large_tensor_bf16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,61 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_mem_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Interwave,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_mem_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Interwave,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_mem_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Interwave,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,61 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_mem_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Intrawave,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_mem_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Intrawave,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_mem_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Intrawave,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,51 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_merged_groups_bf16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_merged_groups_bf16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwd3x3,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -25,6 +25,28 @@
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
|
||||
// NOTE: Usage of NHWGK layout for GK bias is a workaround. This test is to
|
||||
// just keep such implementation valid.
|
||||
// TODO: Add possiblity to pass GK layout and GK lengths for bias and reuse
|
||||
// the same instances.
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
auto get_bias_desc(ck::index_t G, ck::index_t K)
|
||||
{
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
return HostTensorDescriptor({G, 1, K, 1}, {K, 0, 1, 0});
|
||||
}
|
||||
else if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
return HostTensorDescriptor({G, 1, K, 1, 1}, {K, 0, 1, 0, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({G, 1, K, 1, 1, 1}, {K, 0, 1, 0, 0, 0});
|
||||
}
|
||||
}
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
@@ -34,7 +56,8 @@ template <ck::index_t NDimSpatial,
|
||||
typename OutDataType,
|
||||
typename AComputeType = InDataType,
|
||||
typename BComputeType = AComputeType,
|
||||
typename IndexType = ck::index_t>
|
||||
typename IndexType = ck::index_t,
|
||||
bool BiasGK = false>
|
||||
bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
@@ -61,12 +84,16 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification,
|
||||
const auto out_g_n_k_wos_desc =
|
||||
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(conv_param);
|
||||
|
||||
const index_t G = conv_param.G_;
|
||||
const index_t K = conv_param.K_;
|
||||
|
||||
std::array<IndexType, NDimSpatial + 3> a_g_n_c_wis_lengths{};
|
||||
std::array<IndexType, NDimSpatial + 3> a_g_n_c_wis_strides{};
|
||||
std::array<IndexType, NDimSpatial + 3> b_g_k_c_xs_lengths{};
|
||||
std::array<IndexType, NDimSpatial + 3> b_g_k_c_xs_strides{};
|
||||
std::array<IndexType, NDimSpatial + 3> e_g_n_k_wos_lengths{};
|
||||
std::array<IndexType, NDimSpatial + 3> e_g_n_k_wos_strides{};
|
||||
std::array<IndexType, NDimSpatial + 3> d_g_n_k_wos_strides{};
|
||||
std::array<IndexType, NDimSpatial> conv_filter_strides{};
|
||||
std::array<IndexType, NDimSpatial> conv_filter_dilations{};
|
||||
std::array<IndexType, NDimSpatial> input_left_pads{};
|
||||
@@ -80,6 +107,7 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification,
|
||||
copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides);
|
||||
copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths);
|
||||
copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides);
|
||||
copy(out_g_n_k_wos_desc.GetStrides(), d_g_n_k_wos_strides);
|
||||
copy(conv_param.conv_filter_strides_, conv_filter_strides);
|
||||
copy(conv_param.conv_filter_dilations_, conv_filter_dilations);
|
||||
copy(conv_param.input_left_pads_, input_left_pads);
|
||||
@@ -89,7 +117,8 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification,
|
||||
Tensor<WeiDataType> weight(wei_g_k_c_xs_desc);
|
||||
Tensor<OutDataType> host_output(out_g_n_k_wos_desc);
|
||||
Tensor<OutDataType> device_output(out_g_n_k_wos_desc);
|
||||
Tensor<OutDataType> bias(out_g_n_k_wos_desc);
|
||||
const auto bias_desc = BiasGK ? get_bias_desc<NDimSpatial>(G, K) : out_g_n_k_wos_desc;
|
||||
Tensor<OutDataType> bias(bias_desc);
|
||||
|
||||
std::cout << "input: " << input.mDesc << std::endl;
|
||||
std::cout << "weight: " << weight.mDesc << std::endl;
|
||||
@@ -113,7 +142,11 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification,
|
||||
DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpaceSize());
|
||||
DeviceMem wei_device_buf(sizeof(WeiDataType) * weight.mDesc.GetElementSpaceSize());
|
||||
DeviceMem out_device_buf(sizeof(OutDataType) * device_output.mDesc.GetElementSpaceSize());
|
||||
DeviceMem bias_device_buf(sizeof(OutDataType) * bias.mDesc.GetElementSpaceSize());
|
||||
|
||||
const std::size_t bias_dev_buf_size =
|
||||
BiasGK ? sizeof(OutDataType) * G * K
|
||||
: sizeof(OutDataType) * device_output.mDesc.GetElementSpaceSize();
|
||||
DeviceMem bias_device_buf(bias_dev_buf_size);
|
||||
|
||||
in_device_buf.ToDevice(input.mData.data());
|
||||
wei_device_buf.ToDevice(weight.mData.data());
|
||||
@@ -244,6 +277,16 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification,
|
||||
|
||||
std::cout << "ckProfiler found " << op_ptrs.size() << " instances" << std::endl;
|
||||
|
||||
if constexpr(BiasGK)
|
||||
{
|
||||
constexpr ck::index_t spatial_offset = 3;
|
||||
d_g_n_k_wos_strides[1] = 0;
|
||||
for(int i = 0; i < NDimSpatial; i++)
|
||||
{
|
||||
d_g_n_k_wos_strides[i + spatial_offset] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
for(auto& op_ptr : op_ptrs)
|
||||
{
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(in_device_buf.GetDeviceBuffer(),
|
||||
@@ -255,7 +298,7 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
{e_g_n_k_wos_lengths},
|
||||
{e_g_n_k_wos_strides},
|
||||
{d_g_n_k_wos_strides},
|
||||
e_g_n_k_wos_lengths,
|
||||
e_g_n_k_wos_strides,
|
||||
conv_filter_strides,
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp"
|
||||
|
||||
#include "ck/library/utility/algorithm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
@@ -34,20 +35,20 @@ template <ck::index_t NDimSpatial,
|
||||
typename OutDataType,
|
||||
typename AComputeType = InDataType,
|
||||
typename BComputeType = AComputeType,
|
||||
typename IndexType = ck::index_t>
|
||||
typename IndexType = ck::index_t,
|
||||
typename OutElementOp = ck::tensor_operation::element_wise::PassThrough>
|
||||
bool profile_grouped_conv_fwd_impl(int do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
bool time_kernel,
|
||||
const ck::utils::conv::ConvParam& conv_param)
|
||||
const ck::utils::conv::ConvParam& conv_param,
|
||||
const OutElementOp out_element_op = OutElementOp{})
|
||||
{
|
||||
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
const auto in_element_op = InElementOp{};
|
||||
const auto wei_element_op = WeiElementOp{};
|
||||
const auto out_element_op = OutElementOp{};
|
||||
|
||||
const auto in_g_n_c_wis_desc =
|
||||
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(conv_param);
|
||||
|
||||
@@ -208,6 +208,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-in_layout",
|
||||
"-I",
|
||||
"--in_layout",
|
||||
"--I",
|
||||
default="NCHW",
|
||||
type=str,
|
||||
required=False,
|
||||
@@ -216,6 +218,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-forw",
|
||||
"-F",
|
||||
"--forw",
|
||||
"--F",
|
||||
default=0,
|
||||
type=int,
|
||||
required=False,
|
||||
@@ -231,6 +235,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-spatial_dim",
|
||||
"-_",
|
||||
"--spatial_dim",
|
||||
"--_",
|
||||
default=2,
|
||||
type=int,
|
||||
required=False,
|
||||
@@ -239,6 +245,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-batchsize",
|
||||
"-n",
|
||||
"--batchsize",
|
||||
"--n",
|
||||
default=100,
|
||||
type=int,
|
||||
required=False,
|
||||
@@ -247,6 +255,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-in_channels",
|
||||
"-c",
|
||||
"--in_channels",
|
||||
"--c",
|
||||
default=3,
|
||||
type=int,
|
||||
required=False,
|
||||
@@ -255,6 +265,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-in_d",
|
||||
"-!",
|
||||
"--in_d",
|
||||
"--!",
|
||||
default=32,
|
||||
type=int,
|
||||
required=False,
|
||||
@@ -263,6 +275,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-in_h",
|
||||
"-H",
|
||||
"--in_h",
|
||||
"--H",
|
||||
default=32,
|
||||
type=int,
|
||||
required=False,
|
||||
@@ -271,6 +285,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-in_w",
|
||||
"-W",
|
||||
"--in_w",
|
||||
"--W",
|
||||
default=32,
|
||||
type=int,
|
||||
required=False,
|
||||
@@ -279,6 +295,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-out_channels",
|
||||
"-k",
|
||||
"--out_channels",
|
||||
"--k",
|
||||
default=32,
|
||||
type=int,
|
||||
required=False,
|
||||
@@ -287,6 +305,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-fil_d",
|
||||
"-@",
|
||||
"--fil_d",
|
||||
"--@",
|
||||
default=3,
|
||||
type=int,
|
||||
required=False,
|
||||
@@ -295,6 +315,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-fil_h",
|
||||
"-y",
|
||||
"--fil_h",
|
||||
"--y",
|
||||
default=3,
|
||||
type=int,
|
||||
required=False,
|
||||
@@ -303,6 +325,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-fil_w",
|
||||
"-x",
|
||||
"--fil_w",
|
||||
"--x",
|
||||
default=3,
|
||||
type=int,
|
||||
required=False,
|
||||
@@ -311,6 +335,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-conv_stride_d",
|
||||
"-#",
|
||||
"--conv_stride_d",
|
||||
"--#",
|
||||
default=1,
|
||||
type=int,
|
||||
required=False,
|
||||
@@ -319,6 +345,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-conv_stride_h",
|
||||
"-u",
|
||||
"--conv_stride_h",
|
||||
"--u",
|
||||
default=1,
|
||||
type=int,
|
||||
required=False,
|
||||
@@ -327,6 +355,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-conv_stride_w",
|
||||
"-v",
|
||||
"--conv_stride_w",
|
||||
"--v",
|
||||
default=1,
|
||||
type=int,
|
||||
required=False,
|
||||
@@ -335,6 +365,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-pad_d",
|
||||
"-$",
|
||||
"--pad_d",
|
||||
"--$",
|
||||
default=1,
|
||||
type=int,
|
||||
required=False,
|
||||
@@ -343,6 +375,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-pad_h",
|
||||
"-p",
|
||||
"--pad_h",
|
||||
"--p",
|
||||
default=1,
|
||||
type=int,
|
||||
required=False,
|
||||
@@ -351,6 +385,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-pad_w",
|
||||
"-q",
|
||||
"--pad_w",
|
||||
"--q",
|
||||
default=1,
|
||||
type=int,
|
||||
required=False,
|
||||
@@ -359,6 +395,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-verify",
|
||||
"-V",
|
||||
"--verify",
|
||||
"--V",
|
||||
default=1,
|
||||
type=int,
|
||||
required=False,
|
||||
@@ -367,6 +405,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-time",
|
||||
"-t",
|
||||
"--time",
|
||||
"--t",
|
||||
default=0,
|
||||
type=int,
|
||||
required=False,
|
||||
@@ -375,6 +415,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-dilation_d",
|
||||
"-^",
|
||||
"--dilation_d",
|
||||
"--^",
|
||||
default=1,
|
||||
type=int,
|
||||
required=False,
|
||||
@@ -383,6 +425,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-dilation_h",
|
||||
"-l",
|
||||
"--dilation_h",
|
||||
"--l",
|
||||
default=1,
|
||||
type=int,
|
||||
required=False,
|
||||
@@ -391,6 +435,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-dilation_w",
|
||||
"-j",
|
||||
"--dilation_w",
|
||||
"--j",
|
||||
default=1,
|
||||
type=int,
|
||||
required=False,
|
||||
@@ -399,6 +445,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-group_count",
|
||||
"-g",
|
||||
"--group_count",
|
||||
"--g",
|
||||
type=int,
|
||||
default=1,
|
||||
required=False,
|
||||
|
||||
@@ -252,7 +252,7 @@ add_subdirectory(reduce)
|
||||
add_subdirectory(convnd_fwd)
|
||||
add_subdirectory(convnd_bwd_data)
|
||||
add_subdirectory(grouped_convnd_fwd)
|
||||
add_subdirectory(grouped_convnd_fwd_bias_clamp)
|
||||
add_subdirectory(grouped_convnd_fwd_activation)
|
||||
add_subdirectory(grouped_convnd_bwd_weight)
|
||||
add_subdirectory(block_to_ctile_map)
|
||||
add_subdirectory(softmax)
|
||||
|
||||
10
test/grouped_convnd_fwd_activation/CMakeLists.txt
Normal file
10
test/grouped_convnd_fwd_activation/CMakeLists.txt
Normal file
@@ -0,0 +1,10 @@
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_gtest_executable(test_grouped_convnd_fwd_bias_clamp test_grouped_convnd_fwd_bias_clamp.cpp)
|
||||
target_link_libraries(test_grouped_convnd_fwd_bias_clamp PRIVATE utility device_grouped_conv2d_fwd_bias_clamp_instance device_grouped_conv3d_fwd_bias_clamp_instance)
|
||||
|
||||
add_gtest_executable(test_grouped_convnd_fwd_gk_bias_clamp test_grouped_convnd_fwd_gk_bias_clamp.cpp)
|
||||
target_link_libraries(test_grouped_convnd_fwd_gk_bias_clamp PRIVATE utility device_grouped_conv2d_fwd_bias_clamp_instance device_grouped_conv3d_fwd_bias_clamp_instance)
|
||||
|
||||
add_gtest_executable(test_grouped_convnd_fwd_clamp test_grouped_convnd_fwd_clamp.cpp)
|
||||
target_link_libraries(test_grouped_convnd_fwd_clamp PRIVATE utility device_grouped_conv2d_fwd_clamp_instance device_grouped_conv3d_fwd_clamp_instance)
|
||||
endif()
|
||||
@@ -41,7 +41,8 @@ class TestGroupedConvndFwd : public ::testing::Test
|
||||
DataType,
|
||||
DataType,
|
||||
DataType,
|
||||
IndexType>(
|
||||
IndexType,
|
||||
false /*BiasGK*/>(
|
||||
true, // do_verification
|
||||
1, // init_method: integer value
|
||||
false, // do_log
|
||||
@@ -0,0 +1,95 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <initializer_list>
|
||||
#include <vector>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "profiler/profile_grouped_conv_fwd_impl.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
using Clamp = ck::tensor_operation::element_wise::Clamp;
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedConvndFwd : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using DataType = std::tuple_element_t<0, Tuple>;
|
||||
using InLayout = std::tuple_element_t<1, Tuple>;
|
||||
using WeiLayout = std::tuple_element_t<2, Tuple>;
|
||||
using OutLayout = std::tuple_element_t<3, Tuple>;
|
||||
using IndexType = ck::index_t;
|
||||
|
||||
std::vector<ck::utils::conv::ConvParam> conv_params;
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
void Run()
|
||||
{
|
||||
EXPECT_FALSE(conv_params.empty());
|
||||
bool pass = true;
|
||||
Clamp out_element_op{0.f, 256.f};
|
||||
for(auto& param : conv_params)
|
||||
{
|
||||
pass = pass && ck::profiler::profile_grouped_conv_fwd_impl<NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
DataType,
|
||||
DataType,
|
||||
DataType,
|
||||
DataType,
|
||||
DataType,
|
||||
IndexType,
|
||||
Clamp>(
|
||||
true, // do_verification
|
||||
1, // init_method: integer value
|
||||
false, // do_log
|
||||
false, // time_kernel
|
||||
param,
|
||||
out_element_op);
|
||||
}
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
};
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using KernelTypes2d = ::testing::Types<std::tuple<ck::bhalf_t, NHWGC, GKYXC, NHWGK>>;
|
||||
|
||||
using KernelTypes3d = ::testing::Types<std::tuple<ck::bhalf_t, NDHWGC, GKZYXC, NDHWGK>>;
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedConvndFwd2d : public TestGroupedConvndFwd<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedConvndFwd3d : public TestGroupedConvndFwd<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
TYPED_TEST_SUITE(TestGroupedConvndFwd2d, KernelTypes2d);
|
||||
TYPED_TEST_SUITE(TestGroupedConvndFwd3d, KernelTypes3d);
|
||||
|
||||
TYPED_TEST(TestGroupedConvndFwd2d, Test2D)
|
||||
{
|
||||
this->conv_params.clear();
|
||||
this->conv_params.push_back(
|
||||
{2, 2, 32, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}});
|
||||
this->conv_params.push_back(
|
||||
{2, 2, 32, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
this->template Run<2>();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGroupedConvndFwd3d, Test3D)
|
||||
{
|
||||
this->conv_params.clear();
|
||||
this->conv_params.push_back(
|
||||
{3, 2, 32, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
|
||||
this->conv_params.push_back(
|
||||
{3, 2, 32, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
|
||||
this->template Run<3>();
|
||||
}
|
||||
@@ -0,0 +1,93 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <initializer_list>
|
||||
#include <vector>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedConvndFwd : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using DataType = std::tuple_element_t<0, Tuple>;
|
||||
using InLayout = std::tuple_element_t<1, Tuple>;
|
||||
using WeiLayout = std::tuple_element_t<2, Tuple>;
|
||||
using OutLayout = std::tuple_element_t<3, Tuple>;
|
||||
using IndexType = ck::index_t;
|
||||
|
||||
std::vector<ck::utils::conv::ConvParam> conv_params;
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
void Run()
|
||||
{
|
||||
EXPECT_FALSE(conv_params.empty());
|
||||
bool pass = true;
|
||||
for(auto& param : conv_params)
|
||||
{
|
||||
pass = pass && ck::profiler::profile_grouped_conv_fwd_bias_clamp_impl<NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
DataType,
|
||||
DataType,
|
||||
DataType,
|
||||
DataType,
|
||||
DataType,
|
||||
IndexType,
|
||||
true /*BiasGK*/>(
|
||||
true, // do_verification
|
||||
1, // init_method: integer value
|
||||
false, // do_log
|
||||
false, // time_kernel
|
||||
param);
|
||||
}
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
};
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using KernelTypes2d = ::testing::Types<std::tuple<ck::bhalf_t, NHWGC, GKYXC, NHWGK>>;
|
||||
|
||||
using KernelTypes3d = ::testing::Types<std::tuple<ck::bhalf_t, NDHWGC, GKZYXC, NDHWGK>>;
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedConvndFwd2d : public TestGroupedConvndFwd<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedConvndFwd3d : public TestGroupedConvndFwd<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
TYPED_TEST_SUITE(TestGroupedConvndFwd2d, KernelTypes2d);
|
||||
TYPED_TEST_SUITE(TestGroupedConvndFwd3d, KernelTypes3d);
|
||||
|
||||
TYPED_TEST(TestGroupedConvndFwd2d, Test2D)
|
||||
{
|
||||
this->conv_params.clear();
|
||||
this->conv_params.push_back(
|
||||
{2, 2, 32, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}});
|
||||
this->conv_params.push_back(
|
||||
{2, 2, 32, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
this->template Run<2>();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGroupedConvndFwd3d, Test3D)
|
||||
{
|
||||
this->conv_params.clear();
|
||||
this->conv_params.push_back(
|
||||
{3, 2, 32, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
|
||||
this->conv_params.push_back(
|
||||
{3, 2, 32, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
|
||||
this->template Run<3>();
|
||||
}
|
||||
@@ -1,4 +0,0 @@
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_gtest_executable(test_grouped_convnd_fwd_bias_clamp test_grouped_convnd_fwd_bias_clamp.cpp)
|
||||
target_link_libraries(test_grouped_convnd_fwd_bias_clamp PRIVATE utility device_grouped_conv2d_fwd_bias_clamp_instance device_grouped_conv3d_fwd_bias_clamp_instance)
|
||||
endif()
|
||||
Reference in New Issue
Block a user