mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-05 20:55:59 +00:00
Add new DeviceMem API to copy memory
This commit is contained in:
@@ -3,8 +3,12 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
template <typename T>
|
||||
__global__ void set_buffer_value(T* p, T x, uint64_t buffer_element_size)
|
||||
{
|
||||
@@ -14,19 +18,50 @@ __global__ void set_buffer_value(T* p, T x, uint64_t buffer_element_size)
|
||||
}
|
||||
}
|
||||
|
||||
struct DeviceMem
|
||||
class DeviceMem
|
||||
{
|
||||
void ToDeviceImpl(const void* p) const;
|
||||
void FromDeviceImpl(void* p) const;
|
||||
|
||||
public:
|
||||
DeviceMem() = delete;
|
||||
DeviceMem(std::size_t mem_size);
|
||||
void* GetDeviceBuffer() const;
|
||||
std::size_t GetBufferSize() const;
|
||||
void ToDevice(const void* p) const;
|
||||
void FromDevice(void* p) const;
|
||||
void SetZero() const;
|
||||
template <typename T>
|
||||
void SetValue(T x) const;
|
||||
~DeviceMem();
|
||||
|
||||
template <typename T>
|
||||
std::enable_if_t<!std::is_const_v<T> &&
|
||||
(std::is_same_v<T, int8_t> || std::is_same_v<T, int32_t> ||
|
||||
std::is_same_v<T, ck::bhalf_t> || std::is_same_v<T, ck::half_t> ||
|
||||
std::is_same_v<T, float> || std::is_same_v<T, double>)>
|
||||
FromDevice(T* host) const
|
||||
{
|
||||
assert(device.GetBufferSize() % sizeof(T) == 0);
|
||||
|
||||
FromDeviceImpl(host);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::enable_if_t<std::is_same_v<T, int8_t> || std::is_same_v<T, int32_t> ||
|
||||
std::is_same_v<T, ck::bhalf_t> || std::is_same_v<T, ck::half_t> ||
|
||||
std::is_same_v<T, ck::half_t> || std::is_same_v<T, float> ||
|
||||
std::is_same_v<T, double>>
|
||||
ToDevice(const T* host) const
|
||||
{
|
||||
assert(device.GetBufferSize() % sizeof(T) == 0);
|
||||
|
||||
ToDeviceImpl(host);
|
||||
}
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
void ToDevice(const ck::int4_t* host) const;
|
||||
void FromDevice(ck::int4_t* host) const;
|
||||
#endif
|
||||
|
||||
void* mpDeviceBuf;
|
||||
std::size_t mMemSize;
|
||||
};
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/host_utility/hip_check_error.hpp"
|
||||
#include <memory>
|
||||
|
||||
#include "ck/host_utility/hip_check_error.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
|
||||
DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size)
|
||||
@@ -14,16 +15,38 @@ void* DeviceMem::GetDeviceBuffer() const { return mpDeviceBuf; }
|
||||
|
||||
std::size_t DeviceMem::GetBufferSize() const { return mMemSize; }
|
||||
|
||||
void DeviceMem::ToDevice(const void* p) const
|
||||
void DeviceMem::ToDeviceImpl(const void* p) const
|
||||
{
|
||||
hip_check_error(hipMemcpy(mpDeviceBuf, const_cast<void*>(p), mMemSize, hipMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
void DeviceMem::FromDevice(void* p) const
|
||||
void DeviceMem::FromDeviceImpl(void* p) const
|
||||
{
|
||||
hip_check_error(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost));
|
||||
}
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
void DeviceMem::ToDevice(const ck::int4_t* host) const
|
||||
{
|
||||
const std::size_t count = GetBufferSize();
|
||||
const auto buffer = std::make_unique<int8_t[]>(count);
|
||||
|
||||
std::copy_n(host, count, buffer.get());
|
||||
|
||||
ToDevice(buffer.get());
|
||||
}
|
||||
|
||||
void DeviceMem::FromDevice(ck::int4_t* host) const
|
||||
{
|
||||
const std::size_t count = GetBufferSize();
|
||||
const auto buffer = std::make_unique<int8_t[]>(count);
|
||||
|
||||
FromDevice(buffer.get());
|
||||
|
||||
std::copy_n(buffer.get(), count, host);
|
||||
}
|
||||
#endif
|
||||
|
||||
void DeviceMem::SetZero() const { hip_check_error(hipMemset(mpDeviceBuf, 0, mMemSize)); }
|
||||
|
||||
DeviceMem::~DeviceMem() { hip_check_error(hipFree(mpDeviceBuf)); }
|
||||
|
||||
Reference in New Issue
Block a user