Transpose builtin macro defense (#2374)

* add the macro defense

* add the static assert check

[ROCm/composable_kernel commit: 107e3623c7]
This commit is contained in:
Thomas Ning
2025-06-20 11:24:54 -07:00
committed by GitHub
parent b7f7728e82
commit 7e4994ac35
3 changed files with 50 additions and 4 deletions

View File

@@ -2784,10 +2784,13 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
#endif
}
#if defined(__gfx950__)
template <typename T, index_t N, address_space_enum BufferAddressSpace>
__device__ auto amd_transpose_load_to_vgpr(const T* in_ptr)
{
static_assert(__has_builtin(__builtin_amdgcn_raw_buffer_load_b32),
"We need to have the compatible compiler version to build this instruction");
if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::half_t>)
{
typedef __attribute__((__vector_size__(4 * sizeof(__fp16)))) __fp16 llvm_fp16x4_t;
@@ -2817,6 +2820,7 @@ __device__ auto amd_transpose_load_to_vgpr(const T* in_ptr)
static_assert(false, "not implemented");
}
}
#endif
} // namespace ck_tile

View File

@@ -2554,6 +2554,44 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
#endif
}
#if defined(__gfx950__)
template <typename T, index_t N, address_space_enum BufferAddressSpace>
__device__ auto amd_transpose_load_to_vgpr(const T* in_ptr)
{
static_assert(__has_builtin(__builtin_amdgcn_raw_buffer_load_b32),
"We need to have the compatible compiler version to build this instruction");
if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::half_t>)
{
typedef __attribute__((__vector_size__(4 * sizeof(__fp16)))) __fp16 llvm_fp16x4_t;
__attribute__((address_space(3))) llvm_fp16x4_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) llvm_fp16x4_t*>(
reinterpret_cast<uintptr_t>(in_ptr));
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4f16(lds_ptr));
}
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::bf16_t>)
{
typedef __attribute__((__vector_size__(4 * sizeof(__bf16)))) __bf16 llvm_bf16x4_t;
__attribute__((address_space(3))) llvm_bf16x4_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) llvm_bf16x4_t*>(
reinterpret_cast<uintptr_t>(in_ptr));
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4bf16(lds_ptr));
}
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::fp8_t>)
{
typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_fp8x8_t;
__attribute__((address_space(3))) llvm_fp8x8_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) llvm_fp8x8_t*>(
reinterpret_cast<uintptr_t>(in_ptr));
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr8_b64_v2i32(lds_ptr));
}
else
{
static_assert(false, "not implemented");
}
}
#endif
} // namespace ck_tile
#endif // CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN

View File

@@ -902,8 +902,9 @@ struct buffer_view<address_space_enum::lds,
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE constexpr auto
transpose_get(index_t i, index_t linear_offset, bool is_valid_element) const
CK_TILE_DEVICE constexpr auto transpose_get([[maybe_unused]] index_t i,
[[maybe_unused]] index_t linear_offset,
bool is_valid_element) const
{
// X contains multiple T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
@@ -913,13 +914,16 @@ struct buffer_view<address_space_enum::lds,
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T");
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
if(is_valid_element)
{
#if defined(__gfx950__)
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
constexpr address_space_enum addr_space = get_address_space();
return amd_transpose_load_to_vgpr<remove_cvref_t<T>, t_per_x, addr_space>(
p_data_ + i + linear_offset);
#else
return X{numeric<remove_cvref_t<T>>::zero()};
#endif
}
else
{