Files
composable_kernel/library/include/ck/library/utility/device_memory.hpp
zjing14 c79ecbccfb Grouped Gemm with Fixed K and N with SplitK (#818)
* move all arguments into device

* add b2c_tile_map

* add examples

* add SetDeviceKernelArgs

* dedicated fixed_nk solution

* init client api

* add grouped_gemm_bias example

* add a instance

* add instances

* formatting

* fixed cmake

* Update EnableCompilerWarnings.cmake

* Update cmake-ck-dev.sh

* clean; fixed comments

* fixed comment

* add instances for fp32 output

* add instances for fp32 output

* add fp32 out client example

* fixed CI

* init commit for kbatch

* add splitk gridwise

* format

* fixed

* clean deviceop

* clean code

* finish splitk

* fixed instances

* change m_loops to tile_loops

* add setkbatch

* clean code

* add splitK+bias

* add instances

* opt mk_nk instances

* clean examples

* fixed CI

* remove zero

* finished non-zero

* clean

* clean code

* optimized global_barrier

* fixed ci

* fixed CI

* removed AddBias

* format

* fixed CI

* fixed CI

* move 20_grouped_gemm to 21_grouped_gemm

---------

Co-authored-by: Jing Zhang <jizha@amd.com>

[ROCm/composable_kernel commit: f5ec04f091]
2023-08-31 09:22:12 -05:00

51 lines
1.3 KiB
C++

// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <hip/hip_runtime.h>
template <typename T>
__global__ void set_buffer_value(T* p, T x, uint64_t buffer_element_size)
{
for(uint64_t i = threadIdx.x; i < buffer_element_size; i += blockDim.x)
{
p[i] = x;
}
}
/**
* @brief Container for storing data in GPU device memory
*
*/
struct DeviceMem
{
DeviceMem() : mpDeviceBuf(nullptr), mMemSize(0) {}
DeviceMem(std::size_t mem_size);
void Realloc(std::size_t mem_size);
void* GetDeviceBuffer() const;
std::size_t GetBufferSize() const;
void ToDevice(const void* p) const;
void ToDevice(const void* p, const std::size_t cpySize) const;
void FromDevice(void* p) const;
void FromDevice(void* p, const std::size_t cpySize) const;
void SetZero() const;
template <typename T>
void SetValue(T x) const;
~DeviceMem();
void* mpDeviceBuf;
std::size_t mMemSize;
};
template <typename T>
void DeviceMem::SetValue(T x) const
{
if(mMemSize % sizeof(T) != 0)
{
throw std::runtime_error("wrong! not entire DeviceMem will be set");
}
set_buffer_value<T><<<1, 1024>>>(static_cast<T*>(mpDeviceBuf), x, mMemSize / sizeof(T));
}