mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 17:55:48 +00:00
experimenting global and buffer load/store
This commit is contained in:
@@ -8,6 +8,114 @@ namespace ck {
|
||||
// cast a pointer of LDS to its address
|
||||
extern "C" __attribute__((address_space(3))) __device__ void* __to_local(void* p);
|
||||
|
||||
// buffer_load and buffer_store
|
||||
template <typename T, index_t VectorSize>
|
||||
__device__ typename vector_type<T, VectorSize>::MemoryType
|
||||
buffer_load(const T* p_src_block, uint32_t src_thread_offset, uint32_t src_const_offset);
|
||||
|
||||
template <typename T, index_t VectorSize>
|
||||
__device__ void buffer_store(const typename vector_type<T, VectorSize>::MemoryType& src,
|
||||
T* p_dst_block,
|
||||
uint32_t dst_thread_offset,
|
||||
uint32_t dst_const_offset);
|
||||
|
||||
template <>
|
||||
__device__ float buffer_load<float, 1>(const float* p_src_block,
|
||||
uint32_t src_thread_offset,
|
||||
uint32_t src_const_offset)
|
||||
{
|
||||
float dst;
|
||||
|
||||
int32x4_t src_block_setting{0};
|
||||
// fill in byte 0 - 1
|
||||
*reinterpret_cast<float**>(&src_block_setting) = const_cast<float*>(p_src_block);
|
||||
// fill in byte 2
|
||||
reinterpret_cast<int*>(&src_block_setting)[2] = -1;
|
||||
// fill in byte 3
|
||||
reinterpret_cast<int*>(&src_block_setting)[3] = 0x00027000;
|
||||
|
||||
asm volatile("\n \
|
||||
buffer_load_dword %0, %1, %2, %3 offen offset:0 \n \
|
||||
s_waitcnt 0 \n \
|
||||
"
|
||||
: "=v"(dst)
|
||||
: "v"(src_thread_offset), "s"(src_block_setting), "s"(src_const_offset));
|
||||
|
||||
return dst;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ vector_type<float, 2>::MemoryType buffer_load<float, 2>(const float* p_src_block,
|
||||
uint32_t src_thread_offset,
|
||||
uint32_t src_const_offset)
|
||||
{
|
||||
vector_type<float, 2>::MemoryType dst;
|
||||
|
||||
int32x4_t src_block_setting{0};
|
||||
// fill in byte 0 - 1
|
||||
*reinterpret_cast<float**>(&src_block_setting) = const_cast<float*>(p_src_block);
|
||||
// fill in byte 2
|
||||
reinterpret_cast<int*>(&src_block_setting)[2] = -1;
|
||||
// fill in byte 3
|
||||
reinterpret_cast<int*>(&src_block_setting)[3] = 0x00027000;
|
||||
|
||||
asm volatile("\n \
|
||||
buffer_load_dwordx2 %0, %1, %2, %3 offen offset:0 \n \
|
||||
s_waitcnt 0 \n \
|
||||
"
|
||||
: "=v"(dst)
|
||||
: "v"(src_thread_offset), "s"(src_block_setting), "s"(src_const_offset));
|
||||
|
||||
return dst;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ vector_type<float, 4>::MemoryType buffer_load<float, 4>(const float* p_src_block,
|
||||
uint32_t src_thread_offset,
|
||||
uint32_t src_const_offset)
|
||||
{
|
||||
vector_type<float, 4>::MemoryType dst;
|
||||
|
||||
int32x4_t src_block_setting{0};
|
||||
// fill in byte 0 - 1
|
||||
*reinterpret_cast<float**>(&src_block_setting) = const_cast<float*>(p_src_block);
|
||||
// fill in byte 2
|
||||
reinterpret_cast<int*>(&src_block_setting)[2] = -1;
|
||||
// fill in byte 3
|
||||
reinterpret_cast<int*>(&src_block_setting)[3] = 0x00027000;
|
||||
|
||||
asm volatile("\n \
|
||||
buffer_load_dwordx4 %0, %1, %2, %3 offen offset:0 \n \
|
||||
s_waitcnt 0 \n \
|
||||
"
|
||||
: "=v"(dst)
|
||||
: "v"(src_thread_offset), "s"(src_block_setting), "s"(src_const_offset));
|
||||
|
||||
return dst;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void buffer_store<float, 1>(const float& src,
|
||||
float* p_dst_block,
|
||||
uint32_t dst_thread_offset,
|
||||
uint32_t dst_const_offset)
|
||||
{
|
||||
int32x4_t dst_block_setting{0};
|
||||
// fill in byte 0 - 1
|
||||
*reinterpret_cast<float**>(&dst_block_setting) = p_dst_block;
|
||||
// fill in byte 2
|
||||
reinterpret_cast<int*>(&dst_block_setting)[2] = -1;
|
||||
// fill in byte 3
|
||||
reinterpret_cast<int*>(&dst_block_setting)[3] = 0x00027000;
|
||||
|
||||
asm volatile("\n \
|
||||
buffer_store_dword %1, %2, %0, %3 offen offset:0 \n \
|
||||
s_waitcnt 0 \n \
|
||||
"
|
||||
:
|
||||
: "s"(dst_block_setting), "v"(src), "v"(dst_thread_offset), "s"(dst_const_offset));
|
||||
}
|
||||
|
||||
__device__ void vmcnt(index_t cnt)
|
||||
{
|
||||
if(cnt == 0)
|
||||
|
||||
Reference in New Issue
Block a user