mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 10:59:55 +00:00
Maxpool bwd (#750)
* Add maxpool f32 kernel and example
* Revise copyright
* Add device pool bwd device op
* Support f16 and bf16
* Add compute datatype for reference code.
Prevent error in bf16
* Fix type error
* Remove layout
* Fix bf16 error
* Add f16 and bf16 example
* Add more operations
* Implement IsSupportedArgument
* Add changelog
* Add comment
* Add comment
* Remove useless header
* Move initialize of workspace to the run
* Move set din zero to the device operator
* Save din_length_raw
* Remove useless header
* Calculate gridsize according to the number of CU
* Calculate gridSize according to the number of CU.
Remove useless header
* Add put example
* Remove useless header
* Fix CI fail
[ROCm/composable_kernel commit: 341ad95665]
This commit is contained in:
88
example/50_put_element/put_element_fp16.cpp
Normal file
88
example/50_put_element/put_element_fp16.cpp
Normal file
@@ -0,0 +1,88 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_put_element_impl.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
|
||||
using XDataType = ck::half_t;
|
||||
using YDataType = ck::half_t;
|
||||
using IndexDataType = int32_t;
|
||||
|
||||
using YElementwiseOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using DeviceInstance =
|
||||
ck::tensor_operation::device::DevicePutElementImpl<XDataType, // XDataType
|
||||
IndexDataType, // IndexDataType
|
||||
YDataType, // YDataType
|
||||
YElementwiseOp,
|
||||
ck::InMemoryDataOperationEnum::Set,
|
||||
1>;
|
||||
|
||||
int main()
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
int N = 1024;
|
||||
|
||||
Tensor<XDataType> x(HostTensorDescriptor{N, 1});
|
||||
Tensor<IndexDataType> indices(HostTensorDescriptor{N, 1});
|
||||
Tensor<YDataType> y(HostTensorDescriptor{N, 1});
|
||||
|
||||
x.GenerateTensorValue(GeneratorTensor_3<XDataType>{-1.0, 1.0});
|
||||
for(int i = 0; i < N; ++i)
|
||||
indices(i) = i;
|
||||
|
||||
DeviceMem x_device_buf(sizeof(XDataType) * x.mDesc.GetElementSpaceSize());
|
||||
DeviceMem y_device_buf(sizeof(YDataType) * y.mDesc.GetElementSpaceSize());
|
||||
DeviceMem indices_device_buf(sizeof(IndexDataType) * indices.mDesc.GetElementSpaceSize());
|
||||
|
||||
x_device_buf.ToDevice(x.mData.data());
|
||||
indices_device_buf.ToDevice(indices.mData.data());
|
||||
|
||||
auto put_instance = DeviceInstance{};
|
||||
auto put_invoker_ptr = put_instance.MakeInvokerPointer();
|
||||
auto put_argument_ptr = put_instance.MakeArgumentPointer(
|
||||
static_cast<XDataType*>(x_device_buf.GetDeviceBuffer()),
|
||||
static_cast<IndexDataType*>(indices_device_buf.GetDeviceBuffer()),
|
||||
static_cast<YDataType*>(y_device_buf.GetDeviceBuffer()),
|
||||
N,
|
||||
N,
|
||||
YElementwiseOp{});
|
||||
|
||||
if(!put_instance.IsSupportedArgument(put_argument_ptr.get()))
|
||||
{
|
||||
throw std::runtime_error("argument is not supported!");
|
||||
}
|
||||
|
||||
float ave_time =
|
||||
put_invoker_ptr->Run(put_argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::cout << "perf: " << ave_time << " ms" << std::endl;
|
||||
|
||||
bool pass = true;
|
||||
if(do_verification)
|
||||
{
|
||||
Tensor<YDataType> y_host(HostTensorDescriptor{N, 1});
|
||||
|
||||
for(int i = 0; i < N; ++i)
|
||||
{
|
||||
IndexDataType idx = indices(i);
|
||||
y_host(idx) = x(i);
|
||||
}
|
||||
|
||||
y_device_buf.FromDevice(y.mData.data());
|
||||
pass = ck::utils::check_err(y, y_host);
|
||||
}
|
||||
|
||||
return (pass ? 0 : 1);
|
||||
}
|
||||
Reference in New Issue
Block a user