Files
composable_kernel/library/src/utility/device_memory.cpp
zjing14 f5ec04f091 Grouped Gemm with Fixed K and N with SplitK (#818)
* move all arguments into device

* add b2c_tile_map

* add examples

* add SetDeviceKernelArgs

* dedicated fixed_nk solution

* init client api

* add grouped_gemm_bias example

* add a instance

* add instances

* formatting

* fixed cmake

* Update EnableCompilerWarnings.cmake

* Update cmake-ck-dev.sh

* clean; fixed comments

* fixed comment

* add instances for fp32 output

* add instances for fp32 output

* add fp32 out client example

* fixed CI

* init commit for kbatch

* add splitk gridwise

* format

* fixed

* clean deviceop

* clean code

* finish splitk

* fixed instances

* change m_loops to tile_loops

* add setkbatch

* clean code

* add splitK+bias

* add instances

* opt mk_nk instances

* clean examples

* fixed CI

* remove zero

* finished non-zero

* clean

* clean code

* optimized global_barrier

* fixed ci

* fixed CI

* removed AddBias

* format

* fixed CI

* fixed CI

* move 20_grouped_gemm to 21_grouped_gemm

---------

Co-authored-by: Jing Zhang <jizha@amd.com>
2023-08-31 09:22:12 -05:00

77 lines
1.8 KiB
C++

// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host_utility/hip_check_error.hpp"
#include "ck/library/utility/device_memory.hpp"
DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size)
{
hip_check_error(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
}
void DeviceMem::Realloc(std::size_t mem_size)
{
if(mpDeviceBuf)
{
hip_check_error(hipFree(mpDeviceBuf));
}
mMemSize = mem_size;
hip_check_error(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
}
void* DeviceMem::GetDeviceBuffer() const { return mpDeviceBuf; }
std::size_t DeviceMem::GetBufferSize() const { return mMemSize; }
void DeviceMem::ToDevice(const void* p) const
{
if(mpDeviceBuf)
{
hip_check_error(
hipMemcpy(mpDeviceBuf, const_cast<void*>(p), mMemSize, hipMemcpyHostToDevice));
}
else
{
throw std::runtime_error("ToDevice with an empty pointer");
}
}
void DeviceMem::ToDevice(const void* p, const std::size_t cpySize) const
{
hip_check_error(hipMemcpy(mpDeviceBuf, const_cast<void*>(p), cpySize, hipMemcpyHostToDevice));
}
void DeviceMem::FromDevice(void* p) const
{
if(mpDeviceBuf)
{
hip_check_error(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost));
}
else
{
throw std::runtime_error("FromDevice with an empty pointer");
}
}
void DeviceMem::FromDevice(void* p, const std::size_t cpySize) const
{
hip_check_error(hipMemcpy(p, mpDeviceBuf, cpySize, hipMemcpyDeviceToHost));
}
void DeviceMem::SetZero() const
{
if(mpDeviceBuf)
{
hip_check_error(hipMemset(mpDeviceBuf, 0, mMemSize));
}
}
DeviceMem::~DeviceMem()
{
if(mpDeviceBuf)
{
hip_check_error(hipFree(mpDeviceBuf));
}
}