mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-24 06:44:36 +00:00
Standalone sweep once softmax kernel w/ ckProfiler (#295)
* use 'sweep once' softmax kernel where applicable
* threadwise copy's dst buffer can specify invalid element value
* add int8 in/out float compute softmax support
give a bit of leeway for int absolute tolerance as there's a single data point of all test cases showing off-by-1 error
* format
* softmax inherits DeviceNormalization
* softmax profiler stub
* tighten up reference softmax interface
* example prints tensor dimension
* add fp32 to softmax profiler
* rename header
* hook with ckProfiler
* format
* resolve merge conflict
* resolve merge conflicts
* update normalization profiler help string
* resolve conflict
* typo
* remove residual
* softmax profiler: address feedback
* test for mixed precision input/output
* fully qualify ck::math::isnan
* add comment for device normalization interface
* revise wording
* constness for alpha/beta scaler pointer
[ROCm/composable_kernel commit: 93c99f3d87]
This commit is contained in:
@@ -236,9 +236,14 @@ template <typename SrcData,
|
||||
index_t SrcScalarPerVector,
|
||||
index_t SrcScalarStrideInVector,
|
||||
bool SrcResetCoordinateAfterRun,
|
||||
bool InvalidElementAsNaN = false,
|
||||
typename enable_if<DstDesc::IsKnownAtCompileTime(), bool>::type = false>
|
||||
struct ThreadwiseTensorSliceTransfer_v2
|
||||
{
|
||||
static_assert((InvalidElementAsNaN && !std::is_integral<DstData>::value) ||
|
||||
(!InvalidElementAsNaN),
|
||||
"Filling invalid element as NaN is only for floating point types");
|
||||
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
@@ -318,8 +323,18 @@ struct ThreadwiseTensorSliceTransfer_v2
|
||||
dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx +
|
||||
i * src_scalar_step_in_vector);
|
||||
|
||||
dst_buf(Number<dst_offset>{}) =
|
||||
type_convert<DstData>(src_vector.template AsType<SrcData>()[i]);
|
||||
if constexpr(InvalidElementAsNaN)
|
||||
{
|
||||
dst_buf(Number<dst_offset>{}) =
|
||||
is_src_valid
|
||||
? type_convert<DstData>(src_vector.template AsType<SrcData>()[i])
|
||||
: NumericLimits<DstData>::QuietNaN();
|
||||
}
|
||||
else
|
||||
{
|
||||
dst_buf(Number<dst_offset>{}) =
|
||||
type_convert<DstData>(src_vector.template AsType<SrcData>()[i]);
|
||||
}
|
||||
});
|
||||
|
||||
if constexpr(idx_1d.value != num_access - 1)
|
||||
|
||||
Reference in New Issue
Block a user