mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 20:51:23 +00:00
Standalone softmax kernel (#284)
* initial stub for standalone softmax * start device_softmax_mk_to_mk as a wrapper to device_reduce_mk_to_m * host softmax validates * compiles; to implement beta scaling * use NaN trick to efficiently ignore OOB values during sum of exponentials * freeload device_reduce's utility functions * clean up interface * adding prior value (beta scaling) * remove restriction related to perf considerations * apply clang-format * clean; disable diagnostics * resolve conflicts * add exp wrapper * honor HostTensorDesc interface; allow implicit cast from different vector<T> type * test softmax for fp16/fp32 * update readme * amend commit NaN trick * remove redundant param added during development * format * replace ScalarDataType with AccDataType * separate out test programs by precision type * move softmax sample code to its own folder * format * keep up with recent changes in reduction API * remove extra header
This commit is contained in:
@@ -39,7 +39,9 @@ template <typename AccDataType,
|
||||
typename SrcThreadDesc_M_K,
|
||||
typename DstThreadDesc_M,
|
||||
typename OpReduce,
|
||||
bool PropagateNan>
|
||||
bool PropagateNan,
|
||||
typename Accumulation =
|
||||
detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>>
|
||||
struct ThreadwiseReduction
|
||||
{
|
||||
static constexpr auto src_thread_desc_m_k = SrcThreadDesc_M_K{};
|
||||
@@ -51,8 +53,6 @@ struct ThreadwiseReduction
|
||||
|
||||
static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!");
|
||||
|
||||
using Accumulation = detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>;
|
||||
|
||||
template <typename SrcBufferType, typename DstBufferType>
|
||||
__device__ static void Reduce(const SrcBufferType& src_buf, DstBufferType& dst_buf)
|
||||
{
|
||||
@@ -73,12 +73,15 @@ struct ThreadwiseReduction
|
||||
// 2) DstDesc is known at compile-time
|
||||
// 3) SrcBuffer is static buffer
|
||||
// 4) DstBuffer is static buffer
|
||||
template <typename AccDataType,
|
||||
typename IndexDataType,
|
||||
typename SrcThreadDesc_M_K,
|
||||
typename DstThreadDesc_M,
|
||||
typename OpReduce,
|
||||
bool PropagateNan>
|
||||
template <
|
||||
typename AccDataType,
|
||||
typename IndexDataType,
|
||||
typename SrcThreadDesc_M_K,
|
||||
typename DstThreadDesc_M,
|
||||
typename OpReduce,
|
||||
bool PropagateNan,
|
||||
typename Accumulation =
|
||||
detail::AccumulateWithIndexAndNanCheck<PropagateNan, OpReduce, AccDataType, IndexDataType>>
|
||||
struct ThreadwiseReductionWithIndex
|
||||
{
|
||||
static constexpr auto src_thread_desc_m_k = SrcThreadDesc_M_K{};
|
||||
@@ -90,9 +93,6 @@ struct ThreadwiseReductionWithIndex
|
||||
|
||||
static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!");
|
||||
|
||||
using Accumulation =
|
||||
detail::AccumulateWithIndexAndNanCheck<PropagateNan, OpReduce, AccDataType, IndexDataType>;
|
||||
|
||||
template <typename SrcValueBufferType,
|
||||
typename SrcIndexBufferType,
|
||||
typename DstValueBufferType,
|
||||
|
||||
Reference in New Issue
Block a user