diff --git a/client_example/06_softmax/softmax4d.cpp b/client_example/06_softmax/softmax4d.cpp index aef5624cad..401b161d11 100644 --- a/client_example/06_softmax/softmax4d.cpp +++ b/client_example/06_softmax/softmax4d.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" @@ -60,6 +61,24 @@ int main(int argc, char* argv[]) PassThrough, Rank, NumReduceDim>; + + const auto g_op_ptr = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetGenericInstance(); + + auto g_op_argument_ptr = g_op_ptr->MakeArgumentPointer(in_lengths, + in_strides, + reduce_dims, + alpha, + beta, + in.GetDeviceBuffer(), + out.GetDeviceBuffer(), + PassThrough{}, + PassThrough{}); + + if(!g_op_ptr->IsSupportedArgument(g_op_argument_ptr.get())) + throw std::runtime_error( + "Generic instance should be suitable for various input lengths/strides"); + // get device op instances const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< DeviceOp>::GetInstances(); @@ -122,6 +141,7 @@ int main(int argc, char* argv[]) << best_op_name << std::endl; // run the best intance + if(found) { auto& op_ptr = op_ptrs[best_op_id]; std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()