mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
Standalone sweep once softmax kernel w/ ckProfiler (#295)
* use 'sweep once' softmax kernel where applicable * threadwise copy's dst buffer can specify invalid element value * add int8 in/out float compute softmax support give a bit of leeway for int absolute tolerance as there's a single data point of all test cases showing off-by-1 error * format * softmax inherits DeviceNormalization * softmax profiler stub * tighten up reference softmax interface * example prints tensor dimension * add fp32 to softmax profiler * rename header * hook with ckProfiler * format * resolve merge conflict * resolve merge conflicts * update normalization profiler help string * resolve conflict * typo * remove residual * softmax profiler: address feedback * test for mixed precision input/output * fully qualify ck::math::isnan * add comment for device normalization interface * revise wording * constness for alpha/beta scaler pointer
This commit is contained in:
@@ -49,7 +49,8 @@ template <typename InDataType,
|
||||
index_t KThreadSliceSize,
|
||||
index_t InSrcVectorDim,
|
||||
index_t InSrcVectorSize,
|
||||
index_t OutDstVectorSize>
|
||||
index_t OutDstVectorSize,
|
||||
bool SweepOnce>
|
||||
struct GridwiseSoftmax_mk_to_mk
|
||||
{
|
||||
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
|
||||
@@ -75,19 +76,6 @@ struct GridwiseSoftmax_mk_to_mk
|
||||
using ThreadReduceDstDesc_M =
|
||||
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
|
||||
|
||||
using BlockwiseMaxReduce = PartitionedBlockwiseReduction<AccDataType,
|
||||
BlockSize,
|
||||
ThreadClusterLengths_M_K,
|
||||
ThreadClusterArrangeOrder,
|
||||
reduce::Max,
|
||||
false>; // PropagateNan
|
||||
|
||||
using ThreadwiseMaxReduce = ThreadwiseReduction<AccDataType,
|
||||
ThreadReduceSrcDesc_M_K,
|
||||
ThreadReduceDstDesc_M,
|
||||
reduce::Max,
|
||||
false>; // PropagateNan
|
||||
|
||||
using PassThroughOp = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
@@ -105,6 +93,11 @@ struct GridwiseSoftmax_mk_to_mk
|
||||
AccDataType beta,
|
||||
OutDataType* const __restrict__ p_out_value_global)
|
||||
{
|
||||
if constexpr(SweepOnce)
|
||||
{
|
||||
num_k_block_tile_iteration = 1;
|
||||
}
|
||||
|
||||
// LDS
|
||||
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
|
||||
|
||||
@@ -149,6 +142,20 @@ struct GridwiseSoftmax_mk_to_mk
|
||||
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
|
||||
// Normally, 0 as invalid element value is adequate since 0 makes no contribution to
|
||||
// accumulated result. However, in stable softmax, all values 0s or not are subtracted by
|
||||
// another value_max. As numbers become non-zero, effectively it allows invalid values to
|
||||
// slip through and contribute to the accumulated result.
|
||||
//
|
||||
// The trick here is leveraging the fact that many math functions (add, sub, exp, ...)
|
||||
// propagate NaNs when operands have NaNs involved. By initialiing invalid element value
|
||||
// with NaN, an invalid value doing math manipulations is still NaN, which in turn can still
|
||||
// be identified as an invalid value. We can then discard the invalid values which
|
||||
// originally failed the bound check during accumulation. This allows to ignore values that
|
||||
// failed bound check even after multiple math manipulations.
|
||||
//
|
||||
// NOTE: reset coordinate after every step because the same threadwise copy will sweep
|
||||
// through global memory 3 times back and forth
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
|
||||
AccDataType,
|
||||
GridDesc_M_K,
|
||||
@@ -158,7 +165,8 @@ struct GridwiseSoftmax_mk_to_mk
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
true /* ResetCoordAfterRun */,
|
||||
true /* InvalidElementAsNaN */>(
|
||||
in_grid_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id * reduceSizePerBlock +
|
||||
@@ -198,21 +206,39 @@ struct GridwiseSoftmax_mk_to_mk
|
||||
block_local_id * reduceSizePerBlock + thread_k_cluster_id * KThreadSliceSize),
|
||||
PassThroughOp{});
|
||||
|
||||
constexpr auto in_thread_copy_fwd_step = make_multi_index(0, K_BlockTileSize);
|
||||
constexpr auto in_thread_copy_bwd_step = make_multi_index(0, -K_BlockTileSize);
|
||||
constexpr auto in_thread_copy_fwd_step =
|
||||
make_multi_index(0, SweepOnce ? 0 : K_BlockTileSize);
|
||||
constexpr auto in_thread_copy_bwd_step =
|
||||
make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize);
|
||||
|
||||
///
|
||||
/// max(x)
|
||||
///
|
||||
const auto in_global_val_buf_oob_non_zero = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_value_global,
|
||||
in_grid_desc_m_k.GetElementSpaceSize(),
|
||||
reduce::Max::template GetIdentityValue<InDataType>());
|
||||
using BlockwiseMaxReduce = PartitionedBlockwiseReduction<
|
||||
AccDataType,
|
||||
BlockSize,
|
||||
ThreadClusterLengths_M_K,
|
||||
ThreadClusterArrangeOrder,
|
||||
reduce::Max,
|
||||
false, // param ignored
|
||||
detail::AccumulateWithNanIgnore<reduce::Max, AccDataType>>;
|
||||
|
||||
using ThreadwiseMaxReduce =
|
||||
ThreadwiseReduction<AccDataType,
|
||||
ThreadReduceSrcDesc_M_K,
|
||||
ThreadReduceDstDesc_M,
|
||||
reduce::Max,
|
||||
false, // param ignored
|
||||
detail::AccumulateWithNanIgnore<reduce::Max, AccDataType>>;
|
||||
|
||||
const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_value_global, in_grid_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
index_t reducedTiles = 0;
|
||||
do
|
||||
{
|
||||
threadwise_src_load.Run(in_grid_desc_m_k,
|
||||
in_global_val_buf_oob_non_zero,
|
||||
in_global_val_buf,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_buf);
|
||||
@@ -232,26 +258,6 @@ struct GridwiseSoftmax_mk_to_mk
|
||||
///
|
||||
/// sum(exp(x - max(x)))
|
||||
///
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) = reduce::Add::template GetIdentityValue<AccDataType>();
|
||||
});
|
||||
|
||||
// Normally, 0 as invalid element value is adequate since 0 makes no contribution to
|
||||
// accumulated result. However, in stable softmax, all values 0s or not are subtracted by
|
||||
// another value_max. As numbers become non-zero, effectively it allows invalid values to
|
||||
// slip through and contribute to the accumulated result.
|
||||
//
|
||||
// The trick here is leveraging the fact that many math functions (add, sub, exp, ...)
|
||||
// propagate NaNs when operands have NaNs involved. By initialiing invalid element value
|
||||
// with NaN, an invalid value doing math manipulations is still NaN, which in turn can still
|
||||
// be identified as an invalid value. We can then discard the invalid values which
|
||||
// originally failed the bound check during accumulation. This allows to ignore values that
|
||||
// failed bound check even after multiple math manipulations.
|
||||
const auto in_global_val_buf_oob_nan =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
|
||||
in_grid_desc_m_k.GetElementSpaceSize(),
|
||||
NumericLimits<InDataType>::QuietNaN());
|
||||
|
||||
using BlockwiseSumReduce = PartitionedBlockwiseReduction<
|
||||
AccDataType,
|
||||
BlockSize,
|
||||
@@ -272,22 +278,25 @@ struct GridwiseSoftmax_mk_to_mk
|
||||
reducedTiles = 0;
|
||||
do
|
||||
{
|
||||
threadwise_src_load.Run(in_grid_desc_m_k,
|
||||
in_global_val_buf_oob_nan,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_buf);
|
||||
if constexpr(!SweepOnce)
|
||||
{
|
||||
threadwise_src_load.Run(in_grid_desc_m_k,
|
||||
in_global_val_buf,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_buf);
|
||||
}
|
||||
|
||||
// do element-wise pre-reduction operation
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
in_thread_buf(Number<offset>{}) =
|
||||
out_thread_buf(Number<offset>{}) =
|
||||
math::exp(in_thread_buf(Number<offset>{}) - max_value_buf(iM));
|
||||
});
|
||||
});
|
||||
|
||||
ThreadwiseSumReduce::Reduce(in_thread_buf, accu_value_buf);
|
||||
ThreadwiseSumReduce::Reduce(out_thread_buf, accu_value_buf);
|
||||
|
||||
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_bwd_step);
|
||||
|
||||
@@ -309,11 +318,14 @@ struct GridwiseSoftmax_mk_to_mk
|
||||
{
|
||||
do
|
||||
{
|
||||
threadwise_src_load.Run(in_grid_desc_m_k,
|
||||
in_global_val_buf_oob_nan,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_buf);
|
||||
if constexpr(!SweepOnce)
|
||||
{
|
||||
threadwise_src_load.Run(in_grid_desc_m_k,
|
||||
in_global_val_buf,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_buf);
|
||||
}
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
// out = alpha * exp(x - max(x)) / sum(exp(x - max(x)))
|
||||
@@ -340,18 +352,27 @@ struct GridwiseSoftmax_mk_to_mk
|
||||
}
|
||||
else
|
||||
{
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
AccDataType,
|
||||
MThreadSliceSize * KThreadSliceSize,
|
||||
true>
|
||||
in_prior_dst_buf;
|
||||
do
|
||||
{
|
||||
threadwise_src_load.Run(in_grid_desc_m_k,
|
||||
in_global_val_buf_oob_nan,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_buf);
|
||||
if constexpr(!SweepOnce)
|
||||
{
|
||||
threadwise_src_load.Run(in_grid_desc_m_k,
|
||||
in_global_val_buf,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_buf);
|
||||
}
|
||||
threadwise_dst_load.Run(out_grid_desc_m_k,
|
||||
out_global_val_buf,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
out_thread_buf);
|
||||
in_prior_dst_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
// out = alpha * exp(x - max(x)) / sum(exp(x - max(x))) + beta * prior_out
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
@@ -360,7 +381,7 @@ struct GridwiseSoftmax_mk_to_mk
|
||||
out_thread_buf(Number<offset>{}) =
|
||||
alpha * math::exp(in_thread_buf(Number<offset>{}) - max_value_buf(iM)) /
|
||||
accu_value_buf(iM) +
|
||||
beta * out_thread_buf(Number<offset>{});
|
||||
beta * in_prior_dst_buf(Number<offset>{});
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
Reference in New Issue
Block a user