Files
composable_kernel/example/39_permute/run_permute_element_example.inc
Aviral Goel d85f065b15 chore(copyright): update copyright header for example directory (#3273)
* chore(copyright): update copyright header for codegen directory

* chore(copyright): update copyright header for example directory
2025-11-24 18:02:41 -08:00

68 lines
2.5 KiB
C++

// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
bool run_permute_element(const Problem& problem, bool time_kernel)
{
const auto& input_shape = problem.shape;
const auto& input_axes = problem.axes;
const auto output_shape = transpose(input_shape, input_axes);
Tensor<InDataType> input_tensor(input_shape);
Tensor<OutDataType> output_tensor(output_shape);
ck::utils::FillUniformDistribution<InDataType>{-1.f, 1.f}(input_tensor);
DeviceMem input_device_buf(input_tensor.GetElementSpaceSizeInBytes());
DeviceMem output_device_buf(output_tensor.GetElementSpaceSizeInBytes());
using std::data;
input_device_buf.ToDevice(data(input_tensor));
static_assert(std::is_default_constructible_v<DevicePermuteInstance>);
auto permute = DevicePermuteInstance{};
auto argument = permute.MakeArgument(to_array(input_shape),
to_array(input_tensor.GetStrides()),
to_array(output_shape),
to_array(output_tensor.GetStrides()),
input_device_buf.GetDeviceBuffer(),
output_device_buf.GetDeviceBuffer(),
PassThrough{});
if(!permute.IsSupportedArgument(argument))
{
std::cerr << "The runtime parameters seems not supported by the device instance, exiting!"
<< std::endl;
return false;
};
auto invoker = permute.MakeInvoker();
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::cout << "Perf: " << ave_time << " ms" << std::endl;
output_device_buf.FromDevice(data(output_tensor));
Tensor<OutDataType> output_tensor_host(output_shape);
if(!host_permute(input_tensor, input_axes, PassThrough{}, output_tensor_host))
{
return false;
}
return ck::utils::check_err(output_tensor.AsSpan<const OutDataType>(),
output_tensor_host.AsSpan<const OutDataType>(),
"Error: incorrect results in output tensor",
1e-6,
1e-6);
}
bool run_permute_element_example(const Problem::Shape& shape,
const Problem::Axes& axes,
bool time_kernel)
{
return run_permute_element(Problem{shape, axes}, time_kernel);
}