mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
Transpose builtin macro defense (#2374)
* add the macro defense
* add the static assert check
[ROCm/composable_kernel commit: 107e3623c7]
This commit is contained in:
@@ -2784,10 +2784,13 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
|
||||
#endif
|
||||
}
|
||||
|
||||
#if defined(__gfx950__)
|
||||
template <typename T, index_t N, address_space_enum BufferAddressSpace>
|
||||
__device__ auto amd_transpose_load_to_vgpr(const T* in_ptr)
|
||||
{
|
||||
|
||||
static_assert(__has_builtin(__builtin_amdgcn_raw_buffer_load_b32),
|
||||
"We need to have the compatible compiler version to build this instruction");
|
||||
if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::half_t>)
|
||||
{
|
||||
typedef __attribute__((__vector_size__(4 * sizeof(__fp16)))) __fp16 llvm_fp16x4_t;
|
||||
@@ -2817,6 +2820,7 @@ __device__ auto amd_transpose_load_to_vgpr(const T* in_ptr)
|
||||
static_assert(false, "not implemented");
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
|
||||
@@ -2554,6 +2554,44 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
|
||||
#endif
|
||||
}
|
||||
|
||||
#if defined(__gfx950__)
|
||||
template <typename T, index_t N, address_space_enum BufferAddressSpace>
|
||||
__device__ auto amd_transpose_load_to_vgpr(const T* in_ptr)
|
||||
{
|
||||
|
||||
static_assert(__has_builtin(__builtin_amdgcn_raw_buffer_load_b32),
|
||||
"We need to have the compatible compiler version to build this instruction");
|
||||
if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::half_t>)
|
||||
{
|
||||
typedef __attribute__((__vector_size__(4 * sizeof(__fp16)))) __fp16 llvm_fp16x4_t;
|
||||
__attribute__((address_space(3))) llvm_fp16x4_t* lds_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(3))) llvm_fp16x4_t*>(
|
||||
reinterpret_cast<uintptr_t>(in_ptr));
|
||||
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4f16(lds_ptr));
|
||||
}
|
||||
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::bf16_t>)
|
||||
{
|
||||
typedef __attribute__((__vector_size__(4 * sizeof(__bf16)))) __bf16 llvm_bf16x4_t;
|
||||
__attribute__((address_space(3))) llvm_bf16x4_t* lds_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(3))) llvm_bf16x4_t*>(
|
||||
reinterpret_cast<uintptr_t>(in_ptr));
|
||||
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4bf16(lds_ptr));
|
||||
}
|
||||
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::fp8_t>)
|
||||
{
|
||||
typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_fp8x8_t;
|
||||
__attribute__((address_space(3))) llvm_fp8x8_t* lds_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(3))) llvm_fp8x8_t*>(
|
||||
reinterpret_cast<uintptr_t>(in_ptr));
|
||||
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr8_b64_v2i32(lds_ptr));
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "not implemented");
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
#endif // CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
|
||||
|
||||
@@ -902,8 +902,9 @@ struct buffer_view<address_space_enum::lds,
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
transpose_get(index_t i, index_t linear_offset, bool is_valid_element) const
|
||||
CK_TILE_DEVICE constexpr auto transpose_get([[maybe_unused]] index_t i,
|
||||
[[maybe_unused]] index_t linear_offset,
|
||||
bool is_valid_element) const
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
|
||||
@@ -913,13 +914,16 @@ struct buffer_view<address_space_enum::lds,
|
||||
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
||||
"wrong! X should contain multiple T");
|
||||
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
if(is_valid_element)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
constexpr address_space_enum addr_space = get_address_space();
|
||||
return amd_transpose_load_to_vgpr<remove_cvref_t<T>, t_per_x, addr_space>(
|
||||
p_data_ + i + linear_offset);
|
||||
#else
|
||||
return X{numeric<remove_cvref_t<T>>::zero()};
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user