mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 03:49:41 +00:00
Standalone softmax kernel (#284)
* initial stub for standalone softmax
* start device_softmax_mk_to_mk as a wrapper to device_reduce_mk_to_m
* host softmax validates
* compiles; to implement beta scaling
* use NaN trick to efficiently ignore OOB values during sum of exponentials
* freeload device_reduce's utility functions
* clean up interface
* adding prior value (beta scaling)
* remove restriction related to perf considerations
* apply clang-format
* clean; disable diagnostics
* resolve conflicts
* add exp wrapper
* honor HostTensorDesc interface; allow implicit cast from different vector<T> type
* test softmax for fp16/fp32
* update readme
* amend commit NaN trick
* remove redundant param added during development
* format
* replace ScalarDataType with AccDataType
* separate out test programs by precision type
* move softmax sample code to its own folder
* format
* keep up with recent changes in reduction API
* remove extra header
[ROCm/composable_kernel commit: 15c89e81f0]
This commit is contained in:
@@ -45,7 +45,9 @@ template <typename AccDataType,
|
||||
typename ThreadClusterLengths_M_K,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename OpReduce,
|
||||
bool PropagateNan>
|
||||
bool PropagateNan,
|
||||
typename Accumulation =
|
||||
detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>>
|
||||
struct PartitionedBlockwiseReduction
|
||||
{
|
||||
static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1),
|
||||
@@ -62,8 +64,6 @@ struct PartitionedBlockwiseReduction
|
||||
static constexpr auto thread_cluster_desc =
|
||||
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using Accumulation = detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>;
|
||||
|
||||
template <typename BufferType>
|
||||
__device__ static void Reduce(BufferType& work_buffer, AccDataType& in_out_value)
|
||||
{
|
||||
@@ -113,13 +113,16 @@ struct PartitionedBlockwiseReduction
|
||||
// 3) in_out_value/in_out_index is the input data in vgpr from each thread
|
||||
// 4) in_out_value/in_out_index is the over-written reduced output in vgpr for each thread
|
||||
// clang-format on
|
||||
template <typename AccDataType,
|
||||
typename IndexDataType,
|
||||
index_t BlockSize,
|
||||
typename ThreadClusterLengths_M_K,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename OpReduce,
|
||||
bool PropagateNan>
|
||||
template <
|
||||
typename AccDataType,
|
||||
typename IndexDataType,
|
||||
index_t BlockSize,
|
||||
typename ThreadClusterLengths_M_K,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename OpReduce,
|
||||
bool PropagateNan,
|
||||
typename Accumulation =
|
||||
detail::AccumulateWithIndexAndNanCheck<PropagateNan, OpReduce, AccDataType, IndexDataType>>
|
||||
struct PartitionedBlockwiseReductionWithIndex
|
||||
{
|
||||
static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1),
|
||||
@@ -136,9 +139,6 @@ struct PartitionedBlockwiseReductionWithIndex
|
||||
static constexpr auto thread_cluster_desc =
|
||||
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using Accumulation =
|
||||
detail::AccumulateWithIndexAndNanCheck<PropagateNan, OpReduce, AccDataType, IndexDataType>;
|
||||
|
||||
// This interface accumulates on both data values and indices
|
||||
template <typename BufferType, typename IdxBufferType>
|
||||
__device__ static void Reduce(BufferType& work_val_buffer,
|
||||
|
||||
@@ -390,10 +390,8 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
|
||||
};
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
static bool IsSupportedArgument(const Argument* pArg)
|
||||
{
|
||||
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
|
||||
|
||||
if constexpr(use_multiblock)
|
||||
{
|
||||
if(static_cast<float>(pArg->beta_) != 0.0f)
|
||||
@@ -442,11 +440,16 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
|
||||
else
|
||||
{
|
||||
// cases with very small reduce_total_length should be handled by ThreadWise kernel
|
||||
if(pArg->reduce_total_length / KThreadSliceSize < 2)
|
||||
return (false);
|
||||
// if(pArg->reduce_total_length / KThreadSliceSize < 2)
|
||||
// return (false);
|
||||
};
|
||||
|
||||
return (true);
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(dynamic_cast<const Argument*>(p_arg));
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
|
||||
203
include/ck/tensor_operation/gpu/device/device_softmax.hpp
Normal file
203
include/ck/tensor_operation/gpu/device/device_softmax.hpp
Normal file
@@ -0,0 +1,203 @@
|
||||
#ifndef DEVICE_SOFTMAX_HPP
|
||||
#define DEVICE_SOFTMAX_HPP
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include "device.hpp"
|
||||
#include "device_base.hpp"
|
||||
#include "device_reduce.hpp"
|
||||
#include "device_reduce_multiblock.hpp"
|
||||
#include "device_reduce_common.hpp"
|
||||
#include "gridwise_softmax.hpp"
|
||||
#include "gridwise_set_buffer_value.hpp"
|
||||
#include "reduction_operator.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename InDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t KThreadSliceSize,
|
||||
index_t InSrcVectorDim,
|
||||
index_t InSrcVectorSize,
|
||||
index_t OutDstVectorSize>
|
||||
struct DeviceSoftmax : public BaseOperator
|
||||
{
|
||||
using PassThrough = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
// Used for freeloading of some handy functions from DeviceReduceMultiBlock
|
||||
using Reduction = DeviceReduceMultiBlock<InDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
Rank,
|
||||
NumReduceDim,
|
||||
reduce::Add,
|
||||
PassThrough, // InElementwiseOperation
|
||||
PassThrough, // AccElementwiseOperation
|
||||
InMemoryDataOperationEnum::Set,
|
||||
false, // PropagateNan
|
||||
false, // OutputIndex
|
||||
false, // HaveIndexInputIfOutputIndex
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1>; // OutDstVectorSize
|
||||
|
||||
using GridDesc_M_K = decltype(Reduction::MakeSrc2dDescriptor({1}, {1}, 1, 1));
|
||||
|
||||
using GridwiseReduce = GridwiseSoftmax_mk_to_mk<InDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
GridDesc_M_K,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
OutDstVectorSize>;
|
||||
|
||||
struct Argument : public Reduction::Argument
|
||||
{
|
||||
Argument(const std::vector<index_t> inLengths,
|
||||
const std::vector<index_t> inStrides,
|
||||
const std::vector<index_t> reduceDims,
|
||||
AccDataType alpha,
|
||||
AccDataType beta,
|
||||
const InDataType* in_dev,
|
||||
OutDataType* out_dev)
|
||||
: Reduction::Argument(inLengths,
|
||||
inStrides,
|
||||
{},
|
||||
{},
|
||||
reduceDims,
|
||||
0.0f, // alpha
|
||||
0.0f, // beta
|
||||
in_dev,
|
||||
nullptr,
|
||||
out_dev,
|
||||
nullptr,
|
||||
PassThrough{},
|
||||
PassThrough{}),
|
||||
// FIXME: The base class DeviceReduceMultiBlock::Argument only supports alpha/beta of
|
||||
// float32 precision. Make it support any data type so the fields can be removed.
|
||||
alpha_(alpha),
|
||||
beta_(beta)
|
||||
{
|
||||
// std::cout << "blkGroupSize= " << this->blkGroupSize
|
||||
// << ", numBlockTileIteration= " << this->numBlockTileIteration
|
||||
// << ", gridSize=" << this->gridSize
|
||||
// << ", invariant_total_length=" << this->invariant_total_length <<
|
||||
// std::endl;
|
||||
}
|
||||
|
||||
AccDataType alpha_;
|
||||
AccDataType beta_;
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
const auto in_grid_desc_m_k = Reduction::MakeSrc2dDescriptor(
|
||||
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
|
||||
const auto out_grid_desc_m_k = Reduction::MakeSrc2dDescriptor(
|
||||
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
|
||||
|
||||
const auto kernel_main =
|
||||
kernel_softmax<GridwiseReduce, InDataType, OutDataType, AccDataType, GridDesc_M_K>;
|
||||
|
||||
float avg_time = 0;
|
||||
|
||||
avg_time += launch_and_time_kernel(stream_config,
|
||||
kernel_main,
|
||||
dim3(arg.gridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
in_grid_desc_m_k,
|
||||
out_grid_desc_m_k,
|
||||
arg.blkGroupSize,
|
||||
arg.numBlockTileIteration,
|
||||
arg.alpha_,
|
||||
arg.in_dev_,
|
||||
arg.beta_,
|
||||
arg.out_dev_);
|
||||
|
||||
return (avg_time);
|
||||
};
|
||||
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
};
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg);
|
||||
|
||||
if(!Reduction::IsSupportedArgument(p_arg_))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(p_arg_->inLengths_[Rank - 1] % OutDstVectorSize != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const std::vector<index_t> inLengths,
|
||||
const std::vector<index_t> inStrides,
|
||||
const std::vector<int> reduceDims,
|
||||
AccDataType alpha,
|
||||
AccDataType beta,
|
||||
const void* in_dev,
|
||||
void* out_dev)
|
||||
{
|
||||
return std::make_unique<Argument>(inLengths,
|
||||
inStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
static_cast<const InDataType*>(in_dev),
|
||||
static_cast<OutDataType*>(out_dev));
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() { return std::make_unique<Invoker>(); };
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceReduceSoftmax<" << BlockSize << ",";
|
||||
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
|
||||
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
|
||||
str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif // DEVICE_SOFTMAX_HPP
|
||||
407
include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp
Normal file
407
include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp
Normal file
@@ -0,0 +1,407 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2022 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#ifndef GRIDWISE_SOFTMAX_HPP
|
||||
#define GRIDWISE_SOFTMAX_HPP
|
||||
|
||||
#include "reduction_common.hpp"
|
||||
#include "reduction_operator.hpp"
|
||||
#include "reduction_functions_accumulate.hpp"
|
||||
#include "reduction_functions_blockwise.hpp"
|
||||
#include "reduction_functions_threadwise.hpp"
|
||||
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseReduction,
|
||||
typename InDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
typename GridDesc_M_K>
|
||||
__global__ void kernel_softmax(const GridDesc_M_K in_grid_desc_m_k,
|
||||
const GridDesc_M_K out_grid_desc_m_k,
|
||||
index_t block_group_size,
|
||||
index_t num_k_block_tile_iteration,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_in_value_global,
|
||||
AccDataType beta,
|
||||
OutDataType* const __restrict__ p_out_value_global)
|
||||
{
|
||||
GridwiseReduction::Run(in_grid_desc_m_k,
|
||||
out_grid_desc_m_k,
|
||||
block_group_size,
|
||||
num_k_block_tile_iteration,
|
||||
alpha,
|
||||
p_in_value_global,
|
||||
beta,
|
||||
p_out_value_global);
|
||||
};
|
||||
|
||||
template <typename InDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
typename GridDesc_M_K,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t KThreadSliceSize,
|
||||
index_t InSrcVectorDim,
|
||||
index_t InSrcVectorSize,
|
||||
index_t OutDstVectorSize>
|
||||
struct GridwiseSoftmax_mk_to_mk
|
||||
{
|
||||
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
|
||||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
|
||||
(KThreadSliceSize % OutDstVectorSize == 0),
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
|
||||
|
||||
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
|
||||
|
||||
using ThreadBufferDimAccessOrder =
|
||||
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
using ThreadClusterArrangeOrder =
|
||||
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
static constexpr auto thread_cluster_desc =
|
||||
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
|
||||
using ThreadReduceDstDesc_M =
|
||||
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
|
||||
|
||||
using BlockwiseMaxReduce = PartitionedBlockwiseReduction<AccDataType,
|
||||
BlockSize,
|
||||
ThreadClusterLengths_M_K,
|
||||
ThreadClusterArrangeOrder,
|
||||
reduce::Max,
|
||||
false>; // PropagateNan
|
||||
|
||||
using ThreadwiseMaxReduce = ThreadwiseReduction<AccDataType,
|
||||
ThreadReduceSrcDesc_M_K,
|
||||
ThreadReduceDstDesc_M,
|
||||
reduce::Max,
|
||||
false>; // PropagateNan
|
||||
|
||||
using PassThroughOp = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
|
||||
__device__ static void Run(const GridDesc_M_K& in_grid_desc_m_k,
|
||||
const GridDesc_M_K& out_grid_desc_m_k,
|
||||
index_t block_group_size,
|
||||
index_t num_k_block_tile_iteration,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_in_value_global,
|
||||
AccDataType beta,
|
||||
OutDataType* const __restrict__ p_out_value_global)
|
||||
{
|
||||
// LDS
|
||||
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
|
||||
|
||||
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_out_value_global, out_grid_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
auto reduce_work_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
in_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
out_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> max_value_buf;
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
max_value_buf(I) = reduce::Max::template GetIdentityValue<AccDataType>();
|
||||
});
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) = reduce::Add::template GetIdentityValue<AccDataType>();
|
||||
});
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_id = get_block_1d_id();
|
||||
const index_t blkgroup_id = block_global_id / block_group_size;
|
||||
const index_t block_local_id = block_global_id % block_group_size;
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
|
||||
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[I0];
|
||||
const auto thread_k_cluster_id = thread_cluster_idx[I1];
|
||||
|
||||
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
|
||||
|
||||
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
|
||||
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
|
||||
AccDataType,
|
||||
GridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
ThreadBufferDimAccessOrder,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id * reduceSizePerBlock +
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2<OutDataType,
|
||||
AccDataType,
|
||||
GridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
ThreadBufferDimAccessOrder,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
out_grid_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id * reduceSizePerBlock +
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
auto threadwise_dst_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
OutDataType,
|
||||
decltype(thread_buffer_desc),
|
||||
GridDesc_M_K,
|
||||
PassThroughOp,
|
||||
ThreadBufferLengths,
|
||||
ThreadBufferDimAccessOrder,
|
||||
InSrcVectorDim,
|
||||
OutDstVectorSize,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
out_grid_desc_m_k,
|
||||
make_multi_index(
|
||||
blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id * reduceSizePerBlock + thread_k_cluster_id * KThreadSliceSize),
|
||||
PassThroughOp{});
|
||||
|
||||
constexpr auto in_thread_copy_fwd_step = make_multi_index(0, K_BlockTileSize);
|
||||
constexpr auto in_thread_copy_bwd_step = make_multi_index(0, -K_BlockTileSize);
|
||||
|
||||
///
|
||||
/// max(x)
|
||||
///
|
||||
const auto in_global_val_buf_oob_non_zero = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_value_global,
|
||||
in_grid_desc_m_k.GetElementSpaceSize(),
|
||||
reduce::Max::template GetIdentityValue<InDataType>());
|
||||
index_t reducedTiles = 0;
|
||||
do
|
||||
{
|
||||
threadwise_src_load.Run(in_grid_desc_m_k,
|
||||
in_global_val_buf_oob_non_zero,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_buf);
|
||||
|
||||
ThreadwiseMaxReduce::Reduce(in_thread_buf, max_value_buf);
|
||||
|
||||
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step);
|
||||
|
||||
reducedTiles++;
|
||||
} while(reducedTiles < num_k_block_tile_iteration);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}(
|
||||
[&](auto I) { BlockwiseMaxReduce::Reduce(reduce_work_buf, max_value_buf(I)); });
|
||||
|
||||
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_bwd_step);
|
||||
|
||||
///
|
||||
/// sum(exp(x - max(x)))
|
||||
///
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) = reduce::Add::template GetIdentityValue<AccDataType>();
|
||||
});
|
||||
|
||||
// Normally, 0 as invalid element value is adequate since 0 makes no contribution to
|
||||
// accumulated result. However, in stable softmax, all values 0s or not are subtracted by
|
||||
// another value_max. As numbers become non-zero, effectively it allows invalid values to
|
||||
// slip through and contribute to the accumulated result.
|
||||
//
|
||||
// The trick here is leveraging the fact that many math functions (add, sub, exp, ...)
|
||||
// propagate NaNs when operands have NaNs involved. By initialiing invalid element value
|
||||
// with NaN, an invalid value doing math manipulations is still NaN, which in turn can still
|
||||
// be identified as an invalid value. We can then discard the invalid values which
|
||||
// originally failed the bound check during accumulation. This allows to ignore values that
|
||||
// failed bound check even after multiple math manipulations.
|
||||
const auto in_global_val_buf_oob_nan =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
|
||||
in_grid_desc_m_k.GetElementSpaceSize(),
|
||||
NumericLimits<InDataType>::QuietNaN());
|
||||
|
||||
using BlockwiseSumReduce = PartitionedBlockwiseReduction<
|
||||
AccDataType,
|
||||
BlockSize,
|
||||
ThreadClusterLengths_M_K,
|
||||
ThreadClusterArrangeOrder,
|
||||
reduce::Add,
|
||||
false, // ignored
|
||||
detail::AccumulateWithNanIgnore<reduce::Add, AccDataType>>;
|
||||
|
||||
using ThreadwiseSumReduce =
|
||||
ThreadwiseReduction<AccDataType,
|
||||
ThreadReduceSrcDesc_M_K,
|
||||
ThreadReduceDstDesc_M,
|
||||
reduce::Add,
|
||||
false, // ignored
|
||||
detail::AccumulateWithNanIgnore<reduce::Add, AccDataType>>;
|
||||
|
||||
reducedTiles = 0;
|
||||
do
|
||||
{
|
||||
threadwise_src_load.Run(in_grid_desc_m_k,
|
||||
in_global_val_buf_oob_nan,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_buf);
|
||||
|
||||
// do element-wise pre-reduction operation
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
in_thread_buf(Number<offset>{}) =
|
||||
math::exp(in_thread_buf(Number<offset>{}) - max_value_buf(iM));
|
||||
});
|
||||
});
|
||||
|
||||
ThreadwiseSumReduce::Reduce(in_thread_buf, accu_value_buf);
|
||||
|
||||
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_bwd_step);
|
||||
|
||||
reducedTiles++;
|
||||
} while(reducedTiles < num_k_block_tile_iteration);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
BlockwiseSumReduce::Reduce(reduce_work_buf, accu_value_buf(I));
|
||||
// block_sync_lds();
|
||||
});
|
||||
|
||||
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step);
|
||||
|
||||
///
|
||||
/// softmax
|
||||
///
|
||||
reducedTiles = 0;
|
||||
if(float_equal_zero{}(beta))
|
||||
{
|
||||
do
|
||||
{
|
||||
threadwise_src_load.Run(in_grid_desc_m_k,
|
||||
in_global_val_buf_oob_nan,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
// out = alpha * exp(x - max(x)) / sum(exp(x - max(x)))
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset =
|
||||
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
out_thread_buf(Number<offset>{}) =
|
||||
alpha * math::exp(in_thread_buf(Number<offset>{}) - max_value_buf(iM)) /
|
||||
accu_value_buf(iM);
|
||||
});
|
||||
});
|
||||
|
||||
threadwise_dst_store.Run(thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
out_thread_buf,
|
||||
out_grid_desc_m_k,
|
||||
out_global_val_buf);
|
||||
|
||||
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step);
|
||||
threadwise_dst_store.MoveDstSliceWindow(out_grid_desc_m_k, in_thread_copy_fwd_step);
|
||||
|
||||
reducedTiles++;
|
||||
} while(reducedTiles < num_k_block_tile_iteration);
|
||||
}
|
||||
else
|
||||
{
|
||||
do
|
||||
{
|
||||
threadwise_src_load.Run(in_grid_desc_m_k,
|
||||
in_global_val_buf_oob_nan,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_buf);
|
||||
threadwise_dst_load.Run(out_grid_desc_m_k,
|
||||
out_global_val_buf,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
out_thread_buf);
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
// out = alpha * exp(x - max(x)) / sum(exp(x - max(x))) + beta * prior_out
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset =
|
||||
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
out_thread_buf(Number<offset>{}) =
|
||||
alpha * math::exp(in_thread_buf(Number<offset>{}) - max_value_buf(iM)) /
|
||||
accu_value_buf(iM) +
|
||||
beta * out_thread_buf(Number<offset>{});
|
||||
});
|
||||
});
|
||||
|
||||
threadwise_dst_store.Run(thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
out_thread_buf,
|
||||
out_grid_desc_m_k,
|
||||
out_global_val_buf);
|
||||
|
||||
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step);
|
||||
threadwise_dst_store.MoveDstSliceWindow(out_grid_desc_m_k, in_thread_copy_fwd_step);
|
||||
threadwise_dst_load.MoveSrcSliceWindow(out_grid_desc_m_k, in_thread_copy_fwd_step);
|
||||
|
||||
reducedTiles++;
|
||||
} while(reducedTiles < num_k_block_tile_iteration);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif // GRIDWISE_SOFTMAX_HPP
|
||||
@@ -39,7 +39,9 @@ template <typename AccDataType,
|
||||
typename SrcThreadDesc_M_K,
|
||||
typename DstThreadDesc_M,
|
||||
typename OpReduce,
|
||||
bool PropagateNan>
|
||||
bool PropagateNan,
|
||||
typename Accumulation =
|
||||
detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>>
|
||||
struct ThreadwiseReduction
|
||||
{
|
||||
static constexpr auto src_thread_desc_m_k = SrcThreadDesc_M_K{};
|
||||
@@ -51,8 +53,6 @@ struct ThreadwiseReduction
|
||||
|
||||
static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!");
|
||||
|
||||
using Accumulation = detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>;
|
||||
|
||||
template <typename SrcBufferType, typename DstBufferType>
|
||||
__device__ static void Reduce(const SrcBufferType& src_buf, DstBufferType& dst_buf)
|
||||
{
|
||||
@@ -73,12 +73,15 @@ struct ThreadwiseReduction
|
||||
// 2) DstDesc is known at compile-time
|
||||
// 3) SrcBuffer is static buffer
|
||||
// 4) DstBuffer is static buffer
|
||||
template <typename AccDataType,
|
||||
typename IndexDataType,
|
||||
typename SrcThreadDesc_M_K,
|
||||
typename DstThreadDesc_M,
|
||||
typename OpReduce,
|
||||
bool PropagateNan>
|
||||
template <
|
||||
typename AccDataType,
|
||||
typename IndexDataType,
|
||||
typename SrcThreadDesc_M_K,
|
||||
typename DstThreadDesc_M,
|
||||
typename OpReduce,
|
||||
bool PropagateNan,
|
||||
typename Accumulation =
|
||||
detail::AccumulateWithIndexAndNanCheck<PropagateNan, OpReduce, AccDataType, IndexDataType>>
|
||||
struct ThreadwiseReductionWithIndex
|
||||
{
|
||||
static constexpr auto src_thread_desc_m_k = SrcThreadDesc_M_K{};
|
||||
@@ -90,9 +93,6 @@ struct ThreadwiseReductionWithIndex
|
||||
|
||||
static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!");
|
||||
|
||||
using Accumulation =
|
||||
detail::AccumulateWithIndexAndNanCheck<PropagateNan, OpReduce, AccDataType, IndexDataType>;
|
||||
|
||||
template <typename SrcValueBufferType,
|
||||
typename SrcIndexBufferType,
|
||||
typename DstValueBufferType,
|
||||
|
||||
@@ -1001,6 +1001,11 @@ struct NumericLimits
|
||||
__host__ __device__ static constexpr T Max() { return std::numeric_limits<T>::max(); }
|
||||
|
||||
__host__ __device__ static constexpr T Lowest() { return std::numeric_limits<T>::lowest(); }
|
||||
|
||||
__host__ __device__ static constexpr T QuietNaN()
|
||||
{
|
||||
return std::numeric_limits<T>::quiet_NaN();
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -1009,12 +1014,15 @@ struct NumericLimits<half_t>
|
||||
static constexpr unsigned short binary_min = 0x0400;
|
||||
static constexpr unsigned short binary_max = 0x7BFF;
|
||||
static constexpr unsigned short binary_lowest = 0xFBFF;
|
||||
static constexpr unsigned short binary_qnan = 0x7FFF;
|
||||
|
||||
__host__ __device__ static constexpr half_t Min() { return bit_cast<half_t>(binary_min); }
|
||||
|
||||
__host__ __device__ static constexpr half_t Max() { return bit_cast<half_t>(binary_max); }
|
||||
|
||||
__host__ __device__ static constexpr half_t Lowest() { return bit_cast<half_t>(binary_lowest); }
|
||||
|
||||
__host__ __device__ static constexpr half_t QuietNaN() { return bit_cast<half_t>(binary_qnan); }
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -142,6 +142,22 @@ __host__ __device__ constexpr auto min(X x, Ys... ys)
|
||||
return min(x, min(ys...));
|
||||
}
|
||||
|
||||
// disallow implicit type casting
|
||||
template <typename T>
|
||||
__device__ T exp(T x);
|
||||
|
||||
template <>
|
||||
__device__ float exp<float>(float x)
|
||||
{
|
||||
return __expf(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ double exp<double>(double x)
|
||||
{
|
||||
return exp(x);
|
||||
}
|
||||
|
||||
// greatest common divisor, aka highest common factor
|
||||
__host__ __device__ constexpr index_t gcd(index_t x, index_t y)
|
||||
{
|
||||
|
||||
@@ -35,9 +35,27 @@
|
||||
namespace ck {
|
||||
namespace detail {
|
||||
|
||||
// Check for NaN; guarantee NaNs are NOT propagated to result (i.e., ignore NaNs)
|
||||
template <typename ReduceOperation, typename AccDataType>
|
||||
struct AccumulateWithNanIgnore
|
||||
{
|
||||
__device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal)
|
||||
{
|
||||
if(!isnan(currVal))
|
||||
{
|
||||
ReduceOperation{}(accuVal, currVal);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
template <bool PropagateNan, typename ReduceOperation, typename AccDataType>
|
||||
struct AccumulateWithNanCheck;
|
||||
|
||||
// Does not check for NaN; does not guarantee NaNs be propagated to result
|
||||
// e.g., given that max(a, b) = a > b ? a : b
|
||||
// then max(NaN, 1) returns 1
|
||||
// max(1, NaN) returns NaN
|
||||
// since any comparison involving NaNs returns false
|
||||
template <typename ReduceOperation, typename AccDataType>
|
||||
struct AccumulateWithNanCheck<false, ReduceOperation, AccDataType>
|
||||
{
|
||||
@@ -48,6 +66,7 @@ struct AccumulateWithNanCheck<false, ReduceOperation, AccDataType>
|
||||
};
|
||||
};
|
||||
|
||||
// Check for NaN; guarantees NaNs be propagated to result
|
||||
template <typename ReduceOperation, typename AccDataType>
|
||||
struct AccumulateWithNanCheck<true, ReduceOperation, AccDataType>
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user