mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
Softmax client example (#396)
* Update Softmax device operation interface. * Update ckProfiler. * Update Softmax UT. * Update example. * Client example. * Clang format Co-authored-by: Adam Osewski <aosewski@amd.com>
This commit is contained in:
@@ -9,7 +9,8 @@
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/utility/number.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_softmax.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
@@ -51,19 +52,23 @@ class TestSoftmax : public ::testing::Test
|
||||
using ReferenceInstance =
|
||||
tensor_operation::host::ReferenceSoftmax<InDataType, OutDataType, AccDataType>;
|
||||
|
||||
using DeviceInstance = tensor_operation::device::DeviceSoftmax<InDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
Rank,
|
||||
NumReduceDim,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
OutDstVectorSize>;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using DeviceInstance = tensor_operation::device::DeviceSoftmaxImpl<InDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Rank,
|
||||
NumReduceDim,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
OutDstVectorSize>;
|
||||
|
||||
TestSoftmax() : ref_instance_invoker_(ReferenceInstance{}.MakeInvoker()) {}
|
||||
|
||||
@@ -97,7 +102,9 @@ class TestSoftmax : public ::testing::Test
|
||||
&alpha,
|
||||
&beta,
|
||||
in_dev.GetDeviceBuffer(),
|
||||
out_dev.GetDeviceBuffer());
|
||||
out_dev.GetDeviceBuffer(),
|
||||
PassThrough{},
|
||||
PassThrough{});
|
||||
|
||||
if(!device_instance.IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user