mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
Grouped convolution forward with clamp (#2334)
* Grouped convolution forward with clamp * Optimize clamp * unary fixes * test gk bias * Revert "test gk bias" This reverts commit8e42e29d7b. * Revert "Revert "test gk bias"" This reverts commite73c0550ce. * workaround comment
This commit is contained in:
@@ -311,8 +311,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
|
||||
static_assert(NumGroupsToMerge >= 1);
|
||||
|
||||
static constexpr bool isMultiA = is_detected<is_tuple, ADataType>::value;
|
||||
static constexpr bool isMultiB = is_detected<is_tuple, BDataType>::value;
|
||||
static constexpr bool isMultiA = is_detected<is_tuple, ADataType>::value;
|
||||
static constexpr bool isMultiB = is_detected<is_tuple, BDataType>::value;
|
||||
static constexpr bool isMultiAB = isMultiA || isMultiB;
|
||||
|
||||
// NGCHW is not supported for multiAB
|
||||
static_assert(!(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
@@ -323,6 +324,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
static constexpr index_t NumBTensor = GetNumABTensors<isMultiB, BDataType>();
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static constexpr bool DoElementwiseBeforeCShuffle =
|
||||
NumDTensor == 0 && !isMultiAB && is_same_v<EDataType, bhalf_t> &&
|
||||
!is_same_v<CDEElementwiseOperation, tensor_operation::element_wise::PassThrough>;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
@@ -465,7 +470,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
|
||||
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
|
||||
BComputeDataType
|
||||
BComputeDataType, DoElementwiseBeforeCShuffle
|
||||
// Use appropriate gridwise gemm
|
||||
using GridwiseGemm = std::conditional_t<
|
||||
isMultiA || isMultiB,
|
||||
|
||||
@@ -279,6 +279,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
static constexpr bool isMultiD = DsDataType::Size() > 0;
|
||||
static constexpr bool isMultiABD = isMultiA || isMultiB || isMultiD;
|
||||
|
||||
static constexpr bool DoElementwiseBeforeCShuffle =
|
||||
!isMultiABD && is_same_v<EDataType, bhalf_t> &&
|
||||
!is_same_v<CDEElementwiseOperation, tensor_operation::element_wise::PassThrough>;
|
||||
|
||||
static constexpr index_t NumATensor = GetNumABTensors<isMultiA, ADataType>();
|
||||
static constexpr index_t NumBTensor = GetNumABTensors<isMultiB, BDataType>();
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
@@ -412,7 +416,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
|
||||
CDEBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, \
|
||||
AComputeDataType, BComputeDataType
|
||||
AComputeDataType, BComputeDataType, false, false, DoElementwiseBeforeCShuffle
|
||||
|
||||
// Use appropriate gridwise gemm
|
||||
using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3<GridwiseGemmV3TemplateParams>;
|
||||
@@ -780,8 +784,20 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
sizeof(EDataType);
|
||||
}
|
||||
|
||||
typename GridwiseGemm::Argument gemm_arg{
|
||||
p_a_grid, p_b_grid, p_e_grid, GemmM, GemmN, GemmK, I0, I0, I0, I1};
|
||||
typename GridwiseGemm::Argument gemm_arg{p_a_grid,
|
||||
p_b_grid,
|
||||
p_e_grid,
|
||||
GemmM,
|
||||
GemmN,
|
||||
GemmK,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
I1,
|
||||
false,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_};
|
||||
|
||||
const auto Run = [&](const auto& kernel) {
|
||||
if(stream_config.flush_cache)
|
||||
|
||||
@@ -192,6 +192,9 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
static constexpr index_t MaxGemmsNum = 32;
|
||||
static constexpr bool DoElementwiseBeforeCShuffle =
|
||||
NumDTensor == 0 && is_same_v<EDataType, bhalf_t> &&
|
||||
!is_same_v<CDEElementwiseOperation, tensor_operation::element_wise::PassThrough>;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
@@ -361,7 +364,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
|
||||
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
|
||||
AComputeDataType
|
||||
AComputeDataType, DoElementwiseBeforeCShuffle
|
||||
// Use appropriate gridwise gemm
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmTemplateParameters>;
|
||||
|
||||
|
||||
@@ -730,6 +730,15 @@ struct UnaryAbs
|
||||
{
|
||||
y = ck::type_convert<f8_t>(ck::math::abs(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
|
||||
{
|
||||
y = ck::type_convert<bhalf_t>(ck::math::abs(x));
|
||||
};
|
||||
};
|
||||
|
||||
struct UnarySqrt
|
||||
@@ -744,6 +753,79 @@ struct UnarySqrt
|
||||
};
|
||||
};
|
||||
|
||||
struct Clamp
|
||||
{
|
||||
Clamp(float floor = 0.f, float ceil = NumericLimits<float>::Max())
|
||||
: floor_(floor), ceil_(ceil){};
|
||||
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<float, float>(float& y, const float& x) const
|
||||
{
|
||||
const float& a = x;
|
||||
y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<double, double>(double& y, const double& x) const
|
||||
{
|
||||
const double& a = x;
|
||||
y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<half_t, half_t>(half_t& y, const half_t& x) const
|
||||
{
|
||||
const float a = type_convert<half_t>(x);
|
||||
const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
|
||||
y = type_convert<half_t>(b);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<half_t, float>(half_t& y, const float& x) const
|
||||
{
|
||||
const float& a = x;
|
||||
const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
|
||||
y = type_convert<half_t>(b);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
|
||||
{
|
||||
const float& a = x;
|
||||
const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
|
||||
y = type_convert<bhalf_t>(b);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<bhalf_t, bhalf_t>(bhalf_t& y,
|
||||
const bhalf_t& x) const
|
||||
{
|
||||
const float a = type_convert<float>(x);
|
||||
const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
|
||||
y = type_convert<bhalf_t>(b);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<int, int>(int& y, const int& x) const
|
||||
{
|
||||
const int8_t& a = x;
|
||||
y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
|
||||
{
|
||||
const int8_t& a = x;
|
||||
y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
|
||||
};
|
||||
|
||||
const float floor_;
|
||||
const float ceil_;
|
||||
};
|
||||
|
||||
struct Relu
|
||||
{
|
||||
template <typename T>
|
||||
@@ -756,6 +838,9 @@ struct Relu
|
||||
y = x > 0 ? x : 0;
|
||||
}
|
||||
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const
|
||||
{
|
||||
@@ -763,6 +848,13 @@ struct Relu
|
||||
float y_f32 = x_f32 > 0 ? x_f32 : 0;
|
||||
y = type_convert<bhalf_t>(y_f32);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
|
||||
{
|
||||
float y_f32 = x > 0 ? x : 0;
|
||||
y = type_convert<bhalf_t>(y_f32);
|
||||
};
|
||||
};
|
||||
|
||||
// Fast GeLU
|
||||
@@ -915,6 +1007,16 @@ struct Sigmoid
|
||||
constexpr T one = type_convert<T>(1);
|
||||
y = one / (one + math::exp(-x));
|
||||
};
|
||||
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
|
||||
{
|
||||
constexpr float one = 1.f;
|
||||
y = type_convert<bhalf_t>(one / (one + math::exp(-x)));
|
||||
};
|
||||
};
|
||||
|
||||
struct Silu
|
||||
@@ -942,6 +1044,15 @@ struct TanH
|
||||
|
||||
y = math::tanh(x);
|
||||
};
|
||||
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
|
||||
{
|
||||
y = type_convert<bhalf_t>(math::tanh(x));
|
||||
};
|
||||
};
|
||||
|
||||
struct ACos
|
||||
@@ -1201,6 +1312,13 @@ struct Swish
|
||||
y = type_convert<Y>(x / (1.f + math::exp(bx)));
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
|
||||
{
|
||||
float bx = -beta_ * x;
|
||||
y = type_convert<bhalf_t>(x / (1.f + math::exp(bx)));
|
||||
};
|
||||
|
||||
const float beta_;
|
||||
};
|
||||
|
||||
@@ -1219,6 +1337,16 @@ struct SoftRelu
|
||||
constexpr T one = type_convert<T>(1);
|
||||
y = math::log(one + math::exp(x * casted_alpha)) / casted_alpha;
|
||||
}
|
||||
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
|
||||
{
|
||||
constexpr float one = 1.f;
|
||||
y = type_convert<bhalf_t>(math::log(one + math::exp(x * alpha_)) / alpha_);
|
||||
};
|
||||
const float alpha_;
|
||||
};
|
||||
|
||||
@@ -1240,6 +1368,17 @@ struct Power
|
||||
T shifted_scaled_x = casted_alpha + casted_beta * x;
|
||||
y = math::pow(shifted_scaled_x, casted_gamma);
|
||||
}
|
||||
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
|
||||
{
|
||||
const float shifted_scaled_x = alpha_ + beta_ * x;
|
||||
y = type_convert<bhalf_t>(math::pow(shifted_scaled_x, gamma_));
|
||||
};
|
||||
|
||||
const float alpha_;
|
||||
const float beta_;
|
||||
const float gamma_;
|
||||
@@ -1260,6 +1399,16 @@ struct ClippedRelu
|
||||
T casted_beta = type_convert<T>(beta_);
|
||||
y = math::min(casted_beta, math::max(casted_alpha, x));
|
||||
}
|
||||
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
|
||||
{
|
||||
y = type_convert<bhalf_t>(math::min(beta_, math::max(alpha_, x)));
|
||||
};
|
||||
|
||||
const float alpha_;
|
||||
const float beta_;
|
||||
};
|
||||
@@ -1278,6 +1427,16 @@ struct LeakyRelu
|
||||
T casted_alpha = type_convert<T>(alpha_);
|
||||
y = x >= 0 ? x : x * casted_alpha;
|
||||
}
|
||||
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
|
||||
{
|
||||
y = type_convert<bhalf_t>(x >= 0 ? x : x * alpha_);
|
||||
};
|
||||
|
||||
const float alpha_;
|
||||
};
|
||||
|
||||
@@ -1295,6 +1454,16 @@ struct Elu
|
||||
T casted_alpha = type_convert<T>(alpha_);
|
||||
y = x > 0 ? x : casted_alpha * math::expm1(x);
|
||||
}
|
||||
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
|
||||
{
|
||||
y = type_convert<bhalf_t>(x > 0 ? x : alpha_ * math::expm1(x));
|
||||
};
|
||||
|
||||
const float alpha_;
|
||||
};
|
||||
|
||||
@@ -1313,6 +1482,16 @@ struct Logistic
|
||||
constexpr T one = type_convert<T>(1);
|
||||
y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha);
|
||||
}
|
||||
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
|
||||
{
|
||||
constexpr float one = 1.f;
|
||||
y = type_convert<bhalf_t>(alpha_ / (one + ck::math::exp(-x) * alpha_));
|
||||
};
|
||||
const float alpha_;
|
||||
};
|
||||
|
||||
|
||||
@@ -71,11 +71,13 @@ template <typename ADataType,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopScheduler LoopSched,
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1,
|
||||
typename BComputeDataType_ = AComputeDataType_>
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1,
|
||||
typename BComputeDataType_ = AComputeDataType_,
|
||||
bool DoElementwiseBeforeCShuffle = false>
|
||||
struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
static_assert(!DoElementwiseBeforeCShuffle || NumDTensor == 0);
|
||||
|
||||
using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization;
|
||||
|
||||
@@ -796,37 +798,60 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(n_thread_data_on_block));
|
||||
|
||||
tensor_operation::element_wise::PassThrough pass_through{};
|
||||
const auto& vpgr_to_lds_element_op = [&] {
|
||||
if constexpr(DoElementwiseBeforeCShuffle)
|
||||
{
|
||||
return cde_element_op;
|
||||
}
|
||||
else
|
||||
{
|
||||
return pass_through;
|
||||
}
|
||||
};
|
||||
const auto& lds_to_global_element_op = [&] {
|
||||
if constexpr(!DoElementwiseBeforeCShuffle)
|
||||
{
|
||||
return cde_element_op;
|
||||
}
|
||||
else
|
||||
{
|
||||
return pass_through;
|
||||
}
|
||||
};
|
||||
|
||||
// shuffle: threadwise copy C from VGPR to LDS
|
||||
auto c_thread_copy_vgpr_to_lds =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
M2,
|
||||
I1,
|
||||
M4,
|
||||
I1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
7,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
make_multi_index(0,
|
||||
0,
|
||||
m_thread_data_on_block_idx[I1],
|
||||
n_thread_data_on_block_idx[I1],
|
||||
m_thread_data_on_block_idx[I2],
|
||||
m_thread_data_on_block_idx[I3],
|
||||
m_thread_data_on_block_idx[I4],
|
||||
n_thread_data_on_block_idx[I2]),
|
||||
ck::tensor_operation::element_wise::PassThrough{}};
|
||||
auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
conditional_t<DoElementwiseBeforeCShuffle,
|
||||
CDEElementwiseOperation,
|
||||
tensor_operation::element_wise::PassThrough>,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
M2,
|
||||
I1,
|
||||
M4,
|
||||
I1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
7,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
make_multi_index(0,
|
||||
0,
|
||||
m_thread_data_on_block_idx[I1],
|
||||
n_thread_data_on_block_idx[I1],
|
||||
m_thread_data_on_block_idx[I2],
|
||||
m_thread_data_on_block_idx[I3],
|
||||
m_thread_data_on_block_idx[I4],
|
||||
n_thread_data_on_block_idx[I2]),
|
||||
vpgr_to_lds_element_op()};
|
||||
|
||||
// tuple of reference to C/Ds tensor descriptors
|
||||
const auto c_ds_desc_refs = concat_tuple_of_reference(
|
||||
@@ -860,7 +885,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
Tuple<EDataType>,
|
||||
decltype(c_ds_desc_refs),
|
||||
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
|
||||
CDEElementwiseOperation,
|
||||
conditional_t<!DoElementwiseBeforeCShuffle,
|
||||
CDEElementwiseOperation,
|
||||
tensor_operation::element_wise::PassThrough>,
|
||||
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
|
||||
// support arbitray type
|
||||
Sequence<1,
|
||||
@@ -881,7 +908,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
idx_c_ds_block_begin,
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)),
|
||||
cde_element_op};
|
||||
lds_to_global_element_op()};
|
||||
|
||||
// space filling curve for threadwise C in VGPR before shuffle
|
||||
constexpr auto sfc_c_vgpr =
|
||||
|
||||
@@ -186,6 +186,8 @@ __global__ void
|
||||
/// in global memory. Currently not supported!
|
||||
/// @tparam PermuteB Whether the B input tensor has gridwise-gemm friendly data layout
|
||||
/// in global memory (pre-shuffled).
|
||||
/// @tparam DoElementwiseBeforeCShuffle Whether the cde_elementwise should be performed before or
|
||||
/// after elementwise op.
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
@@ -233,7 +235,8 @@ template <typename ALayout,
|
||||
typename ComputeTypeA = CDataType,
|
||||
typename ComputeTypeB = ComputeTypeA,
|
||||
bool PermuteA = false,
|
||||
bool PermuteB = false>
|
||||
bool PermuteB = false,
|
||||
bool DoElementwiseBeforeCShuffle = false>
|
||||
struct GridwiseGemm_xdl_cshuffle_v3
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
@@ -636,7 +639,10 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
index_t StrideA_,
|
||||
index_t StrideB_,
|
||||
index_t StrideC_,
|
||||
index_t KBatch_)
|
||||
index_t KBatch_,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
: M{M_},
|
||||
N{N_},
|
||||
K{K_},
|
||||
@@ -651,7 +657,10 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
AK0{CalculateAK0Padded(K_, KBatch_)},
|
||||
BK0{CalculateBK0Padded(K_, KBatch_)},
|
||||
MBlock{CalculateMBlock(M_)},
|
||||
NBlock{CalculateNBlock(N_)}
|
||||
NBlock{CalculateNBlock(N_)},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
c_element_op_{c_element_op}
|
||||
{
|
||||
}
|
||||
|
||||
@@ -689,6 +698,9 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
index_t BK0;
|
||||
index_t MBlock;
|
||||
index_t NBlock;
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CElementwiseOperation c_element_op_;
|
||||
};
|
||||
|
||||
// Argument
|
||||
@@ -704,8 +716,20 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
index_t StrideB_,
|
||||
index_t StrideC_,
|
||||
index_t k_batch_,
|
||||
bool is_reduce_ = false)
|
||||
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, k_batch_},
|
||||
bool is_reduce_ = false,
|
||||
AElementwiseOperation a_element_op = AElementwiseOperation{},
|
||||
BElementwiseOperation b_element_op = BElementwiseOperation{},
|
||||
CElementwiseOperation c_element_op = CElementwiseOperation{})
|
||||
: Problem{M_,
|
||||
N_,
|
||||
K_,
|
||||
StrideA_,
|
||||
StrideB_,
|
||||
StrideC_,
|
||||
k_batch_,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op},
|
||||
p_a_grid{p_a_grid_},
|
||||
p_b_grid{p_b_grid_},
|
||||
p_c_grid{p_c_grid_},
|
||||
@@ -1377,10 +1401,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
const AElementwiseOperation a_element_op{};
|
||||
const BElementwiseOperation b_element_op{};
|
||||
const CElementwiseOperation c_element_op{};
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
|
||||
|
||||
@@ -1440,7 +1460,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(0, m_block_data_idx_on_grid, 0),
|
||||
a_element_op,
|
||||
problem.a_element_op_,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
@@ -1471,7 +1491,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
make_multi_index(0, n_block_data_idx_on_grid, 0),
|
||||
b_element_op,
|
||||
problem.b_element_op_,
|
||||
b_block_desc_bk0_n_bk1,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
@@ -1598,42 +1618,67 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(n_thread_data_on_block));
|
||||
|
||||
tensor_operation::element_wise::PassThrough pass_through{};
|
||||
const auto& vpgr_to_lds_element_op = [&] {
|
||||
if constexpr(DoElementwiseBeforeCShuffle)
|
||||
{
|
||||
return problem.c_element_op_;
|
||||
}
|
||||
else
|
||||
{
|
||||
return pass_through;
|
||||
}
|
||||
};
|
||||
const auto& lds_to_global_element_op = [&] {
|
||||
if constexpr(!DoElementwiseBeforeCShuffle)
|
||||
{
|
||||
return problem.c_element_op_;
|
||||
}
|
||||
else
|
||||
{
|
||||
return pass_through;
|
||||
}
|
||||
};
|
||||
|
||||
// shuffle: threadwise copy C from VGPR to LDS
|
||||
auto c_thread_copy_vgpr_to_lds =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
M2,
|
||||
I1,
|
||||
M4,
|
||||
I1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
7,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
make_multi_index(0,
|
||||
0,
|
||||
m_thread_data_on_block_idx[I1],
|
||||
n_thread_data_on_block_idx[I1],
|
||||
m_thread_data_on_block_idx[I2],
|
||||
m_thread_data_on_block_idx[I3],
|
||||
m_thread_data_on_block_idx[I4],
|
||||
n_thread_data_on_block_idx[I2]),
|
||||
ck::tensor_operation::element_wise::PassThrough{}};
|
||||
auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
conditional_t<DoElementwiseBeforeCShuffle,
|
||||
CElementwiseOperation,
|
||||
tensor_operation::element_wise::PassThrough>,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
M2,
|
||||
I1,
|
||||
M4,
|
||||
I1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
7,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
make_multi_index(0,
|
||||
0,
|
||||
m_thread_data_on_block_idx[I1],
|
||||
n_thread_data_on_block_idx[I1],
|
||||
m_thread_data_on_block_idx[I2],
|
||||
m_thread_data_on_block_idx[I3],
|
||||
m_thread_data_on_block_idx[I4],
|
||||
n_thread_data_on_block_idx[I2]),
|
||||
vpgr_to_lds_element_op()};
|
||||
|
||||
// shuffle: blockwise copy C from LDS to global
|
||||
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
|
||||
ThisThreadBlock, // ThreadGroup
|
||||
CElementwiseOperation, // ElementwiseOperation,
|
||||
ThisThreadBlock, // ThreadGroup
|
||||
conditional_t<!DoElementwiseBeforeCShuffle,
|
||||
CElementwiseOperation,
|
||||
tensor_operation::element_wise::PassThrough>,
|
||||
CGlobalMemoryDataOperation, // DstInMemOp,
|
||||
Sequence<1,
|
||||
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
|
||||
@@ -1654,7 +1699,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
make_multi_index(0, 0, 0, 0),
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_multi_index(block_m_id, 0, block_n_id, 0),
|
||||
c_element_op};
|
||||
lds_to_global_element_op()};
|
||||
|
||||
// space filling curve for threadwise C in VGPR
|
||||
constexpr auto sfc_c_vgpr =
|
||||
@@ -1773,10 +1818,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
const AElementwiseOperation a_element_op{};
|
||||
const BElementwiseOperation b_element_op{};
|
||||
const CElementwiseOperation c_element_op{};
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
|
||||
|
||||
@@ -1836,7 +1877,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(0, m_block_data_idx_on_grid, 0),
|
||||
a_element_op,
|
||||
problem.a_element_op_,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
@@ -1867,7 +1908,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
make_multi_index(0, n_block_data_idx_on_grid, 0),
|
||||
b_element_op,
|
||||
problem.b_element_op_,
|
||||
b_block_desc_bk0_n_bk1,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
@@ -2059,7 +2100,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
make_multi_index(0, 0, 0, 0),
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_multi_index(block_m_id, 0, block_n_id, 0),
|
||||
c_element_op};
|
||||
problem.c_element_op_};
|
||||
|
||||
// space filling curve for threadwise C in VGPR
|
||||
constexpr auto sfc_c_vgpr =
|
||||
|
||||
Reference in New Issue
Block a user