mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 04:31:25 +00:00
Softmax client example (#396)
* Update Softmax device operation interface. * Update ckProfiler. * Update Softmax UT. * Update example. * Client example. * Clang format Co-authored-by: Adam Osewski <aosewski@amd.com>
This commit is contained in:
@@ -9,37 +9,41 @@
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/utility/reduction_enums.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_softmax.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_common_util.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
|
||||
|
||||
using namespace ck;
|
||||
using namespace ck::tensor_operation::device;
|
||||
|
||||
using InDataType = ck::half_t;
|
||||
using OutDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
constexpr int Rank = 3;
|
||||
constexpr int NumReduceDim = 1;
|
||||
|
||||
using DeviceInstance = DeviceSoftmax<InDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
Rank,
|
||||
NumReduceDim,
|
||||
256, // BlockSize
|
||||
8, // ClusterM
|
||||
32, // ClusterK
|
||||
1, // SliceM
|
||||
8, // SliceK
|
||||
1, // SrcVecDim (0=M, 1=K)
|
||||
8, // SrcScalarPerVector
|
||||
8>; // OutScalarPerVector
|
||||
using DeviceInstance = DeviceSoftmaxImpl<InDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
PassThrough, // InElementwiseOperation
|
||||
PassThrough, // AccElementwiseOperation
|
||||
Rank,
|
||||
NumReduceDim,
|
||||
256, // BlockSize
|
||||
8, // ClusterM
|
||||
32, // ClusterK
|
||||
1, // SliceM
|
||||
8, // SliceK
|
||||
1, // SrcVecDim (0=M, 1=K)
|
||||
8, // SrcScalarPerVector
|
||||
8>; // OutScalarPerVector
|
||||
|
||||
static struct option long_options[] = {{"inLengths", required_argument, nullptr, 'D'},
|
||||
{"verify", required_argument, nullptr, 'v'},
|
||||
@@ -196,7 +200,7 @@ int main(int argc, char* argv[])
|
||||
if(args.do_verification)
|
||||
{
|
||||
using ReferenceInstance =
|
||||
tensor_operation::host::ReferenceSoftmax<InDataType, OutDataType, AccDataType>;
|
||||
ck::tensor_operation::host::ReferenceSoftmax<InDataType, OutDataType, AccDataType>;
|
||||
ReferenceInstance ref;
|
||||
auto ref_arg = ref.MakeArgument(in, out_ref, alpha, beta, reduceDims);
|
||||
auto invoker = ref.MakeInvoker();
|
||||
@@ -220,7 +224,9 @@ int main(int argc, char* argv[])
|
||||
&alpha,
|
||||
&beta,
|
||||
in_dev.GetDeviceBuffer(),
|
||||
out_dev.GetDeviceBuffer());
|
||||
out_dev.GetDeviceBuffer(),
|
||||
PassThrough{},
|
||||
PassThrough{});
|
||||
|
||||
if(!device_instance.IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user