mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 03:19:48 +00:00
* Initial implementation:
- add new thread group transfer supporting transpose instruction
- refactor AB transfer to switch between thread and wave tiles methods
* Add some comments and remove explicit wave and lane calculations
* Remove compiler option for performance
* fp16 example: use tuned instance
* Missing cleanup
* Integrate wave transfer in existing gemm and batched gemm instances
* Add fast instances
* extend implementation for 8 bit datatypes
packed types not supported
* Address review comments
* Optimize pipeline v1 and re-introduce compiler option
* Disable wave tile approach for b scale gemm
* Fix for clang20
* Avoid code duplication of amd_global_load_transpose_to_vgpr function
[ROCm/composable_kernel commit: 440358c168]
38 lines
1.2 KiB
C++
38 lines
1.2 KiB
C++
// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
#pragma once
|
|
#include "data_type.hpp"
|
|
|
|
namespace ck {
|
|
|
|
#if defined(__gfx12__)
|
|
template <typename T>
|
|
__device__ auto amd_global_load_transpose_to_vgpr(const T* in_ptr)
|
|
{
|
|
using vector_t = typename vector_type<T, 8>::type;
|
|
if constexpr(sizeof(T) == 2)
|
|
{
|
|
typedef __attribute__((__vector_size__(8 * sizeof(__fp16)))) __fp16 llvm_fp16x8_t;
|
|
__attribute__((address_space(1))) llvm_fp16x8_t* glb_ptr =
|
|
reinterpret_cast<__attribute__((address_space(1))) llvm_fp16x8_t*>(
|
|
reinterpret_cast<uintptr_t>(in_ptr));
|
|
return bit_cast<vector_t>(__builtin_amdgcn_global_load_tr_b128_v8f16(glb_ptr));
|
|
}
|
|
else if constexpr(sizeof(T) == 1)
|
|
{
|
|
typedef __attribute__((__vector_size__(2 * sizeof(int)))) int llvm_intx2_t;
|
|
__attribute__((address_space(1))) llvm_intx2_t* glb_ptr =
|
|
reinterpret_cast<__attribute__((address_space(1))) llvm_intx2_t*>(
|
|
reinterpret_cast<uintptr_t>(in_ptr));
|
|
return bit_cast<vector_t>(__builtin_amdgcn_global_load_tr_b64_v2i32(glb_ptr));
|
|
}
|
|
else
|
|
{
|
|
static_assert(false, "not implemented");
|
|
}
|
|
}
|
|
#endif
|
|
|
|
} // namespace ck
|