Standalone softmax kernel (#284)

* initial stub for standalone softmax

* start device_softmax_mk_to_mk as a wrapper to device_reduce_mk_to_m

* host softmax validates

* compiles; to implement beta scaling

* use NaN trick to efficiently ignore OOB values during sum of exponentials

* freeload device_reduce's utility functions

* clean up interface

* adding prior value (beta scaling)

* remove restriction related to perf considerations

* apply clang-format

* clean; disable diagnostics

* resolve conflicts

* add exp wrapper

* honor HostTensorDesc interface; allow implicit cast from different vector<T> type

* test softmax for fp16/fp32

* update readme

* amend commit NaN trick

* remove redundant param added during development

* format

* replace ScalarDataType with AccDataType

* separate out test programs by precision type

* move softmax sample code to its own folder

* format

* keep up with recent changes in reduction API

* remove extra header
This commit is contained in:
Anthony Chang
2022-06-22 03:59:19 +08:00
committed by GitHub
parent be60d60d7a
commit 15c89e81f0
21 changed files with 1371 additions and 41 deletions

View File

@@ -39,7 +39,9 @@ template <typename AccDataType,
typename SrcThreadDesc_M_K,
typename DstThreadDesc_M,
typename OpReduce,
bool PropagateNan>
bool PropagateNan,
typename Accumulation =
detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>>
struct ThreadwiseReduction
{
static constexpr auto src_thread_desc_m_k = SrcThreadDesc_M_K{};
@@ -51,8 +53,6 @@ struct ThreadwiseReduction
static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!");
using Accumulation = detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>;
template <typename SrcBufferType, typename DstBufferType>
__device__ static void Reduce(const SrcBufferType& src_buf, DstBufferType& dst_buf)
{
@@ -73,12 +73,15 @@ struct ThreadwiseReduction
// 2) DstDesc is known at compile-time
// 3) SrcBuffer is static buffer
// 4) DstBuffer is static buffer
template <typename AccDataType,
typename IndexDataType,
typename SrcThreadDesc_M_K,
typename DstThreadDesc_M,
typename OpReduce,
bool PropagateNan>
template <
typename AccDataType,
typename IndexDataType,
typename SrcThreadDesc_M_K,
typename DstThreadDesc_M,
typename OpReduce,
bool PropagateNan,
typename Accumulation =
detail::AccumulateWithIndexAndNanCheck<PropagateNan, OpReduce, AccDataType, IndexDataType>>
struct ThreadwiseReductionWithIndex
{
static constexpr auto src_thread_desc_m_k = SrcThreadDesc_M_K{};
@@ -90,9 +93,6 @@ struct ThreadwiseReductionWithIndex
static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!");
using Accumulation =
detail::AccumulateWithIndexAndNanCheck<PropagateNan, OpReduce, AccDataType, IndexDataType>;
template <typename SrcValueBufferType,
typename SrcIndexBufferType,
typename DstValueBufferType,