mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-05 20:55:59 +00:00
Add testing of GetGenericInstance() in client_example of Softmax
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <stdexcept>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
@@ -60,6 +61,24 @@ int main(int argc, char* argv[])
|
||||
PassThrough,
|
||||
Rank,
|
||||
NumReduceDim>;
|
||||
|
||||
const auto g_op_ptr = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetGenericInstance();
|
||||
|
||||
auto g_op_argument_ptr = g_op_ptr->MakeArgumentPointer(in_lengths,
|
||||
in_strides,
|
||||
reduce_dims,
|
||||
alpha,
|
||||
beta,
|
||||
in.GetDeviceBuffer(),
|
||||
out.GetDeviceBuffer(),
|
||||
PassThrough{},
|
||||
PassThrough{});
|
||||
|
||||
if(!g_op_ptr->IsSupportedArgument(g_op_argument_ptr.get()))
|
||||
throw std::runtime_error(
|
||||
"Generic instance should be suitable for various input lengths/strides");
|
||||
|
||||
// get device op instances
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
@@ -122,6 +141,7 @@ int main(int argc, char* argv[])
|
||||
<< best_op_name << std::endl;
|
||||
|
||||
// run the best intance
|
||||
if(found)
|
||||
{
|
||||
auto& op_ptr = op_ptrs[best_op_id];
|
||||
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
|
||||
|
||||
Reference in New Issue
Block a user