mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 04:31:25 +00:00
Standalone sweep once softmax kernel w/ ckProfiler (#295)
* use 'sweep once' softmax kernel where applicable * threadwise copy's dst buffer can specify invalid element value * add int8 in/out float compute softmax support give a bit of leeway for int absolute tolerance as there's a single data point of all test cases showing off-by-1 error * format * softmax inherits DeviceNormalization * softmax profiler stub * tighten up reference softmax interface * example prints tensor dimension * add fp32 to softmax profiler * rename header * hook with ckProfiler * format * resolve merge conflict * resolve merge conflicts * update normalization profiler help string * resolve conflict * typo * remove residual * softmax profiler: address feedback * test for mixed precision input/output * fully qualify ck::math::isnan * add comment for device normalization interface * revise wording * constness for alpha/beta scaler pointer
This commit is contained in:
@@ -236,9 +236,14 @@ template <typename SrcData,
|
||||
index_t SrcScalarPerVector,
|
||||
index_t SrcScalarStrideInVector,
|
||||
bool SrcResetCoordinateAfterRun,
|
||||
bool InvalidElementAsNaN = false,
|
||||
typename enable_if<DstDesc::IsKnownAtCompileTime(), bool>::type = false>
|
||||
struct ThreadwiseTensorSliceTransfer_v2
|
||||
{
|
||||
static_assert((InvalidElementAsNaN && !std::is_integral<DstData>::value) ||
|
||||
(!InvalidElementAsNaN),
|
||||
"Filling invalid element as NaN is only for floating point types");
|
||||
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
@@ -318,8 +323,18 @@ struct ThreadwiseTensorSliceTransfer_v2
|
||||
dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx +
|
||||
i * src_scalar_step_in_vector);
|
||||
|
||||
dst_buf(Number<dst_offset>{}) =
|
||||
type_convert<DstData>(src_vector.template AsType<SrcData>()[i]);
|
||||
if constexpr(InvalidElementAsNaN)
|
||||
{
|
||||
dst_buf(Number<dst_offset>{}) =
|
||||
is_src_valid
|
||||
? type_convert<DstData>(src_vector.template AsType<SrcData>()[i])
|
||||
: NumericLimits<DstData>::QuietNaN();
|
||||
}
|
||||
else
|
||||
{
|
||||
dst_buf(Number<dst_offset>{}) =
|
||||
type_convert<DstData>(src_vector.template AsType<SrcData>()[i]);
|
||||
}
|
||||
});
|
||||
|
||||
if constexpr(idx_1d.value != num_access - 1)
|
||||
|
||||
Reference in New Issue
Block a user