mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Wave Tile Transfer supporting global load with transpose (#3027)
* Initial implementation: - add new thread group transfer supporting transpose instruction - refactor AB transfer to switch between thread and wave tiles methods * Add some comments and remove explicit wave and lane calculations * Remove compiler option for performance * fp16 example: use tuned instance * Missing cleanup * Integrate wave transfer in existing gemm and batched gemm instances * Add fast instances * extend implementation for 8 bit datatypes packed types not supported * Address review comments * Optimize pipeline v1 and re-introduce compiler option * Disable wave tile approach for b scale gemm * Fix for clang20 * Avoid code duplication of amd_global_load_transpose_to_vgpr function
This commit is contained in:
37
include/ck/utility/amd_transpose_load.hpp
Normal file
37
include/ck/utility/amd_transpose_load.hpp
Normal file
@@ -0,0 +1,37 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include "data_type.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
#if defined(__gfx12__)
|
||||
template <typename T>
|
||||
__device__ auto amd_global_load_transpose_to_vgpr(const T* in_ptr)
|
||||
{
|
||||
using vector_t = typename vector_type<T, 8>::type;
|
||||
if constexpr(sizeof(T) == 2)
|
||||
{
|
||||
typedef __attribute__((__vector_size__(8 * sizeof(__fp16)))) __fp16 llvm_fp16x8_t;
|
||||
__attribute__((address_space(1))) llvm_fp16x8_t* glb_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(1))) llvm_fp16x8_t*>(
|
||||
reinterpret_cast<uintptr_t>(in_ptr));
|
||||
return bit_cast<vector_t>(__builtin_amdgcn_global_load_tr_b128_v8f16(glb_ptr));
|
||||
}
|
||||
else if constexpr(sizeof(T) == 1)
|
||||
{
|
||||
typedef __attribute__((__vector_size__(2 * sizeof(int)))) int llvm_intx2_t;
|
||||
__attribute__((address_space(1))) llvm_intx2_t* glb_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(1))) llvm_intx2_t*>(
|
||||
reinterpret_cast<uintptr_t>(in_ptr));
|
||||
return bit_cast<vector_t>(__builtin_amdgcn_global_load_tr_b64_v2i32(glb_ptr));
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "not implemented");
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace ck
|
||||
@@ -12,6 +12,7 @@
|
||||
#else
|
||||
#include "amd_buffer_addressing.hpp"
|
||||
#endif
|
||||
#include "amd_transpose_load.hpp"
|
||||
#include "generic_memory_space_atomic.hpp"
|
||||
|
||||
namespace ck {
|
||||
@@ -69,6 +70,7 @@ struct DynamicBuffer
|
||||
__host__ __device__ constexpr T& operator()(IndexType i) { return p_data_[i]; }
|
||||
|
||||
template <typename X,
|
||||
bool DoTranspose = false,
|
||||
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value ||
|
||||
!is_native_type<X>(),
|
||||
@@ -89,7 +91,8 @@ struct DynamicBuffer
|
||||
bool constexpr use_amd_buffer_addressing = false;
|
||||
#endif
|
||||
|
||||
if constexpr(GetAddressSpace() == AddressSpaceEnum::Global && use_amd_buffer_addressing)
|
||||
if constexpr(GetAddressSpace() == AddressSpaceEnum::Global && use_amd_buffer_addressing &&
|
||||
!DoTranspose)
|
||||
{
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
@@ -112,6 +115,14 @@ struct DynamicBuffer
|
||||
invalid_element_value_);
|
||||
}
|
||||
}
|
||||
else if constexpr(GetAddressSpace() == AddressSpaceEnum::Global && DoTranspose)
|
||||
{
|
||||
#ifdef __gfx12__
|
||||
return amd_global_load_transpose_to_vgpr(p_data_ + i);
|
||||
#else
|
||||
static_assert(!DoTranspose, "load-with-transpose only supported on gfx12+");
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
if(is_valid_element)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -7,15 +7,19 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
#if CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
|
||||
#ifdef __gfx12__
|
||||
__device__ void llvm_amdgcn_s_wait_dscnt(short cnt) __asm("llvm.amdgcn.s.wait.dscnt");
|
||||
#endif
|
||||
#endif
|
||||
|
||||
__device__ void block_sync_lds()
|
||||
{
|
||||
#if CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
|
||||
#ifdef __gfx12__
|
||||
asm volatile("\
|
||||
s_wait_dscnt 0x0 \n \
|
||||
s_barrier_signal -1 \n \
|
||||
s_barrier_wait -1 \
|
||||
" ::);
|
||||
llvm_amdgcn_s_wait_dscnt(0);
|
||||
asm volatile("s_barrier_signal -1\n\t"
|
||||
"s_barrier_wait -1");
|
||||
#else
|
||||
// asm volatile("\
|
||||
// s_waitcnt lgkmcnt(0) \n \
|
||||
|
||||
Reference in New Issue
Block a user