mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
* Initial implementation: - add new thread group transfer supporting transpose instruction - refactor AB transfer to switch between thread and wave tiles methods * Add some comments and remove explicit wave and lane calculations * Remove compiler option for performance * fp16 example: use tuned instance * Missing cleanup * Integrate wave transfer in existing gemm and batched gemm instances * Add fast instances * extend implementation for 8 bit datatypes packed types not supported * Address review comments * Optimize pipeline v1 and re-introduce compiler option * Disable wave tile approach for b scale gemm * Fix for clang20 * Avoid code duplication of amd_global_load_transpose_to_vgpr function
38 lines
1.2 KiB
C++
38 lines
1.2 KiB
C++
// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
#pragma once
|
|
#include "data_type.hpp"
|
|
|
|
namespace ck {
|
|
|
|
#if defined(__gfx12__)
|
|
template <typename T>
|
|
__device__ auto amd_global_load_transpose_to_vgpr(const T* in_ptr)
|
|
{
|
|
using vector_t = typename vector_type<T, 8>::type;
|
|
if constexpr(sizeof(T) == 2)
|
|
{
|
|
typedef __attribute__((__vector_size__(8 * sizeof(__fp16)))) __fp16 llvm_fp16x8_t;
|
|
__attribute__((address_space(1))) llvm_fp16x8_t* glb_ptr =
|
|
reinterpret_cast<__attribute__((address_space(1))) llvm_fp16x8_t*>(
|
|
reinterpret_cast<uintptr_t>(in_ptr));
|
|
return bit_cast<vector_t>(__builtin_amdgcn_global_load_tr_b128_v8f16(glb_ptr));
|
|
}
|
|
else if constexpr(sizeof(T) == 1)
|
|
{
|
|
typedef __attribute__((__vector_size__(2 * sizeof(int)))) int llvm_intx2_t;
|
|
__attribute__((address_space(1))) llvm_intx2_t* glb_ptr =
|
|
reinterpret_cast<__attribute__((address_space(1))) llvm_intx2_t*>(
|
|
reinterpret_cast<uintptr_t>(in_ptr));
|
|
return bit_cast<vector_t>(__builtin_amdgcn_global_load_tr_b64_v2i32(glb_ptr));
|
|
}
|
|
else
|
|
{
|
|
static_assert(false, "not implemented");
|
|
}
|
|
}
|
|
#endif
|
|
|
|
} // namespace ck
|