Add basic support for direct loads from global to LDS (#999)

* Add basic support for direct loads from global to LDS

* Clean the code and comments

* Add support for fp16

* Add comments

* Add check for thread cluster lengths

* Align non-direct-load fp16 example

* Small fixes

* Extend IsSupported to check for supported GPU gens

* Build examples only on the supported HW

* Do not throw when instance not supported in 04 example

* Review: Apply review suggestions

* Review: small fix

* Review: small fix

[ROCm/composable_kernel commit: 627054b941]
This commit is contained in:
Bartlomiej Wroblewski
2023-11-25 13:35:22 +01:00
committed by GitHub
parent b88a739b88
commit fbbbce4fb4
23 changed files with 2783 additions and 2 deletions

View File

@@ -944,4 +944,41 @@ amd_buffer_atomic_max(const typename vector_type_maker<T, N>::type::type src_thr
#endif
}
// Direct loads from global to LDS.
__device__ void
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
__attribute__((address_space(3))) uint32_t* lds_ptr,
index_t size,
index_t voffset,
index_t soffset,
index_t offset,
index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds");
template <typename T, index_t NumElemsPerThread>
__device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
const index_t global_offset,
T* lds_base_ptr,
const index_t lds_offset,
const bool is_valid,
const index_t src_element_space_size)
{
// Direct loads require that each thread reads and writes exactly a single DWORD.
constexpr auto dword_bytes = 4;
constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread;
static_assert(bytes_per_thread == dword_bytes);
const uint32_t* global_ptr =
reinterpret_cast<uint32_t*>(reinterpret_cast<uintptr_t>(global_base_ptr));
const int32x4_t src_resource = make_wave_buffer_resource(global_ptr, src_element_space_size);
const index_t global_offset_bytes = is_valid ? global_offset * sizeof(T) : 0x80000000;
// LDS pointer must be attributed with the LDS address space.
__attribute__((address_space(3))) uint32_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset));
llvm_amdgcn_raw_buffer_load_lds(
src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0);
}
} // namespace ck

View File

@@ -173,6 +173,26 @@ struct DynamicBuffer
}
}
template <typename DstBuffer, index_t NumElemsPerThread>
__host__ __device__ void DirectCopyToLds(DstBuffer& dst_buf,
index_t src_offset,
index_t dst_offset,
bool is_valid_element) const
{
// Copy data from global to LDS memory using direct loads.
static_assert(GetAddressSpace() == AddressSpaceEnum::Global,
"Source data must come from a global memory buffer.");
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
"Destination data must be stored in an LDS memory buffer.");
amd_direct_load_global_to_lds<T, NumElemsPerThread>(p_data_,
src_offset,
dst_buf.p_data_,
dst_offset,
is_valid_element,
element_space_size_);
}
template <typename X,
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value,

View File

@@ -19,6 +19,15 @@ __device__ void block_sync_lds()
#endif
}
__device__ void block_sync_lds_direct_load()
{
asm volatile("\
s_waitcnt vmcnt(0) \n \
s_waitcnt lgkmcnt(0) \n \
s_barrier \
" ::);
}
__device__ void s_nop()
{
#if 1