Add testing of GetGenericInstance() in client_example of Softmax

This commit is contained in:
Qianfeng Zhang
2023-05-31 20:05:54 +00:00
parent a9f0d000eb
commit f629cd9a93

View File

@@ -6,6 +6,7 @@
#include <iomanip>
#include <iostream>
#include <vector>
#include <stdexcept>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
@@ -60,6 +61,24 @@ int main(int argc, char* argv[])
PassThrough,
Rank,
NumReduceDim>;
const auto g_op_ptr = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetGenericInstance();
auto g_op_argument_ptr = g_op_ptr->MakeArgumentPointer(in_lengths,
in_strides,
reduce_dims,
alpha,
beta,
in.GetDeviceBuffer(),
out.GetDeviceBuffer(),
PassThrough{},
PassThrough{});
if(!g_op_ptr->IsSupportedArgument(g_op_argument_ptr.get()))
throw std::runtime_error(
"Generic instance should be suitable for various input lengths/strides");
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
@@ -122,6 +141,7 @@ int main(int argc, char* argv[])
<< best_op_name << std::endl;
// run the best intance
if(found)
{
auto& op_ptr = op_ptrs[best_op_id];
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()