Softmax client example (#396)

* Update Softmax device operation interface.

* Update ckProfiler.

* Update Softmax UT.

* Update example.

* Client example.

* Clang format

Co-authored-by: Adam Osewski <aosewski@amd.com>
This commit is contained in:
Adam Osewski
2022-09-06 19:22:48 +02:00
committed by GitHub
parent 7589116121
commit 3da5c19e62
13 changed files with 738 additions and 331 deletions

View File

@@ -9,37 +9,41 @@
#include "ck/ck.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/tensor_operation/gpu/device/device_softmax.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
using namespace ck;
using namespace ck::tensor_operation::device;
using InDataType = ck::half_t;
using OutDataType = ck::half_t;
using AccDataType = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
constexpr int Rank = 3;
constexpr int NumReduceDim = 1;
using DeviceInstance = DeviceSoftmax<InDataType,
AccDataType,
OutDataType,
Rank,
NumReduceDim,
256, // BlockSize
8, // ClusterM
32, // ClusterK
1, // SliceM
8, // SliceK
1, // SrcVecDim (0=M, 1=K)
8, // SrcScalarPerVector
8>; // OutScalarPerVector
using DeviceInstance = DeviceSoftmaxImpl<InDataType,
AccDataType,
OutDataType,
PassThrough, // InElementwiseOperation
PassThrough, // AccElementwiseOperation
Rank,
NumReduceDim,
256, // BlockSize
8, // ClusterM
32, // ClusterK
1, // SliceM
8, // SliceK
1, // SrcVecDim (0=M, 1=K)
8, // SrcScalarPerVector
8>; // OutScalarPerVector
static struct option long_options[] = {{"inLengths", required_argument, nullptr, 'D'},
{"verify", required_argument, nullptr, 'v'},
@@ -196,7 +200,7 @@ int main(int argc, char* argv[])
if(args.do_verification)
{
using ReferenceInstance =
tensor_operation::host::ReferenceSoftmax<InDataType, OutDataType, AccDataType>;
ck::tensor_operation::host::ReferenceSoftmax<InDataType, OutDataType, AccDataType>;
ReferenceInstance ref;
auto ref_arg = ref.MakeArgument(in, out_ref, alpha, beta, reduceDims);
auto invoker = ref.MakeInvoker();
@@ -220,7 +224,9 @@ int main(int argc, char* argv[])
&alpha,
&beta,
in_dev.GetDeviceBuffer(),
out_dev.GetDeviceBuffer());
out_dev.GetDeviceBuffer(),
PassThrough{},
PassThrough{});
if(!device_instance.IsSupportedArgument(argument_ptr.get()))
{