Add new DeviceMem API to copy memory

This commit is contained in:
Po-Yen, Chen
2022-08-21 01:16:47 -04:00
parent 1facdbd08f
commit 3f190b0779
2 changed files with 64 additions and 6 deletions

View File

@@ -3,8 +3,12 @@
#pragma once
#include <cassert>
#include <hip/hip_runtime.h>
#include "ck/utility/data_type.hpp"
template <typename T>
__global__ void set_buffer_value(T* p, T x, uint64_t buffer_element_size)
{
@@ -14,19 +18,50 @@ __global__ void set_buffer_value(T* p, T x, uint64_t buffer_element_size)
}
}
struct DeviceMem
class DeviceMem
{
void ToDeviceImpl(const void* p) const;
void FromDeviceImpl(void* p) const;
public:
DeviceMem() = delete;
DeviceMem(std::size_t mem_size);
void* GetDeviceBuffer() const;
std::size_t GetBufferSize() const;
void ToDevice(const void* p) const;
void FromDevice(void* p) const;
void SetZero() const;
template <typename T>
void SetValue(T x) const;
~DeviceMem();
template <typename T>
std::enable_if_t<!std::is_const_v<T> &&
(std::is_same_v<T, int8_t> || std::is_same_v<T, int32_t> ||
std::is_same_v<T, ck::bhalf_t> || std::is_same_v<T, ck::half_t> ||
std::is_same_v<T, float> || std::is_same_v<T, double>)>
FromDevice(T* host) const
{
assert(device.GetBufferSize() % sizeof(T) == 0);
FromDeviceImpl(host);
}
template <typename T>
std::enable_if_t<std::is_same_v<T, int8_t> || std::is_same_v<T, int32_t> ||
std::is_same_v<T, ck::bhalf_t> || std::is_same_v<T, ck::half_t> ||
std::is_same_v<T, ck::half_t> || std::is_same_v<T, float> ||
std::is_same_v<T, double>>
ToDevice(const T* host) const
{
assert(device.GetBufferSize() % sizeof(T) == 0);
ToDeviceImpl(host);
}
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
void ToDevice(const ck::int4_t* host) const;
void FromDevice(ck::int4_t* host) const;
#endif
void* mpDeviceBuf;
std::size_t mMemSize;
};

View File

@@ -1,8 +1,9 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host_utility/hip_check_error.hpp"
#include <memory>
#include "ck/host_utility/hip_check_error.hpp"
#include "ck/library/utility/device_memory.hpp"
DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size)
@@ -14,16 +15,38 @@ void* DeviceMem::GetDeviceBuffer() const { return mpDeviceBuf; }
std::size_t DeviceMem::GetBufferSize() const { return mMemSize; }
void DeviceMem::ToDevice(const void* p) const
void DeviceMem::ToDeviceImpl(const void* p) const
{
hip_check_error(hipMemcpy(mpDeviceBuf, const_cast<void*>(p), mMemSize, hipMemcpyHostToDevice));
}
void DeviceMem::FromDevice(void* p) const
void DeviceMem::FromDeviceImpl(void* p) const
{
hip_check_error(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost));
}
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
void DeviceMem::ToDevice(const ck::int4_t* host) const
{
const std::size_t count = GetBufferSize();
const auto buffer = std::make_unique<int8_t[]>(count);
std::copy_n(host, count, buffer.get());
ToDevice(buffer.get());
}
void DeviceMem::FromDevice(ck::int4_t* host) const
{
const std::size_t count = GetBufferSize();
const auto buffer = std::make_unique<int8_t[]>(count);
FromDevice(buffer.get());
std::copy_n(buffer.get(), count, host);
}
#endif
void DeviceMem::SetZero() const { hip_check_error(hipMemset(mpDeviceBuf, 0, mMemSize)); }
DeviceMem::~DeviceMem() { hip_check_error(hipFree(mpDeviceBuf)); }