mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
Merge branch 'test_copy_fix' of https://github.com/ROCm/composable_kernel into fa_decode_pipeline
This commit is contained in:
@@ -1,6 +1,73 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
/**
|
||||
* @file
|
||||
* We're defining the data access pattern for a 2D window (`XPerTile` by `YPerTile`)
|
||||
for `BlockSize` threads in a thread block.
|
||||
* X dimension is considered contiguous in memory, so a single instruction can access
|
||||
several adjacent and properly aligned elements (vector); the access pattern along X tile
|
||||
dimension is parameterized only by the suggested vector size `VecSize`.
|
||||
* We can't access more than `MaxVecSize = TileElementsPerThread = TileSize / BlockSize` elements
|
||||
with a single memory access, so the actual vector size along the X dimension is
|
||||
`X0 = min(MaxVecSize, VecSize)`.
|
||||
* This leaves `X1 = XPerTile / X0` threads per tile in X dimension.
|
||||
* X1 is also the number of threads per warp in X dimension, that is,
|
||||
X dimension is not split between warps, and each warp accesses X dimension entirely,
|
||||
and there is no iteration in X dimension.
|
||||
* The tuple <X0, X1> defines the X-axis access pattern.
|
||||
This part is common between the 2D distribution patterns.
|
||||
|
||||
* What's different between the different 2D distribution patterns, is the Y axis access pattern.
|
||||
* There are 3 components in this access pattern;
|
||||
* (1) number of Y-axis elements (rows) per warp for a single instruction access,
|
||||
* (2) number of warps per thread block,
|
||||
* (3) number of iterations to cover the entire Y axis.
|
||||
|
||||
* The raked here represents how data is partitioned across different processing granularity.
|
||||
* It represents howe we are going to access the data in thread, warp, or blocked in contiguous
|
||||
region.
|
||||
* From below, the qualifier for 'raked' is the part of warp/thread hierarchy
|
||||
* in the split of Y tile dimension where the iteration happens,
|
||||
* meaning, the iteration can be logically inserted as a tile dimension in 3 ways,
|
||||
* (1) after thread -> thread-raked,
|
||||
* (2) between warp and thread -> warp-raked,
|
||||
* (3) before warp -> block-raked
|
||||
|
||||
* *Thread raked*
|
||||
|
||||
* Y0 is the number of warps, which we can get from the equation `Y0 * WarpSize == BlockSize`
|
||||
* Y1 is the number of rows accessed by a warp within a single iteration,
|
||||
compute it from the equation `Y0 * X1 == WarpSize`
|
||||
* Y2 is the number of iterations to cover the tile,
|
||||
compute it from the equation `Y0 * Y1 * Y2 == YPerTile`
|
||||
|
||||
* *Warp raked*
|
||||
|
||||
* Y0 is the number of warps, we can get it in the same way as for thread-raked pattern,
|
||||
`Y0 * WarpSize == BlockSize`
|
||||
* Y1 is the number of iterations to cover the tile, `Y0 * Y1 * Y2 == YPerTile`.
|
||||
Compute Y2 from the equation below
|
||||
* Y2 is the number of rows accessed by a warp in a single iteration, `Y2 * X1 == WarpSize`
|
||||
|
||||
* *Block raked*
|
||||
|
||||
* Y0 is the number of iterations to cover the tile, `Y0 * Y1 * Y2 == YPerTile`.
|
||||
Compute Y1 and Y2 from the equations below
|
||||
* Y1 is the number of warps, `Y1 * WarpSize == BlockSize`
|
||||
* Y2 is the number of rows accessed by a warp in a single iteration, `Y2 * X1 == WarpSize`
|
||||
|
||||
* In all cases, the tuple <Y0, Y1, Y2> defines the Y-axis access pattern.
|
||||
|
||||
* *Selection*
|
||||
* When we are selecting, Thread-raked is used in element-wise operation because it is the
|
||||
* Thread-major memory order.
|
||||
* Warp-raked is used in matrix multiplication because the vectorization is in warp level.
|
||||
* Block-raked is used mostly for the reduction process, where will reduce the block in global
|
||||
* atomic level.
|
||||
*
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
@@ -105,9 +172,9 @@ struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
tile_distribution_encoding<sequence<Y0>,
|
||||
tuple<sequence<Y1, Y2>, sequence<X0, X1>>,
|
||||
tuple<sequence<0>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<0, 0>>,
|
||||
tuple<sequence<0>, sequence<0, 0>>, // -> <Y0>, <Y1, X0>
|
||||
sequence<1, 2>,
|
||||
sequence<1, 1>>{});
|
||||
sequence<1, 1>>{}); // -> <Y2, X1>
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -115,9 +182,9 @@ struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>, // -> <Y0>, <Y1, X0>
|
||||
sequence<1, 2>,
|
||||
sequence<2, 1>>{});
|
||||
sequence<2, 1>>{}); // -> <Y2, X1>
|
||||
}
|
||||
}
|
||||
|
||||
@@ -129,9 +196,9 @@ struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
tile_distribution_encoding<sequence<Y0>,
|
||||
tuple<sequence<X0, X1>, sequence<Y1, Y2>>,
|
||||
tuple<sequence<0>, sequence<2, 1>>,
|
||||
tuple<sequence<0>, sequence<0, 0>>,
|
||||
tuple<sequence<0>, sequence<0, 0>>, // -> <Y0>, <Y1, X0>
|
||||
sequence<1, 2>,
|
||||
sequence<1, 1>>{});
|
||||
sequence<1, 1>>{}); // -> <X1, Y2>
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -139,9 +206,9 @@ struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<X0, X1>, sequence<Y0, Y1, Y2>>,
|
||||
tuple<sequence<2>, sequence<2, 1>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>, // -> <Y0>, <Y1, X0>
|
||||
sequence<1, 2>,
|
||||
sequence<1, 2>>{});
|
||||
sequence<1, 2>>{}); // -> <X1, Y2>
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -182,9 +249,9 @@ struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<2, 0>>,
|
||||
tuple<sequence<0>, sequence<2, 0>>, // -> <Y0>, <Y2, X0>
|
||||
sequence<1, 2>,
|
||||
sequence<1, 1>>{});
|
||||
sequence<1, 1>>{}); // -> <Y1, X1>
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution()
|
||||
@@ -193,9 +260,9 @@ struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<X0, X1>, sequence<Y0, Y1, Y2>>,
|
||||
tuple<sequence<2>, sequence<2, 1>>,
|
||||
tuple<sequence<0>, sequence<2, 0>>,
|
||||
tuple<sequence<0>, sequence<2, 0>>, // -> <Y0>, <Y2, X0>
|
||||
sequence<1, 2>,
|
||||
sequence<1, 1>>{});
|
||||
sequence<1, 1>>{}); // -> <X1, Y1>
|
||||
}
|
||||
};
|
||||
|
||||
@@ -233,9 +300,9 @@ struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>, // -> <Y1>, <Y2, X0>
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
sequence<0, 1>>{}); // -> <Y0, X1>
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution()
|
||||
@@ -244,9 +311,9 @@ struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<X0, X1>, sequence<Y0, Y1, Y2>>,
|
||||
tuple<sequence<2>, sequence<2, 1>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>, // -> <Y1>, <Y2, X0>
|
||||
sequence<1, 2>,
|
||||
sequence<1, 0>>{});
|
||||
sequence<1, 0>>{}); // -> <X1, Y0>
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/ignore.hpp"
|
||||
|
||||
// This attribute gives a hint to the compiler that a branch is likely to be taken.
|
||||
// Then, the compiler should remove if possible the associated s_cbranch_execz branch that would
|
||||
@@ -23,6 +24,8 @@
|
||||
#define LIKELY(x) (__builtin_expect(!!(x), 1))
|
||||
#endif
|
||||
|
||||
using as3_uint32_ptr = uint32_t __attribute__((address_space(3)))*;
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// 128 bit SGPRs to supply buffer resource in buffer instructions
|
||||
@@ -1270,7 +1273,7 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
|
||||
// Direct loads from global to LDS.
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
|
||||
__attribute__((address_space(3))) uint32_t* lds_ptr,
|
||||
as3_uint32_ptr lds_ptr,
|
||||
index_t size,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
@@ -1760,7 +1763,7 @@ template <typename T,
|
||||
index_t N,
|
||||
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem,
|
||||
CK_TILE_DEVICE void amd_async_buffer_load_impl(CK_TILE_LDS_ADDR T* smem,
|
||||
int32x4_t src_wave_buffer_resource,
|
||||
index_t src_thread_addr_offset,
|
||||
index_t src_wave_addr_offset,
|
||||
@@ -1790,29 +1793,61 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
|
||||
index_t flag = 0,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size");
|
||||
|
||||
constexpr index_t bytes = sizeof(T) * N;
|
||||
#if defined(__gfx950__)
|
||||
static_assert(bytes == 4 || bytes == 12 || bytes == 16,
|
||||
"wrong! only support in dword, dwordx3, dwordx4");
|
||||
ignore = src_wave_addr_offset;
|
||||
ignore = src_immediate_addr_offset;
|
||||
if constexpr(oob_conditional_check)
|
||||
{
|
||||
index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2];
|
||||
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
|
||||
smem,
|
||||
sizeof(uint32_t),
|
||||
v_offset,
|
||||
src_wave_addr_offset,
|
||||
src_immediate_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_wave_buffer_resource,
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
|
||||
bytes,
|
||||
v_offset,
|
||||
0,
|
||||
0,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
|
||||
smem,
|
||||
sizeof(uint32_t),
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
src_immediate_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_wave_buffer_resource,
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
|
||||
bytes,
|
||||
src_thread_addr_offset,
|
||||
0,
|
||||
0,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
#else
|
||||
static_assert(bytes == 4, "wrong! not implemented vector size");
|
||||
if constexpr(oob_conditional_check)
|
||||
{
|
||||
index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2];
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_wave_buffer_resource,
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
|
||||
bytes,
|
||||
v_offset,
|
||||
src_wave_addr_offset,
|
||||
src_immediate_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_wave_buffer_resource,
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
|
||||
bytes,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
src_immediate_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <index_t N,
|
||||
@@ -2786,15 +2821,52 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
|
||||
: "memory");
|
||||
#else
|
||||
// LDS pointer must be attributed with the LDS address space.
|
||||
__attribute__((address_space(3))) uint32_t* lds_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
|
||||
reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset));
|
||||
as3_uint32_ptr lds_ptr =
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset));
|
||||
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0);
|
||||
#endif
|
||||
}
|
||||
|
||||
#if defined(__gfx950__)
|
||||
template <typename T, index_t N, address_space_enum BufferAddressSpace>
|
||||
__device__ auto amd_transpose_load_to_vgpr(const T* in_ptr)
|
||||
{
|
||||
|
||||
static_assert(__has_builtin(__builtin_amdgcn_raw_buffer_load_b32),
|
||||
"We need to have the compatible compiler version to build this instruction");
|
||||
if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::half_t>)
|
||||
{
|
||||
typedef __attribute__((__vector_size__(4 * sizeof(__fp16)))) __fp16 llvm_fp16x4_t;
|
||||
__attribute__((address_space(3))) llvm_fp16x4_t* lds_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(3))) llvm_fp16x4_t*>(
|
||||
reinterpret_cast<uintptr_t>(in_ptr));
|
||||
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4f16(lds_ptr));
|
||||
}
|
||||
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::bf16_t>)
|
||||
{
|
||||
typedef __attribute__((__vector_size__(4 * sizeof(__bf16)))) __bf16 llvm_bf16x4_t;
|
||||
__attribute__((address_space(3))) llvm_bf16x4_t* lds_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(3))) llvm_bf16x4_t*>(
|
||||
reinterpret_cast<uintptr_t>(in_ptr));
|
||||
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4bf16(lds_ptr));
|
||||
}
|
||||
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::fp8_t>)
|
||||
{
|
||||
typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_fp8x8_t;
|
||||
__attribute__((address_space(3))) llvm_fp8x8_t* lds_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(3))) llvm_fp8x8_t*>(
|
||||
reinterpret_cast<uintptr_t>(in_ptr));
|
||||
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr8_b64_v2i32(lds_ptr));
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "not implemented");
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
#endif // !CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
|
||||
|
||||
@@ -13,6 +13,9 @@
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/ignore.hpp"
|
||||
|
||||
using as3_uint32_ptr = uint32_t __attribute__((address_space(3)))*;
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -1138,7 +1141,7 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
|
||||
// Direct loads from global to LDS.
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
|
||||
__attribute__((address_space(3))) uint32_t* lds_ptr,
|
||||
as3_uint32_ptr lds_ptr,
|
||||
index_t size,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
@@ -1560,29 +1563,61 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
|
||||
index_t flag = 0,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size");
|
||||
|
||||
constexpr index_t bytes = sizeof(T) * N;
|
||||
#if defined(__gfx950__)
|
||||
static_assert(bytes == 4 || bytes == 12 || bytes == 16,
|
||||
"wrong! only support in dword, dwordx3, dwordx4");
|
||||
ignore = src_wave_addr_offset;
|
||||
ignore = src_immediate_addr_offset;
|
||||
if constexpr(oob_conditional_check)
|
||||
{
|
||||
index_t v_offset = flag ? v_offset : src_wave_buffer_resource[2];
|
||||
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
|
||||
smem,
|
||||
sizeof(uint32_t),
|
||||
v_offset,
|
||||
src_wave_addr_offset,
|
||||
src_immediate_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2];
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_wave_buffer_resource,
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
|
||||
bytes,
|
||||
v_offset,
|
||||
0,
|
||||
0,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
|
||||
smem,
|
||||
sizeof(uint32_t),
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
src_immediate_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_wave_buffer_resource,
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
|
||||
bytes,
|
||||
src_thread_addr_offset,
|
||||
0,
|
||||
0,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
#else
|
||||
static_assert(bytes == 4, "wrong! not implemented vector size");
|
||||
if constexpr(oob_conditional_check)
|
||||
{
|
||||
index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2];
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_wave_buffer_resource,
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
|
||||
bytes,
|
||||
v_offset,
|
||||
src_wave_addr_offset,
|
||||
src_immediate_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_wave_buffer_resource,
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
|
||||
bytes,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
src_immediate_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <index_t N,
|
||||
@@ -2556,15 +2591,52 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
|
||||
: "memory");
|
||||
#else
|
||||
// LDS pointer must be attributed with the LDS address space.
|
||||
__attribute__((address_space(3))) uint32_t* lds_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
|
||||
reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset));
|
||||
as3_uint32_ptr lds_ptr =
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset));
|
||||
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0);
|
||||
#endif
|
||||
}
|
||||
|
||||
#if defined(__gfx950__)
|
||||
template <typename T, index_t N, address_space_enum BufferAddressSpace>
|
||||
__device__ auto amd_transpose_load_to_vgpr(const T* in_ptr)
|
||||
{
|
||||
|
||||
static_assert(__has_builtin(__builtin_amdgcn_raw_buffer_load_b32),
|
||||
"We need to have the compatible compiler version to build this instruction");
|
||||
if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::half_t>)
|
||||
{
|
||||
typedef __attribute__((__vector_size__(4 * sizeof(__fp16)))) __fp16 llvm_fp16x4_t;
|
||||
__attribute__((address_space(3))) llvm_fp16x4_t* lds_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(3))) llvm_fp16x4_t*>(
|
||||
reinterpret_cast<uintptr_t>(in_ptr));
|
||||
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4f16(lds_ptr));
|
||||
}
|
||||
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::bf16_t>)
|
||||
{
|
||||
typedef __attribute__((__vector_size__(4 * sizeof(__bf16)))) __bf16 llvm_bf16x4_t;
|
||||
__attribute__((address_space(3))) llvm_bf16x4_t* lds_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(3))) llvm_bf16x4_t*>(
|
||||
reinterpret_cast<uintptr_t>(in_ptr));
|
||||
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4bf16(lds_ptr));
|
||||
}
|
||||
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::fp8_t>)
|
||||
{
|
||||
typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_fp8x8_t;
|
||||
__attribute__((address_space(3))) llvm_fp8x8_t* lds_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(3))) llvm_fp8x8_t*>(
|
||||
reinterpret_cast<uintptr_t>(in_ptr));
|
||||
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr8_b64_v2i32(lds_ptr));
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "not implemented");
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
#endif // CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
|
||||
|
||||
86
include/ck_tile/core/arch/amd_transpose_load_encoding.hpp
Normal file
86
include/ck_tile/core/arch/amd_transpose_load_encoding.hpp
Normal file
@@ -0,0 +1,86 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// this generate wave level tile distribution
|
||||
template <typename T, typename = void>
|
||||
struct LaneGroupTransposeTraits;
|
||||
|
||||
template <typename T>
|
||||
struct LaneGroupTransposeTraits<T, std::enable_if_t<sizeof(T) == 2>>
|
||||
{
|
||||
// before transpose, 4x16
|
||||
static constexpr index_t ksecondDim = 4;
|
||||
static constexpr index_t kleadDim = 16;
|
||||
// after transpose, 16x4
|
||||
static constexpr index_t ksecondDimT = 16;
|
||||
static constexpr index_t kleadDimT = 4;
|
||||
template <index_t kOuterDistDim0,
|
||||
index_t kOuterDistDim1,
|
||||
index_t kInnerDistDim0,
|
||||
index_t kInnerDistDim1>
|
||||
using TileDistribution =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<kOuterDistDim0, kOuterDistDim1, 4>,
|
||||
sequence<kInnerDistDim0, kInnerDistDim1, 4, 4>>,
|
||||
tuple<sequence<1, 2, 1, 2>>,
|
||||
tuple<sequence<0, 0, 2, 2>>,
|
||||
sequence<2, 1, 2>,
|
||||
sequence<1, 1, 3>>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct LaneGroupTransposeTraits<T, std::enable_if_t<sizeof(T) == 1>>
|
||||
{
|
||||
static constexpr index_t ksecondDim = 8;
|
||||
static constexpr index_t kleadDim = 16;
|
||||
|
||||
static constexpr index_t ksecondDimT = 16;
|
||||
static constexpr index_t kleadDimT = 8;
|
||||
|
||||
template <index_t kOuterDistDim0,
|
||||
index_t kOuterDistDim1,
|
||||
index_t kInnerDistDim0,
|
||||
index_t kInnerDistDim1>
|
||||
using TileDistribution =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<kOuterDistDim0, kOuterDistDim1, 8>,
|
||||
sequence<kInnerDistDim0, kInnerDistDim1, 2, 8>>,
|
||||
tuple<sequence<1, 2, 1, 2>>,
|
||||
tuple<sequence<0, 0, 2, 2>>,
|
||||
sequence<2, 1, 2>,
|
||||
sequence<1, 1, 3>>;
|
||||
};
|
||||
|
||||
/*
|
||||
* @brief This function is used to generate the transposed distribution encoding
|
||||
* for the given data type and distribution dimensions.
|
||||
*
|
||||
* @tparam T The data type of the elements in the tensor.
|
||||
* @tparam kOuterDistDim0 The outer distribution dimension 0, which is outer dimension for stride.
|
||||
* @tparam kOuterDistDim1 The outer distribution dimension 1, which is inner dimension for stride.
|
||||
* @tparam kInnerDistDim0 The inner distribution dimension 0, which is outer dimension for
|
||||
* consecutive.
|
||||
* @tparam kInnerDistDim1 The inner distribution dimension 1, which is inner dimension for
|
||||
* consecutive.
|
||||
*/
|
||||
template <typename T,
|
||||
index_t kOuterDistDim0,
|
||||
index_t kOuterDistDim1,
|
||||
index_t kInnerDistDim0,
|
||||
index_t kInnerDistDim1>
|
||||
CK_TILE_DEVICE constexpr auto make_transposed_distr_encode()
|
||||
{
|
||||
using xdllevel_dstr_encoding = typename LaneGroupTransposeTraits<T>::
|
||||
template TileDistribution<kOuterDistDim0, kOuterDistDim1, kInnerDistDim0, kInnerDistDim1>;
|
||||
return xdllevel_dstr_encoding{};
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -10,6 +10,15 @@
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
|
||||
#define CK_TILE_S_CNT_MAX 0b1100'1111'0111'1111
|
||||
#define CK_TILE_VMCNT(cnt) \
|
||||
([]() { static_assert((cnt) < 0b111111, "VMCNT only has 6 bits"); }(), \
|
||||
((cnt)&0b1111) | (((cnt)&0b110000) << 10))
|
||||
#define CK_TILE_EXPCNT(cnt) \
|
||||
([]() { static_assert((cnt) < 0b111, "EXP only has 3 bits"); }(), ((cnt) << 4))
|
||||
#define CK_TILE_LGKMCNT(cnt) \
|
||||
([]() { static_assert((cnt) < 0b1111, "LGKM only has 4 bits"); }(), ((cnt) << 8))
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename, bool>
|
||||
@@ -113,13 +122,12 @@ CK_TILE_DEVICE void block_sync_load_raw(index_t cnt = 0)
|
||||
#endif
|
||||
}
|
||||
|
||||
template <index_t vmcnt>
|
||||
CK_TILE_DEVICE void block_sync_lds_direct_load()
|
||||
{
|
||||
asm volatile("\
|
||||
s_waitcnt vmcnt(0) \n \
|
||||
s_waitcnt lgkmcnt(0) \n \
|
||||
s_barrier \
|
||||
" ::);
|
||||
// We don't sync the lds insts here.
|
||||
__builtin_amdgcn_s_waitcnt(CK_TILE_S_CNT_MAX & CK_TILE_VMCNT(vmcnt));
|
||||
__builtin_amdgcn_s_barrier();
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void s_nop(index_t cnt = 0)
|
||||
|
||||
@@ -263,3 +263,9 @@
|
||||
#ifndef CK_TILE_WA_ISSUE_2028
|
||||
#define CK_TILE_WA_ISSUE_2028 0
|
||||
#endif
|
||||
|
||||
// Y pointed to R, we don't see a valuable use case.
|
||||
// Will enforce encoding to check Y not pointed to R if set to zero
|
||||
#ifndef CK_TILE_ENC_SUPPORT_Y_TO_R
|
||||
#define CK_TILE_ENC_SUPPORT_Y_TO_R 0
|
||||
#endif
|
||||
|
||||
@@ -1178,6 +1178,15 @@ struct reverse_slice_sequence_impl<sequence<x>, sequence<m>, sequence<id>, Slice
|
||||
// clang-format off
|
||||
// input a sequence(with optional mask), and the SliceSize : size per slice
|
||||
// output the sequence each slice, and number of slices
|
||||
// the length count for slice size is from right to left(reverse slice)
|
||||
// or we can say, find the greatest common divider(gcd) from right to left, for the slice length
|
||||
//
|
||||
// e.g. <2, 8, 4>, slice length = 16
|
||||
// step-1: we take the right most <*, *, 4>, remaining 16/4=4
|
||||
// step-2: we only need 4 out of 8, of the midden dim, hence <*, 4, 4>
|
||||
// step-3: since nonthing remain, so the first dim we only need 1, hence<1, 4, 4>
|
||||
// => we got <1, 4, 4> as length for each slice
|
||||
// => total number of slice = <2, 8, 4> / <1, 4, 4> = <2, 2, 1>
|
||||
//
|
||||
// e.g. <2, 1, 4, 2>, 8 -> lengths:<1, 1, 4, 2> , nums: <2, 1, 1, 1> : 2 slices , slice_idx: 0
|
||||
// <4, 2, 4, 1, 2>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 1> : 16 slices , slice_idx: 2
|
||||
@@ -1197,7 +1206,7 @@ struct reverse_slice_sequence_impl<sequence<x>, sequence<m>, sequence<id>, Slice
|
||||
//
|
||||
// return tuple<slice_lengths, slice_nums, slice_index>, slice_index is at which index will start
|
||||
// have split slices (right -> left)
|
||||
// or the first index that sliced length is different from the original length
|
||||
// or the first index (right -> left) that sliced length is different from the original length
|
||||
// clang-format on
|
||||
template <typename Seq,
|
||||
index_t SliceSize,
|
||||
@@ -1207,6 +1216,11 @@ constexpr auto reverse_slice_sequence(Seq,
|
||||
Mask = typename uniform_sequence_gen<Seq::size(), 1>::type{})
|
||||
{
|
||||
static_assert(Seq::size() == Mask::size());
|
||||
static_assert(SliceSize != 0, "slice size zero is invalid");
|
||||
static_assert(container_reduce(pick_sequence_elements_by_mask(Seq{}, Mask{}), multiplies{}, 1) %
|
||||
SliceSize ==
|
||||
0,
|
||||
"slice size can't evenly divide input sizes");
|
||||
using sliced_type =
|
||||
impl::reverse_slice_sequence_impl<Seq,
|
||||
Mask,
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
namespace ck_tile {
|
||||
|
||||
using index_t = int32_t;
|
||||
using int32_t = int32_t;
|
||||
using long_index_t = int64_t;
|
||||
using int8_t = int8_t;
|
||||
|
||||
|
||||
213
include/ck_tile/core/numeric/mxfp_convert.hpp
Normal file
213
include/ck_tile/core/numeric/mxfp_convert.hpp
Normal file
@@ -0,0 +1,213 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
// modify from include/ck/utility/mxfp_utils.hpp
|
||||
|
||||
template <typename T>
|
||||
struct numeric_utils : numeric_traits<T>
|
||||
{
|
||||
|
||||
using traits = numeric_traits<T>;
|
||||
using _numeric = numeric<T>;
|
||||
using raw_type = typename T::raw_type;
|
||||
|
||||
static constexpr int exp_mask = (1 << traits::exp) - 1;
|
||||
|
||||
static constexpr int get_exponent(raw_type x)
|
||||
{
|
||||
// TODO: check if repeated calls are optimized.
|
||||
return (x >> traits::mant) & exp_mask;
|
||||
}
|
||||
static constexpr bool is_positive(raw_type x)
|
||||
{
|
||||
return (x >> (traits::exp + traits::mant)) == _numeric::binary_zero;
|
||||
}
|
||||
static constexpr bool is_subnormal(raw_type x)
|
||||
{
|
||||
return get_exponent(x) == _numeric::binary_zero;
|
||||
}
|
||||
// TODO: replace double with template arg?
|
||||
static constexpr double get_mantissa(raw_type x)
|
||||
{
|
||||
double mantissa = is_subnormal(x) ? 0.0f : 1.0f;
|
||||
for(uint32_t i = 0; i < traits::mant; ++i)
|
||||
{
|
||||
mantissa += std::ldexp(static_cast<float>(x & 0b1), -(traits::mant - i));
|
||||
x >>= 1;
|
||||
}
|
||||
return mantissa;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE float convert_to_float(typename T::raw_type data, int scale_exp = 127)
|
||||
{
|
||||
using utils = numeric_utils<T>;
|
||||
static constexpr int e8m0_bias = 127; // TODO: make it generic.
|
||||
float sign = utils::is_positive(data) ? 1.0 : -1.0;
|
||||
int exp = (utils::is_subnormal(data) ? 1 : utils::get_exponent(data)) - utils::bias;
|
||||
float mant = utils::get_mantissa(data);
|
||||
|
||||
return std::ldexp(sign * mant, exp + scale_exp - e8m0_bias);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE typename T::raw_type convert_to_type(float value)
|
||||
{
|
||||
using bitwise_type = typename numeric_traits<T>::bitwise_type;
|
||||
|
||||
if(std::abs(value) > float(numeric<T>::max()))
|
||||
{
|
||||
float max_value = numeric<T>::max();
|
||||
|
||||
// cppcheck-suppress redundantAssignment
|
||||
uint32_t max_bitwise = bit_cast<uint32_t>(max_value);
|
||||
|
||||
// cppcheck-suppress redundantAssignment
|
||||
bitwise_type sign =
|
||||
bit_cast<uint32_t>(value) >> (numeric_traits<float>::exp + numeric_traits<float>::mant);
|
||||
bitwise_type exp =
|
||||
((max_bitwise >> numeric_traits<float>::mant) & numeric_traits<float>::exp_mask) -
|
||||
(numeric_traits<float>::bias - numeric_traits<T>::bias);
|
||||
bitwise_type mantissa =
|
||||
max_bitwise >> (numeric_traits<float>::mant - numeric_traits<T>::mant);
|
||||
|
||||
uint32_t mant_prev = max_bitwise >> (numeric_traits<float>::mant - numeric_traits<T>::mant);
|
||||
mant_prev &= ((1 << numeric_traits<T>::mant) - 1);
|
||||
mant_prev--;
|
||||
|
||||
mant_prev <<= (numeric_traits<float>::mant - numeric_traits<T>::mant);
|
||||
uint32_t prev_bit =
|
||||
((max_bitwise >> numeric_traits<float>::mant) << numeric_traits<float>::mant) |
|
||||
mant_prev;
|
||||
|
||||
float prev_val = bit_cast<float>(prev_bit);
|
||||
float diff = max_value - prev_val;
|
||||
|
||||
float actual_max = max_value + (diff / 2);
|
||||
|
||||
if(std::abs(value) < actual_max)
|
||||
{
|
||||
return sign << ((numeric_traits<T>::exp + numeric_traits<T>::mant)) |
|
||||
(exp << numeric_traits<T>::mant) | mantissa;
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(!numeric<T>::has_inf())
|
||||
{
|
||||
|
||||
return (1 << (numeric_traits<T>::mant + numeric_traits<T>::exp)) - 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
exp++;
|
||||
return sign << ((numeric_traits<T>::exp + numeric_traits<T>::mant)) |
|
||||
(exp << numeric_traits<T>::mant);
|
||||
}
|
||||
}
|
||||
}
|
||||
const int mfmt = numeric_traits<float>::mant;
|
||||
uint32_t x;
|
||||
x = bit_cast<uint32_t>(value);
|
||||
|
||||
uint32_t head, mantissa;
|
||||
int32_t exponent, bias;
|
||||
uint32_t sign;
|
||||
|
||||
head = x & numeric_traits<float>::head_mask;
|
||||
mantissa = x & numeric_traits<float>::mant_mask;
|
||||
exponent = (head >> numeric_traits<float>::mant) & numeric_traits<float>::exp_mask;
|
||||
sign = head >> (numeric_traits<float>::mant + numeric_traits<float>::exp);
|
||||
bias = numeric_traits<float>::bias;
|
||||
|
||||
if(x == 0)
|
||||
{
|
||||
return 0b0;
|
||||
}
|
||||
|
||||
const int mini_bias = numeric_traits<T>::bias;
|
||||
const int mini_denormal_act_exponent = 1 - mini_bias;
|
||||
|
||||
int act_exponent, out_exponent, exponent_diff;
|
||||
|
||||
bool is_subnorm = false;
|
||||
|
||||
if(exponent == 0)
|
||||
{
|
||||
act_exponent = exponent - bias + 1;
|
||||
exponent_diff = mini_denormal_act_exponent - act_exponent;
|
||||
is_subnorm = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
act_exponent = exponent - bias;
|
||||
if(act_exponent <= mini_denormal_act_exponent)
|
||||
{
|
||||
exponent_diff = mini_denormal_act_exponent - act_exponent;
|
||||
is_subnorm = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
exponent_diff = 0;
|
||||
}
|
||||
mantissa += (1UL << mfmt);
|
||||
}
|
||||
|
||||
auto shift_amount = (mfmt - numeric_traits<T>::mant + exponent_diff);
|
||||
shift_amount = (shift_amount >= 64) ? 63 : shift_amount;
|
||||
bool midpoint = (mantissa & ((1UL << shift_amount) - 1)) == (1UL << (shift_amount - 1));
|
||||
|
||||
float min_subnorm = float(numeric<T>::epsilon()) * (sign ? -1 : 1);
|
||||
|
||||
if(is_subnorm && std::abs(value) < std::abs(min_subnorm))
|
||||
{
|
||||
// closer to 0
|
||||
if(std::abs(value) <= std::abs(min_subnorm - value))
|
||||
return sign << (numeric_traits<T>::exp + numeric_traits<T>::mant);
|
||||
else
|
||||
return 1 | (sign << (numeric_traits<T>::exp + numeric_traits<T>::mant));
|
||||
}
|
||||
|
||||
if(exponent_diff > 0)
|
||||
mantissa >>= exponent_diff;
|
||||
else if(exponent_diff == -1)
|
||||
mantissa <<= -exponent_diff;
|
||||
bool implicit_one = mantissa & (1 << mfmt);
|
||||
out_exponent = (act_exponent + exponent_diff) + mini_bias - (implicit_one ? 0 : 1);
|
||||
|
||||
uint32_t drop_mask = (1UL << (mfmt - numeric_traits<T>::mant)) - 1;
|
||||
bool odd = mantissa & (1UL << (mfmt - numeric_traits<T>::mant));
|
||||
mantissa += (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa) & drop_mask;
|
||||
|
||||
if(out_exponent == 0)
|
||||
{
|
||||
if((1UL << mfmt) & mantissa)
|
||||
{
|
||||
out_exponent = 1;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if((1UL << (mfmt + 1)) & mantissa)
|
||||
{
|
||||
mantissa >>= 1;
|
||||
out_exponent++;
|
||||
}
|
||||
}
|
||||
|
||||
mantissa >>= (mfmt - numeric_traits<T>::mant);
|
||||
|
||||
if(out_exponent == 0 && mantissa == 0)
|
||||
{
|
||||
return sign << (numeric_traits<T>::exp + numeric_traits<T>::mant);
|
||||
}
|
||||
|
||||
mantissa &= (1UL << numeric_traits<T>::mant) - 1;
|
||||
return (sign << (numeric_traits<T>::exp + numeric_traits<T>::mant)) |
|
||||
(out_exponent << numeric_traits<T>::mant) | mantissa;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -103,94 +103,92 @@ struct numeric_traits<float>
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_) \
|
||||
attr_ bool operator==(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) == static_cast<float>(y); \
|
||||
} \
|
||||
attr_ bool operator!=(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) != static_cast<float>(y); \
|
||||
} \
|
||||
attr_ bool operator<(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) < static_cast<float>(y); \
|
||||
} \
|
||||
attr_ bool operator<=(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) <= static_cast<float>(y); \
|
||||
} \
|
||||
attr_ bool operator>(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) > static_cast<float>(y); \
|
||||
} \
|
||||
attr_ bool operator>=(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) >= static_cast<float>(y); \
|
||||
} \
|
||||
attr_ type_ operator+(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) + static_cast<float>(y)); \
|
||||
} \
|
||||
attr_ type_ operator-(const type_& x) \
|
||||
{ \
|
||||
constexpr uint32_t bits = sizeof(type_) * 8; \
|
||||
constexpr uint32_t mask = 1 << (bits - 1); \
|
||||
type_ y = x; \
|
||||
y.data ^= static_cast<typename type_::raw_type>(mask); \
|
||||
return y; \
|
||||
} \
|
||||
attr_ type_ operator-(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) - static_cast<float>(y)); \
|
||||
} \
|
||||
attr_ type_ operator*(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) * static_cast<float>(y)); \
|
||||
} \
|
||||
attr_ type_ operator/(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) / static_cast<float>(y)); \
|
||||
} \
|
||||
attr_ type_& operator+=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) + static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_& operator-=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) - static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_& operator*=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) * static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_& operator/=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) / static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_& operator++(type_& x) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) + 1.f); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_& operator--(type_& x) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) - 1.f); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_ operator++(type_& x, int) \
|
||||
{ \
|
||||
type_ y(x); \
|
||||
x = type_(static_cast<float>(x) + 1.f); \
|
||||
return y; \
|
||||
} \
|
||||
attr_ type_ operator--(type_& x, int) \
|
||||
{ \
|
||||
type_ y(x); \
|
||||
x = type_(static_cast<float>(x) - 1.f); \
|
||||
return y; \
|
||||
#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_) \
|
||||
attr_ bool operator==(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return std::abs(static_cast<float>(x) - static_cast<float>(y)) < \
|
||||
static_cast<float>(numeric<type_>::epsilon()); \
|
||||
} \
|
||||
attr_ bool operator!=(const type_& x, const type_& y) { return not operator==(x, y); } \
|
||||
attr_ bool operator<(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) < static_cast<float>(y); \
|
||||
} \
|
||||
attr_ bool operator<=(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) <= static_cast<float>(y); \
|
||||
} \
|
||||
attr_ bool operator>(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) > static_cast<float>(y); \
|
||||
} \
|
||||
attr_ bool operator>=(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) >= static_cast<float>(y); \
|
||||
} \
|
||||
attr_ type_ operator+(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) + static_cast<float>(y)); \
|
||||
} \
|
||||
attr_ type_ operator-(const type_& x) \
|
||||
{ \
|
||||
constexpr uint32_t bits = sizeof(type_) * 8; \
|
||||
constexpr uint32_t mask = 1 << (bits - 1); \
|
||||
type_ y = x; \
|
||||
y.data ^= static_cast<typename type_::raw_type>(mask); \
|
||||
return y; \
|
||||
} \
|
||||
attr_ type_ operator-(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) - static_cast<float>(y)); \
|
||||
} \
|
||||
attr_ type_ operator*(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) * static_cast<float>(y)); \
|
||||
} \
|
||||
attr_ type_ operator/(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) / static_cast<float>(y)); \
|
||||
} \
|
||||
attr_ type_& operator+=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) + static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_& operator-=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) - static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_& operator*=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) * static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_& operator/=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) / static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_& operator++(type_& x) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) + 1.f); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_& operator--(type_& x) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) - 1.f); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_ operator++(type_& x, int) \
|
||||
{ \
|
||||
type_ y(x); \
|
||||
x = type_(static_cast<float>(x) + 1.f); \
|
||||
return y; \
|
||||
} \
|
||||
attr_ type_ operator--(type_& x, int) \
|
||||
{ \
|
||||
type_ y(x); \
|
||||
x = type_(static_cast<float>(x) - 1.f); \
|
||||
return y; \
|
||||
}
|
||||
|
||||
324
include/ck_tile/core/numeric/pk_fp4.hpp
Normal file
324
include/ck_tile/core/numeric/pk_fp4.hpp
Normal file
@@ -0,0 +1,324 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cmath>
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/mxfp_convert.hpp"
|
||||
|
||||
#if defined(__gfx950__)
|
||||
#define CK_TILE_FP4_CVT_DEVICE 1
|
||||
#else
|
||||
#define CK_TILE_FP4_CVT_DEVICE 0
|
||||
#endif
|
||||
|
||||
#define TEST_convert_with_table 0
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using fp32_t = float;
|
||||
using fp32x2_t = float __attribute__((ext_vector_type(2)));
|
||||
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
|
||||
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr uint8_t float_to_e2m1(float);
|
||||
|
||||
// TODO: Add stochastic method
|
||||
struct pk_float4_e2m1_t
|
||||
{
|
||||
static constexpr int exponent = 2;
|
||||
static constexpr int mantissa = 1;
|
||||
static constexpr int bias = 1;
|
||||
// TODO: Can we merge raw_type and type?
|
||||
using raw_type = uint8_t;
|
||||
using type = raw_type;
|
||||
raw_type data;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t() : data{type{}} {}
|
||||
template <typename T, typename = std::enable_if_t<std::is_integral_v<T>>>
|
||||
CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t(T init) : data{static_cast<type>(init)}
|
||||
{
|
||||
}
|
||||
CK_TILE_HOST_DEVICE explicit constexpr pk_float4_e2m1_t(float init) : data{float_to_e2m1(init)}
|
||||
{
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr operator type() const { return data; }
|
||||
CK_TILE_HOST_DEVICE constexpr raw_type& get() { return data; }
|
||||
CK_TILE_HOST_DEVICE constexpr raw_type get() const { return data; }
|
||||
CK_TILE_HOST_DEVICE constexpr operator float() const;
|
||||
CK_TILE_HOST_DEVICE constexpr operator fp32x2_t() const;
|
||||
CK_TILE_HOST_DEVICE constexpr operator fp16_t() const;
|
||||
CK_TILE_HOST_DEVICE constexpr operator fp16x2_t() const;
|
||||
CK_TILE_HOST_DEVICE constexpr operator bf16_t() const;
|
||||
CK_TILE_HOST_DEVICE constexpr operator bf16x2_t() const;
|
||||
|
||||
template <index_t I>
|
||||
CK_TILE_HOST_DEVICE raw_type unpack(number<I>) const;
|
||||
CK_TILE_HOST_DEVICE static pk_float4_e2m1_t pack(const type x0, const type x1)
|
||||
{
|
||||
return (x1 << 4) | (x0 & 0b00001111);
|
||||
}
|
||||
|
||||
#if TEST_convert_with_table
|
||||
static constexpr float e2m1_to_fp32_table[16] = {
|
||||
0, 0.5, 1, 1.5, 2, 3, 4, 6, -0, -0.5, -1, -1.5, -2, -3, -4, -6};
|
||||
static constexpr fp16_t e2m1_to_fp16_table[16] = {
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0x0000)), // 0
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0x3800)), // 0.5
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0x3C00)), // 1
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0x3E00)), // 1.5
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0x4000)), // 2
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0x4200)), // 3
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0x4400)), // 4
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0x4600)), // 6
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0x8000)), // -0
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0xB800)), // -0.5
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0xBC00)), // -1
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0xBE00)), // -1.5
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0xC000)), // -2
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0xC200)), // -3
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0xC400)), // -4
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0xC600)) // -6
|
||||
};
|
||||
#endif
|
||||
};
|
||||
|
||||
using pk_fp4_t = pk_float4_e2m1_t;
|
||||
using pk_fp4_raw_t = typename pk_fp4_t::raw_type;
|
||||
|
||||
template <>
|
||||
struct numeric_traits<pk_fp4_t>
|
||||
{
|
||||
using bitwise_type = pk_fp4_raw_t;
|
||||
|
||||
static constexpr int exp = 2;
|
||||
static constexpr int mant = 1;
|
||||
static constexpr int bias = 1;
|
||||
static constexpr int PackedSize = 2;
|
||||
};
|
||||
|
||||
// limits
|
||||
template <class T>
|
||||
struct numeric;
|
||||
|
||||
template <>
|
||||
struct numeric<pk_fp4_t>
|
||||
{
|
||||
static constexpr pk_fp4_raw_t binary_min_normal = 0b00100010; // 1
|
||||
static constexpr pk_fp4_raw_t binary_max_normal = 0b01110111; // 6
|
||||
static constexpr pk_fp4_raw_t binary_lowest_normal = 0b11111111; // -6
|
||||
static constexpr pk_fp4_raw_t binary_min_subnorm = 0b00010001; // 0.5
|
||||
static constexpr pk_fp4_raw_t binary_max_subnorm = 0b00010001; // 0.5
|
||||
static constexpr pk_fp4_raw_t binary_zero = 0b00000000; // 0
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t min() { return binary_min_normal; }
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t max() { return binary_max_normal; }
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t lowest() { return binary_lowest_normal; }
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t epsilon() { return binary_min_subnorm; }
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t round_error() { return binary_min_subnorm; }
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t zero() { return binary_zero; }
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t denorm_min() { return binary_min_subnorm; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool has_inf() { return false; }
|
||||
// N/A
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t infinity() { return max(); }
|
||||
// N/A
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t quiet_NaN() { return max(); }
|
||||
// N/A
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t signaling_NaN() { return max(); }
|
||||
};
|
||||
|
||||
template <index_t I>
|
||||
CK_TILE_HOST_DEVICE pk_fp4_raw_t pk_fp4_t::unpack(number<I>) const
|
||||
{
|
||||
static_assert(I < 2, "Index is out of range.");
|
||||
if constexpr(I == 1)
|
||||
return (data >> 4);
|
||||
else
|
||||
return data & 0b00001111;
|
||||
}
|
||||
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, pk_fp4_t)
|
||||
// TODO: consider replace this macro to improve performance
|
||||
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
namespace impl {
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T _from_f4(pk_fp4_raw_t src, float scale = 1.0f)
|
||||
{
|
||||
// TODO: check the order
|
||||
if constexpr(std::is_same_v<T, fp32_t>)
|
||||
return fp32x2_t(__builtin_amdgcn_cvt_scalef32_pk_f32_fp4(src, scale, 0))[0];
|
||||
else if constexpr(std::is_same_v<T, fp32x2_t>)
|
||||
return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(src, scale, 0);
|
||||
else if constexpr(std::is_same_v<T, fp16_t>)
|
||||
return fp16x2_t(__builtin_amdgcn_cvt_scalef32_pk_f16_fp4(src, scale, 0))[0];
|
||||
else if constexpr(std::is_same_v<T, fp16x2_t>)
|
||||
return __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(src, scale, 0);
|
||||
else if constexpr(std::is_same_v<T, bf16_t>)
|
||||
return bf16x2_t(__builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(src, scale, 0))[0];
|
||||
else if constexpr(std::is_same_v<T, bf16x2_t>)
|
||||
return __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(src, scale, 0);
|
||||
else
|
||||
static_assert(std::false_type::value, "Unsupported type.");
|
||||
return T{};
|
||||
}
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE pk_fp4_raw_t _to_f4(T src, float scale = 1.0f)
|
||||
{
|
||||
// TODO: check the order
|
||||
union
|
||||
{
|
||||
uint32_t u32;
|
||||
pk_fp4_raw_t pf4[4];
|
||||
} cvt{0};
|
||||
if constexpr(std::is_same_v<T, fp32_t>)
|
||||
cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(cvt.u32, src, src, scale, 0);
|
||||
else if constexpr(std::is_same_v<T, fp32x2_t>)
|
||||
cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(cvt.u32, src[0], src[1], scale, 0);
|
||||
else if constexpr(std::is_same_v<T, fp16_t>)
|
||||
cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(cvt.u32, fp16x2_t{src, src}, scale, 0);
|
||||
else if constexpr(std::is_same_v<T, fp16x2_t>)
|
||||
cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(cvt.u32, src, scale, 0);
|
||||
else if constexpr(std::is_same_v<T, bf16_t>)
|
||||
cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_bf16(cvt.u32, bf16x2_t{src, src}, scale, 0);
|
||||
else if constexpr(std::is_same_v<T, bf16x2_t>)
|
||||
cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_bf16(cvt.u32, src, scale, 0);
|
||||
else
|
||||
static_assert(std::false_type::value, "Unsupported type.");
|
||||
return cvt.pf4[0];
|
||||
}
|
||||
|
||||
} // namespace impl
|
||||
#endif
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator bf16_t() const
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_from_f4<bf16_t>(data);
|
||||
#else
|
||||
return bf16_t{type_convert<bf16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{})))};
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator bf16x2_t() const
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_from_f4<bf16x2_t>(data);
|
||||
#else
|
||||
return bf16x2_t{type_convert<bf16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{}))),
|
||||
type_convert<bf16_t>(convert_to_float<pk_fp4_t>(unpack(number<1>{})))};
|
||||
#endif
|
||||
}
|
||||
|
||||
// TODO: make float_to_e2m1 generic so that we can convert from directrly.
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t float_to_e2m1(float x)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x);
|
||||
#else
|
||||
return convert_to_type<pk_fp4_t>(x);
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_to_fp32x2(const pk_fp4_t& x) { return fp32x2_t(x); }
|
||||
CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_to_fp16x2(const pk_fp4_t& x) { return fp16x2_t(x); }
|
||||
CK_TILE_HOST_DEVICE constexpr bf16x2_t pk_fp4_to_bf16x2(const pk_fp4_t& x) { return bf16x2_t(x); }
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t float_to_pk_fp4(const float& x) { return float_to_e2m1(x); }
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16_to_pk_fp4(const fp16_t& x)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x);
|
||||
#else
|
||||
return float_to_e2m1(type_convert<float>(x));
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16_to_pk_fp4(const bf16_t& x)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x);
|
||||
#else
|
||||
return float_to_e2m1(type_convert<float>(x));
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t& x)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x);
|
||||
#else
|
||||
return pk_fp4_t::pack(float_to_e2m1(type_convert<float>(x[0])),
|
||||
float_to_e2m1(type_convert<float>(x[1])));
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x);
|
||||
#else
|
||||
return pk_fp4_t::pack(float_to_e2m1(type_convert<float>(x[0])),
|
||||
float_to_e2m1(type_convert<float>(x[1])));
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x);
|
||||
#else
|
||||
return pk_fp4_t::pack(float_to_e2m1(x[0]), float_to_e2m1(x[1]));
|
||||
#endif
|
||||
}
|
||||
|
||||
#if TEST_convert_with_table == 0
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator float() const
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_from_f4<fp32_t>(data);
|
||||
#else
|
||||
return convert_to_float<pk_fp4_t>(unpack(number<0>{}));
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp32x2_t() const
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_from_f4<fp32x2_t>(data);
|
||||
#else
|
||||
return fp32x2_t{convert_to_float<pk_fp4_t>(unpack(number<0>{})),
|
||||
convert_to_float<pk_fp4_t>(unpack(number<1>{}))};
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp16_t() const
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_from_f4<fp16_t>(data);
|
||||
#else
|
||||
return fp16_t{type_convert<fp16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{})))};
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp16x2_t() const
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_from_f4<fp16x2_t>(data);
|
||||
#else
|
||||
return fp16x2_t{type_convert<fp16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{}))),
|
||||
type_convert<fp16_t>(convert_to_float<pk_fp4_t>(unpack(number<1>{})))};
|
||||
#endif
|
||||
}
|
||||
#else
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator float() const
|
||||
{
|
||||
return e2m1_to_fp32_table[data & 0xf];
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp32x2_t() const
|
||||
{
|
||||
return fp32x2_t{e2m1_to_fp32_table[data & 0xf], e2m1_to_fp32_table[data >> 4]};
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp16_t() const
|
||||
{
|
||||
return e2m1_to_fp16_table[data & 0xf];
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp16x2_t() const
|
||||
{
|
||||
return fp16x2_t{e2m1_to_fp16_table[data & 0xf], e2m1_to_fp16_table[data >> 4]};
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "ck_tile/core/numeric/bfloat16.hpp"
|
||||
#include "ck_tile/core/numeric/float8.hpp"
|
||||
#include "ck_tile/core/numeric/int8.hpp"
|
||||
#include "ck_tile/core/numeric/mxfp_convert.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -64,6 +65,21 @@ CK_TILE_TYPE_CONVERT(bf8_t, bf8, float, float)
|
||||
CK_TILE_TYPE_CONVERT(float, float, int8_t, int8)
|
||||
CK_TILE_TYPE_CONVERT(int8_t, int8, float, float)
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
#include "ck_tile/core/numeric/pk_fp4.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp32x2_t, fp32x2)
|
||||
CK_TILE_TYPE_CONVERT(fp32x2_t, fp32x2, pk_fp4_t, pk_fp4)
|
||||
CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp16x2_t, fp16x2)
|
||||
CK_TILE_TYPE_CONVERT(fp16x2_t, fp16x2, pk_fp4_t, pk_fp4)
|
||||
CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, bf16x2_t, bf16x2)
|
||||
CK_TILE_TYPE_CONVERT(bf16x2_t, bf16x2, pk_fp4_t, pk_fp4)
|
||||
CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, float, float)
|
||||
CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, bf16_t, bf16)
|
||||
CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp16_t, fp16)
|
||||
#undef CK_TILE_TYPE_CONVERT
|
||||
#endif
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/bfloat16.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/utility/ignore.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -133,6 +134,28 @@ struct buffer_view<address_space_enum::generic,
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
In the generic address space, we do not support the transpose instruction in the buffer view.
|
||||
Will report compilation error when developer wants to use it.
|
||||
*/
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE constexpr auto transpose_get(index_t i,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element,
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
static_assert(false, "Error: transpose load not supported in global memory space.");
|
||||
ignore = i;
|
||||
ignore = linear_offset;
|
||||
ignore = is_valid_element;
|
||||
return;
|
||||
}
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <memory_operation_enum Op,
|
||||
typename X,
|
||||
@@ -359,6 +382,28 @@ struct buffer_view<address_space_enum::global,
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
In the global memory address space, we do not support the transpose instruction in the buffer
|
||||
view. Will report compilation error when developer wants to use it.
|
||||
*/
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE constexpr auto transpose_get(index_t i,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element,
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
static_assert(false, "Error: transpose load not supported in global memory space.");
|
||||
ignore = i;
|
||||
ignore = linear_offset;
|
||||
ignore = is_valid_element;
|
||||
return;
|
||||
}
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
@@ -407,10 +452,12 @@ struct buffer_view<address_space_enum::global,
|
||||
"wrong! X should contain multiple T");
|
||||
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
const int32x4_t src_wave_buffer_resource =
|
||||
make_wave_buffer_resource(p_data_, (buffer_size_) * sizeof(type));
|
||||
|
||||
amd_async_buffer_load_with_oob<remove_cvref_t<T>, t_per_x, Coherence>(
|
||||
smem,
|
||||
cached_buf_res_,
|
||||
src_wave_buffer_resource,
|
||||
i,
|
||||
linear_offset,
|
||||
is_valid_element,
|
||||
@@ -852,6 +899,47 @@ struct buffer_view<address_space_enum::lds,
|
||||
smem_load<sizeof(X)>{}(dst, v_offset * sizeof(T), i_offset * sizeof(T));
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename std::enable_if<
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE constexpr auto transpose_get([[maybe_unused]] index_t i,
|
||||
[[maybe_unused]] index_t linear_offset,
|
||||
bool is_valid_element) const
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
|
||||
|
||||
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
|
||||
|
||||
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
||||
"wrong! X should contain multiple T");
|
||||
|
||||
if(is_valid_element)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
constexpr address_space_enum addr_space = get_address_space();
|
||||
return amd_transpose_load_to_vgpr<remove_cvref_t<T>, t_per_x, addr_space>(
|
||||
p_data_ + i + linear_offset);
|
||||
#else
|
||||
return X{numeric<remove_cvref_t<T>>::zero()};
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(InvalidElementUseNumericalZeroValue)
|
||||
{
|
||||
return X{numeric<remove_cvref_t<T>>::zero()};
|
||||
}
|
||||
else
|
||||
{
|
||||
return X{invalid_element_value_};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <memory_operation_enum Op,
|
||||
typename X,
|
||||
@@ -923,6 +1011,15 @@ struct buffer_view<address_space_enum::lds,
|
||||
std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8x16_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
|
||||
// int8 on thread buffer
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 8>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 4>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 2>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 1>>) ||
|
||||
// ext_vector_type for pk_int4 must use int8_t as type
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 1>>) ||
|
||||
@@ -945,6 +1042,8 @@ struct buffer_view<address_space_enum::lds,
|
||||
|
||||
if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 1>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 1>>))
|
||||
{
|
||||
@@ -955,6 +1054,8 @@ struct buffer_view<address_space_enum::lds,
|
||||
}
|
||||
else if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x2_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 2>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 2>>))
|
||||
{
|
||||
@@ -965,6 +1066,8 @@ struct buffer_view<address_space_enum::lds,
|
||||
}
|
||||
else if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x4_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 4>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 4>>))
|
||||
{
|
||||
@@ -975,6 +1078,8 @@ struct buffer_view<address_space_enum::lds,
|
||||
}
|
||||
else if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 8>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 8>>))
|
||||
{
|
||||
|
||||
@@ -89,6 +89,19 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
|
||||
tile, number<i_access>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <typename LdsTileWindow_,
|
||||
typename TileWindow_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto async_load_tile(LdsTileWindow_&& lds_tile,
|
||||
const TileWindow_& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
return tile_window.async_load(
|
||||
lds_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename LdsTileWindow_,
|
||||
typename TileWindow_,
|
||||
index_t i_access = -1,
|
||||
|
||||
362
include/ck_tile/core/tensor/load_tile_transpose.hpp
Normal file
362
include/ck_tile/core/tensor/load_tile_transpose.hpp
Normal file
@@ -0,0 +1,362 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/container/thread_buffer.hpp"
|
||||
#include "ck_tile/core/container/statically_indexed_array.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
namespace util {
|
||||
template <typename Suffix, typename Sequence>
|
||||
struct is_sequence_suffix
|
||||
{
|
||||
static constexpr bool size_check = (Suffix::size() <= Sequence::size());
|
||||
|
||||
static constexpr index_t start_pos = Sequence::size() - Suffix::size();
|
||||
using extract_indices = typename arithmetic_sequence_gen<start_pos, Sequence::size(), 1>::type;
|
||||
|
||||
static constexpr bool value =
|
||||
size_check && (Suffix{} == decltype(Sequence::extract(extract_indices{})){});
|
||||
};
|
||||
|
||||
template <index_t... Xs>
|
||||
struct is_sequence_suffix<sequence<>, sequence<Xs...>>
|
||||
{
|
||||
static constexpr bool value = true;
|
||||
};
|
||||
|
||||
template <typename Suffix, typename Sequence>
|
||||
constexpr bool is_sequence_suffix_v = is_sequence_suffix<Suffix, Sequence>::value;
|
||||
|
||||
} // namespace util
|
||||
|
||||
// Default policy: Retains original 2D transpose behavior
|
||||
template <typename DataType>
|
||||
struct DefaultTranspose
|
||||
{
|
||||
struct Quad16
|
||||
{
|
||||
using InputEncoding = tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<4>, sequence<4, 4>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
|
||||
using OutputEncoding = tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<16>, sequence<4>>,
|
||||
tuple<sequence<1>>,
|
||||
tuple<sequence<0>>,
|
||||
sequence<2>,
|
||||
sequence<0>>;
|
||||
};
|
||||
|
||||
struct Quad8
|
||||
{
|
||||
using InputEncoding = tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<8>, sequence<2, 8>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
|
||||
using OutputEncoding = tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<16>, sequence<8>>,
|
||||
tuple<sequence<1>>,
|
||||
tuple<sequence<0>>,
|
||||
sequence<2>,
|
||||
sequence<0>>;
|
||||
};
|
||||
|
||||
// Select based on data size
|
||||
using QuadInputEncoding = std::conditional_t<sizeof(DataType) == 2,
|
||||
typename Quad16::InputEncoding,
|
||||
typename Quad8::InputEncoding>;
|
||||
|
||||
using QuadOutputEncoding = std::conditional_t<sizeof(DataType) == 2,
|
||||
typename Quad16::OutputEncoding,
|
||||
typename Quad8::OutputEncoding>;
|
||||
|
||||
// Always swap last two dimensions
|
||||
static constexpr auto transpose_dims = sequence<1, 0>{};
|
||||
|
||||
// Programmable: Element grouping function
|
||||
static constexpr auto group_func = [](auto idx) {
|
||||
return idx; // Identity mapping
|
||||
};
|
||||
|
||||
template <typename InDstrEncode>
|
||||
struct ValidationTraits
|
||||
{
|
||||
static constexpr auto input_hs_lengthss = InDstrEncode::hs_lengthss_;
|
||||
static constexpr auto quad_hs_lengthss = QuadInputEncoding::hs_lengthss_;
|
||||
// 1. Must be 2D tensor
|
||||
static constexpr bool dims_valid = (InDstrEncode::NDimX == 2);
|
||||
// 2. Quad pattern must be suffix of input pattern
|
||||
static constexpr bool suffix_valid_dim0 =
|
||||
util::is_sequence_suffix_v<decltype(quad_hs_lengthss.template get<0>()),
|
||||
decltype(input_hs_lengthss.template get<0>())>;
|
||||
static constexpr bool suffix_valid_dim1 =
|
||||
util::is_sequence_suffix_v<decltype(quad_hs_lengthss.template get<1>()),
|
||||
decltype(input_hs_lengthss.template get<1>())>;
|
||||
|
||||
// 3. PS→RHS mapping constraints
|
||||
static constexpr auto input_ps_to_rhss_major = InDstrEncode::ps_to_rhss_major_;
|
||||
static constexpr auto input_ps_to_rhss_minor = InDstrEncode::ps_to_rhss_minor_;
|
||||
|
||||
static constexpr index_t ndimp_outer = input_ps_to_rhss_major.size() - 1;
|
||||
static constexpr index_t ndimp_inner =
|
||||
input_ps_to_rhss_major[number<ndimp_outer>{}].size() - 1;
|
||||
|
||||
static constexpr bool ps_mapping_valid =
|
||||
(input_ps_to_rhss_major[number<ndimp_outer>{}][number<ndimp_inner>{}] == 2) &&
|
||||
(input_ps_to_rhss_minor[number<ndimp_outer>{}][number<ndimp_inner>{}] ==
|
||||
input_hs_lengthss[number<1>{}].size() - 2) &&
|
||||
(input_ps_to_rhss_major[number<ndimp_outer>{}][number<ndimp_inner - 1>{}] == 1) &&
|
||||
(input_ps_to_rhss_minor[number<ndimp_outer>{}][number<ndimp_inner - 1>{}] ==
|
||||
input_hs_lengthss[number<0>{}].size() - 1);
|
||||
|
||||
// 4. YS→RHS mapping constraints
|
||||
static constexpr auto input_ys_to_rhs_major = InDstrEncode::ys_to_rhs_major_;
|
||||
static constexpr auto input_ys_to_rhs_minor = InDstrEncode::ys_to_rhs_minor_;
|
||||
|
||||
static constexpr bool ys_mapping_valid =
|
||||
(input_ys_to_rhs_major.back() == 2) &&
|
||||
(input_ys_to_rhs_minor.back() == input_hs_lengthss[number<1>{}].size() - 1) &&
|
||||
(input_ys_to_rhs_major[input_ys_to_rhs_major.size() - 2] == 1) &&
|
||||
(input_ys_to_rhs_minor[input_ys_to_rhs_minor.size() - 2] ==
|
||||
input_hs_lengthss[number<0>{}].size() - 2);
|
||||
|
||||
static constexpr bool value = dims_valid && suffix_valid_dim0 && suffix_valid_dim1 &&
|
||||
ps_mapping_valid && ys_mapping_valid;
|
||||
};
|
||||
};
|
||||
template <typename TileDistribution_, typename DataType_, typename Policy>
|
||||
struct TransposeTileDistrChecker
|
||||
{
|
||||
using InDstrEncode = typename remove_cvref_t<TileDistribution_>::DstrEncode;
|
||||
|
||||
using Validator = typename Policy::template ValidationTraits<InDstrEncode>;
|
||||
|
||||
static constexpr bool distr_encoding_valid = Validator::value;
|
||||
};
|
||||
|
||||
// this is used to generate the transposed output tile distribution encoding
|
||||
// based on the input tile distribution encoding
|
||||
template <typename TileDistribution_,
|
||||
typename DataType_,
|
||||
typename Policy = DefaultTranspose<DataType_>>
|
||||
struct OutputTileDistributionTraits
|
||||
{
|
||||
using InDstrEncode = typename remove_cvref_t<TileDistribution_>::DstrEncode;
|
||||
static constexpr auto input_hs_lengthss = InDstrEncode::hs_lengthss_;
|
||||
static constexpr auto quad_input_hs_lengthss = Policy::QuadInputEncoding::hs_lengthss_;
|
||||
static constexpr auto quad_output_hs_lengthss = Policy::QuadOutputEncoding::hs_lengthss_;
|
||||
|
||||
static constexpr auto input_ps_to_rhss_major = InDstrEncode::ps_to_rhss_major_;
|
||||
static constexpr auto input_ps_to_rhss_minor = InDstrEncode::ps_to_rhss_minor_;
|
||||
static constexpr auto input_ys_to_rhs_major = InDstrEncode::ys_to_rhs_major_;
|
||||
static constexpr auto input_ys_to_rhs_minor = InDstrEncode::ys_to_rhs_minor_;
|
||||
|
||||
static constexpr auto quad_ps_to_rhss_major = Policy::QuadInputEncoding::ps_to_rhss_major_;
|
||||
static constexpr auto quad_ps_to_rhss_minor = Policy::QuadInputEncoding::ps_to_rhss_minor_;
|
||||
|
||||
// for transpose load
|
||||
// append the reversed quad output hs lengths to the input hs lengthss after removing
|
||||
// the quad_input_hs_lengthss
|
||||
// then reverse the whole sequence to get the dst_out_hs_lengthss
|
||||
static constexpr auto reversed_quad_output_hs_lengthss = tuple_reverse(quad_output_hs_lengthss);
|
||||
|
||||
static constexpr auto full_out_hs_lengthss = generate_tuple(
|
||||
[](auto i) {
|
||||
return input_hs_lengthss[i]
|
||||
.extract(typename arithmetic_sequence_gen<0,
|
||||
input_hs_lengthss[i].size() -
|
||||
quad_input_hs_lengthss[i].size(),
|
||||
1>::type{})
|
||||
.push_back(reversed_quad_output_hs_lengthss[i]);
|
||||
},
|
||||
number<InDstrEncode::NDimX>{});
|
||||
|
||||
static constexpr auto dst_out_hs_lengthss = tuple_reverse(full_out_hs_lengthss);
|
||||
|
||||
// for PS→RHS mapping(both major and minor), we need to modify the last element of the major
|
||||
// sequence
|
||||
static constexpr auto modified_ps_to_rhss_major = generate_tuple(
|
||||
[](auto i) {
|
||||
if constexpr(i == input_ps_to_rhss_major.size() - 1)
|
||||
{
|
||||
constexpr auto current_size = input_ps_to_rhss_major[i].size();
|
||||
constexpr auto reduce_size = quad_ps_to_rhss_major[number<0>{}].size();
|
||||
constexpr auto reduced_ps_to_rhss_major = input_ps_to_rhss_major[i].extract(
|
||||
typename arithmetic_sequence_gen<0, current_size - reduce_size, 1>::type{});
|
||||
return reduced_ps_to_rhss_major.push_back(number<2>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
// For all other sequences, keep them unchanged
|
||||
return input_ps_to_rhss_major[i];
|
||||
}
|
||||
},
|
||||
number<input_ps_to_rhss_major.size()>{});
|
||||
|
||||
static constexpr auto minor_last_index =
|
||||
full_out_hs_lengthss[number<InDstrEncode::NDimX - 1>{}].size() - 1;
|
||||
static constexpr auto major_last_index = full_out_hs_lengthss[number<0>{}].size() - 1;
|
||||
|
||||
static constexpr auto dst_ps_to_rhss_minor = generate_tuple(
|
||||
[](auto i) {
|
||||
if constexpr(i == input_ps_to_rhss_minor.size() - 1)
|
||||
{
|
||||
constexpr auto current_size = input_ps_to_rhss_minor[i].size();
|
||||
constexpr auto reduce_size = quad_ps_to_rhss_minor[number<0>{}].size();
|
||||
constexpr auto reduced_ps_to_rhss_minor = input_ps_to_rhss_minor[i].extract(
|
||||
typename arithmetic_sequence_gen<0, current_size - reduce_size, 1>::type{});
|
||||
return reduced_ps_to_rhss_minor.push_back(number<minor_last_index>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
// For all other sequences, keep them unchanged
|
||||
return input_ps_to_rhss_minor[i];
|
||||
}
|
||||
},
|
||||
number<input_ps_to_rhss_minor.size()>{});
|
||||
|
||||
// for major because of dst_out_hs_lengthss is reversed, this index also need to be reversed
|
||||
static constexpr auto swap_one_and_two = [](const index_t idx) {
|
||||
return (idx == 1) ? 2 : (idx == 2) ? 1 : idx;
|
||||
};
|
||||
static constexpr auto dst_ps_to_rhss_major = generate_tuple(
|
||||
[](auto i) { return modified_ps_to_rhss_major[i].transform(swap_one_and_two); },
|
||||
number<modified_ps_to_rhss_major.size()>{});
|
||||
|
||||
static constexpr auto modified_input_ys_to_rhs_major =
|
||||
input_ys_to_rhs_major.pop_back().push_back(number<1>{});
|
||||
|
||||
static constexpr auto dst_ys_to_rhs_major = generate_sequence_v2(
|
||||
[](auto i) { return number<swap_one_and_two(modified_input_ys_to_rhs_major[i])>{}; },
|
||||
number<modified_input_ys_to_rhs_major.size()>{});
|
||||
|
||||
static constexpr auto dst_ys_to_rhs_minor =
|
||||
input_ys_to_rhs_minor.pop_back().push_back(number<major_last_index>{});
|
||||
|
||||
using OutDstrEncode = tile_distribution_encoding<typename InDstrEncode::RsLengths,
|
||||
remove_cvref_t<decltype(dst_out_hs_lengthss)>,
|
||||
remove_cvref_t<decltype(dst_ps_to_rhss_major)>,
|
||||
remove_cvref_t<decltype(dst_ps_to_rhss_minor)>,
|
||||
remove_cvref_t<decltype(dst_ys_to_rhs_major)>,
|
||||
remove_cvref_t<decltype(dst_ys_to_rhs_minor)>>;
|
||||
};
|
||||
|
||||
template <typename InnerEncode,
|
||||
index_t kLeadIterPerWarp,
|
||||
index_t kSecondIterPerWarp,
|
||||
index_t kLeadNumWarps,
|
||||
index_t kSecondNumWarps>
|
||||
CK_TILE_HOST_DEVICE constexpr auto InputTileDistributionEncoding()
|
||||
{
|
||||
constexpr auto block_outer_dst_encoding =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<kSecondIterPerWarp, kSecondNumWarps>,
|
||||
sequence<kLeadIterPerWarp, kLeadNumWarps>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<2, 1>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto blk_distr_encode =
|
||||
detail::make_embed_tile_distribution_encoding(block_outer_dst_encoding, InnerEncode{});
|
||||
|
||||
return blk_distr_encode;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief transpose loads tile from a tensor and returns the resulting tensor with a new
|
||||
* (transposed) tile distribution. use SFINAE to ensure the tile distribution encoding is valid.
|
||||
*
|
||||
* This function is intended for use with statically distributed tensor tiles, where the input
|
||||
* and output tile distributions differ due to the transpose operation. It ensures that the
|
||||
* element space size and vector length remain consistent between the input and output
|
||||
* distributions.
|
||||
*
|
||||
* @tparam BottomTensorView_ The type of the bottom tensor view.
|
||||
* @tparam WindowLengths_ The type representing the window lengths.
|
||||
* @tparam TileDistribution_ The type representing the tile distribution.
|
||||
* @tparam NumCoord The number of coordinates (dimensions).
|
||||
* @tparam Policy The transpose policy to use (defaults to DefaultTranspose).
|
||||
* the last is SFINAE to ensure the tile distribution encoding is valid.
|
||||
*
|
||||
* @param tile_window The tile window with static distribution to load and transpose.
|
||||
*
|
||||
* @return A statically distributed tensor containing the transposed tile data.
|
||||
*
|
||||
* @note
|
||||
* - The function uses compile-time checks to ensure the input and output tile distributions
|
||||
* are compatible in terms of element space size and vector length.
|
||||
* - The transpose operation is performed according to the specified Policy.
|
||||
*/
|
||||
template <
|
||||
typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
typename Policy = DefaultTranspose<typename BottomTensorView_::DataType>,
|
||||
typename = std::enable_if_t<TransposeTileDistrChecker<TileDistribution_,
|
||||
typename BottomTensorView_::DataType,
|
||||
Policy>::distr_encoding_valid,
|
||||
Policy>>
|
||||
CK_TILE_DEVICE auto
|
||||
load_tile_transpose(const tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window)
|
||||
{
|
||||
using OutTileDstrEncode =
|
||||
typename OutputTileDistributionTraits<TileDistribution_,
|
||||
typename BottomTensorView_::DataType>::OutDstrEncode;
|
||||
auto out_tensor = make_static_distributed_tensor<typename BottomTensorView_::DataType>(
|
||||
make_static_tile_distribution(OutTileDstrEncode{}));
|
||||
auto trans_tensor = tile_window.template load_transpose<Policy>();
|
||||
constexpr auto input_distr = TileDistribution_{};
|
||||
constexpr auto output_distr = make_static_tile_distribution(OutTileDstrEncode{});
|
||||
|
||||
constexpr auto y_in_desc = input_distr.get_ys_to_d_descriptor();
|
||||
constexpr auto y_out_desc = output_distr.get_ys_to_d_descriptor();
|
||||
|
||||
constexpr index_t NDimYIn = input_distr.get_num_of_dimension_y();
|
||||
constexpr index_t NDimYOut = output_distr.get_num_of_dimension_y();
|
||||
|
||||
constexpr auto y_in_lengths = to_sequence(y_in_desc.get_lengths());
|
||||
constexpr auto y_out_lengths = to_sequence(y_out_desc.get_lengths());
|
||||
|
||||
constexpr auto y_in_element_space_size = y_in_desc.get_element_space_size();
|
||||
constexpr auto y_out_element_space_size = y_out_desc.get_element_space_size();
|
||||
static_assert(y_in_element_space_size == y_out_element_space_size,
|
||||
"the element space size is not the same!");
|
||||
static_assert(y_in_lengths[NDimYIn - 1] == y_out_lengths[NDimYOut - 1],
|
||||
"the vector length is not the same!");
|
||||
constexpr index_t vecLoadSize = y_in_lengths[NDimYIn - 1];
|
||||
constexpr index_t num_of_access =
|
||||
reduce_on_sequence(y_in_lengths, multiplies{}, number<1>{}) / vecLoadSize;
|
||||
|
||||
using DataVec = array<typename BottomTensorView_::DataType, vecLoadSize>;
|
||||
static_for<0, num_of_access, 1>{}([&](auto iAccess) {
|
||||
out_tensor.get_thread_buffer().template set_as<DataVec>(
|
||||
number<iAccess>{},
|
||||
trans_tensor.get_thread_buffer().template get_as<DataVec>(number<iAccess>{}));
|
||||
});
|
||||
|
||||
return out_tensor;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -129,7 +129,10 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT
|
||||
// set output vectors
|
||||
static_for<0, num_vec_out, 1>{}([&](auto i) {
|
||||
constexpr auto idx_y_out_tmp = generate_array(
|
||||
[&](auto ii) { return ii == y_dim_vec_in ? idx_y_start[ii] + i : idx_y_start[ii]; },
|
||||
[&](auto ii) {
|
||||
return ii == y_dim_vec_in ? static_cast<index_t>(idx_y_start[ii]) + i
|
||||
: static_cast<index_t>(idx_y_start[ii]);
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr auto idx_y_out =
|
||||
|
||||
@@ -161,7 +161,8 @@ struct tensor_view
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
async_get_vectorized_elements(CK_TILE_LDS_ADDR remove_cvref_t<DataType>* smem,
|
||||
const TensorCoord& coord,
|
||||
index_t linear_offset) const
|
||||
index_t linear_offset,
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
return buf_.template async_get<X>(
|
||||
smem,
|
||||
@@ -181,7 +182,8 @@ struct tensor_view
|
||||
async_get_vectorized_elements(CK_TILE_LDS_ADDR remove_cvref_t<DataType>* smem,
|
||||
const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element) const
|
||||
bool is_valid_element,
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
return buf_.template async_get<X>(smem,
|
||||
coord.get_offset() / PackedSize,
|
||||
@@ -251,6 +253,33 @@ struct tensor_view
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<X>
|
||||
get_transpose_vectorized_elements(const TensorCoord& coord, index_t linear_offset) const
|
||||
{
|
||||
return buf_.template transpose_get<X>(
|
||||
coord.get_offset(),
|
||||
linear_offset,
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord));
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<X>
|
||||
get_transpose_vectorized_elements(const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element // flag
|
||||
) const
|
||||
{
|
||||
return buf_.template transpose_get<X>(coord.get_offset(), linear_offset, is_valid_element);
|
||||
}
|
||||
// X is vector of DataType.
|
||||
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
|
||||
template <typename X,
|
||||
|
||||
@@ -542,26 +542,26 @@ namespace detail {
|
||||
//
|
||||
// e.g
|
||||
// X0 X1
|
||||
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 32>, (0 means all length)
|
||||
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice start:<0, 0>, end:<-1, 32>, (-1 means the last one)
|
||||
// Y P P Y P Y P Y
|
||||
// => <1, 4, 32> - <1, 1, 4, 2, 4> -> OK
|
||||
// |--> slice along this Y dim, is the first dim of X1, totally 4 slices
|
||||
//
|
||||
// X0 X1
|
||||
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 8>, (0 means all length)
|
||||
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice start:<0, 0>, end:<-1, 8>, (-1 means the last one)
|
||||
// Y P P Y P Y P Y
|
||||
// => <1, 4, 32> - <1, 1, 1, 2, 4> -> OK
|
||||
// |--> slice along this Y dim, the P dim is 1 in the left, so is OK
|
||||
// totally 16 slices
|
||||
//
|
||||
// X0 X1
|
||||
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 4>, (0 means all length)
|
||||
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice start:<0, 0>, end:<-1, 4>, (-1 means the last one)
|
||||
// Y P P Y P Y P Y
|
||||
// => <1, 4, 32> - <1, 1, 1, 1, 4> -> Fail
|
||||
// |--> slice along this P dim, will split threads, not supported
|
||||
//
|
||||
// X0 X1
|
||||
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 16>, (0 means all length)
|
||||
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice start:<0, 0>, end:<-1, 16>, (-1 means the last one)
|
||||
// Y P P Y P Y P Y
|
||||
// => <1, 4, 32> - <1, 1, 2, 2, 4> -> OK
|
||||
// |--> slice along this Y dim, but this Y sim need to split into 2
|
||||
@@ -577,11 +577,39 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
|
||||
using Encoding = decltype(Distribution::get_static_tile_distribution_encoding());
|
||||
|
||||
static_assert(sizeof...(XSliceBegins) == sizeof...(XSliceEnds));
|
||||
static_assert(sizeof...(XSliceBegins) == Encoding::NDimX, "only support slice over h, not r");
|
||||
|
||||
constexpr auto x_slice_lengths = x_slice_ends - x_slice_begins;
|
||||
constexpr auto p_len_over_h = Encoding::detail::get_uniformed_p_dim_lengths_over_h();
|
||||
|
||||
constexpr auto x_slice_ends_ = generate_sequence_v2(
|
||||
[&](auto i) {
|
||||
if constexpr(x_slice_ends[i] == -1)
|
||||
{
|
||||
// -1 means till the end
|
||||
constexpr auto x_length_ =
|
||||
container_reduce(typename Encoding::HsLengthss{}[i], multiplies{}, number<1>{});
|
||||
return x_length_;
|
||||
}
|
||||
else
|
||||
{
|
||||
return x_slice_ends[i];
|
||||
}
|
||||
},
|
||||
number<x_slice_ends.size()>{});
|
||||
|
||||
constexpr auto x_slice_lengths = x_slice_ends_ - x_slice_begins;
|
||||
|
||||
constexpr auto x_slice_lengths_without_p = generate_sequence_v2(
|
||||
[&](auto i) constexpr {
|
||||
constexpr auto len_ = x_slice_lengths[i];
|
||||
static_assert(len_ % p_len_over_h[i] == 0,
|
||||
"slice length must be dividable by p_len_over_h");
|
||||
return number<len_ / p_len_over_h[i]>{};
|
||||
},
|
||||
number<x_slice_lengths.size()>{});
|
||||
|
||||
constexpr auto src_h_prefix_sum = Encoding::detail::get_h_dim_lengths_prefix_sum();
|
||||
constexpr auto src_y_info = Encoding::detail::get_sorted_y_info();
|
||||
constexpr auto src_y_info = Encoding::detail::get_sorted_y_to_h_info();
|
||||
constexpr auto src_y_dims = src_y_info[number<0>{}];
|
||||
constexpr auto src_y_maps = src_y_info[number<1>{}];
|
||||
constexpr auto src_y_prefix_sum = src_y_info[number<2>{}];
|
||||
@@ -590,14 +618,15 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
|
||||
{
|
||||
auto y_slice_sorted_origins = make_zero_multi_index<Encoding::NDimY>();
|
||||
auto y_slice_lengths = Encoding::detail::ys_lengths_;
|
||||
constexpr auto y_to_h_masks = Encoding::detail::get_y_to_h_masks();
|
||||
|
||||
// This lambda will modify some value outside, so c++ will not treat return value as
|
||||
// constexpr
|
||||
// TODO: ugly
|
||||
auto new_h_lengths = transform_tuples(
|
||||
[&](auto h_len, auto id) {
|
||||
constexpr auto sliced_h =
|
||||
reverse_slice_sequence(h_len, number<x_slice_lengths[id]>{});
|
||||
constexpr auto sliced_h = reverse_slice_sequence(
|
||||
h_len, number<x_slice_lengths_without_p[id]>{}, y_to_h_masks[id]);
|
||||
|
||||
constexpr auto sliced_h_lens = sliced_h[number<0>{}];
|
||||
constexpr auto sliced_h_index = sliced_h[number<2>{}];
|
||||
@@ -605,26 +634,39 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
|
||||
// update y_slice_lengths
|
||||
constexpr auto uniformed_h_index = sliced_h_index + number<src_h_prefix_sum[id]>{};
|
||||
constexpr auto found_y_index = container_find(src_y_dims, uniformed_h_index);
|
||||
constexpr auto y_to_h_dim_end = src_y_prefix_sum[id + 1];
|
||||
|
||||
static_assert(found_y_index >= 0 && found_y_index < src_y_dims.size(),
|
||||
"not sliced at y dim, please check");
|
||||
|
||||
static_for<0, sliced_h_index + 1, 1>{}([&](auto i) {
|
||||
y_slice_lengths(src_y_maps[found_y_index - i]) =
|
||||
sliced_h_lens[sliced_h_index - i];
|
||||
});
|
||||
{
|
||||
constexpr auto sliced_y_to_h_lens =
|
||||
pick_sequence_elements_by_mask(sliced_h_lens, y_to_h_masks[id]);
|
||||
constexpr auto sliced_y_to_h_dims = sliced_y_to_h_lens.size();
|
||||
static_for<0, sliced_y_to_h_dims, 1>{}([&](auto i) {
|
||||
y_slice_lengths(src_y_maps[y_to_h_dim_end - 1 - i]) =
|
||||
sliced_y_to_h_lens[sliced_y_to_h_dims - 1 - i];
|
||||
});
|
||||
}
|
||||
// TODO: add validations not across p dim
|
||||
|
||||
// NOTE: this y_origin is for all dims, not only current dim
|
||||
// will later use pick to select target dim
|
||||
constexpr auto y_origin = [&]() {
|
||||
constexpr auto h_trans = make_merge_transform_v3_division_mod(h_len);
|
||||
auto h_origin_ = make_zero_multi_index<h_trans.NDimLow>();
|
||||
h_trans.calculate_lower_index(h_origin_, sequence<x_slice_begins[id].value>{});
|
||||
// can't use Encoding::Ys2RHsMajor/Ys2RHsMinor, these are unordered
|
||||
constexpr auto y_to_h_len =
|
||||
pick_sequence_elements_by_mask(h_len, y_to_h_masks[id]);
|
||||
constexpr auto y_to_h_dims = y_to_h_len.size();
|
||||
|
||||
constexpr auto h_trans = make_merge_transform_v3_division_mod(y_to_h_len);
|
||||
auto h_origin_ = make_zero_multi_index<h_trans.NDimLow>();
|
||||
constexpr auto y_begin_ = x_slice_begins[id] / p_len_over_h[id];
|
||||
h_trans.calculate_lower_index(h_origin_, sequence<y_begin_.value>{});
|
||||
|
||||
auto y_origin_ = make_zero_multi_index<Encoding::NDimY>();
|
||||
static_for<0, sliced_h_index + 1, 1>{}([&](auto i) {
|
||||
y_origin_(found_y_index - i) = h_origin_[sliced_h_index - i];
|
||||
|
||||
static_for<0, y_to_h_dims, 1>{}([&](auto i) {
|
||||
y_origin_(y_to_h_dim_end - 1 - i) = h_origin_[y_to_h_dims - 1 - i];
|
||||
});
|
||||
return y_origin_;
|
||||
}();
|
||||
|
||||
@@ -47,6 +47,11 @@ struct tile_distribution_encoding
|
||||
static constexpr auto ys_to_rhs_major_ = Ys2RHsMajor{};
|
||||
static constexpr auto ys_to_rhs_minor_ = Ys2RHsMinor{};
|
||||
|
||||
#if !CK_TILE_ENC_SUPPORT_Y_TO_R
|
||||
static_assert(container_find(ys_to_rhs_major_, 0) == NDimY,
|
||||
"do not support Y dim pointed to R dim");
|
||||
#endif
|
||||
|
||||
// redundant but useful info
|
||||
// TODO: really bad code, should be over-hauled
|
||||
struct detail
|
||||
@@ -255,33 +260,107 @@ struct tile_distribution_encoding
|
||||
}
|
||||
}();
|
||||
|
||||
// e.g. tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>> --> seq<3, 5> --> seq<0, 3, 8>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_h_dim_lengths_prefix_sum()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_h_dim_lengths()
|
||||
{
|
||||
// <len_d0, len_d1, ...>
|
||||
// e.g. tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>> --> seq<3, 5>
|
||||
constexpr auto uniformed_h_dim_lengths = generate_sequence_v2(
|
||||
[&](auto i) {
|
||||
constexpr index_t size = HsLengthss{}[i].size();
|
||||
return number<size>{};
|
||||
constexpr index_t size_ = HsLengthss{}[i].size();
|
||||
return number<size_>{};
|
||||
},
|
||||
number<NDimX>{});
|
||||
return uniformed_h_dim_lengths;
|
||||
}
|
||||
|
||||
// note: this function only count the p dim length along h, not r
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_p_dim_lengths_over_h()
|
||||
{
|
||||
// e.g. tuple<seq<1, 4, 32>, seq<1, 2, 8, 4, 4>>
|
||||
// Y P Y Y P Y P Y
|
||||
// | | |
|
||||
// v v v
|
||||
// return : seq<4, 2 * 4> => seq<4, 8>
|
||||
constexpr auto uniformed_ps_to_rhss_major_ =
|
||||
unpack([](auto... xs_) { return merge_sequences(xs_...); }, ps_to_rhss_major_);
|
||||
constexpr auto uniformed_ps_to_rhss_minor_ =
|
||||
unpack([](auto... xs_) { return merge_sequences(xs_...); }, ps_to_rhss_minor_);
|
||||
|
||||
constexpr auto p_len_ = [&]() {
|
||||
array<index_t, NDimX> len_{1};
|
||||
static_for<0, NDimX, 1>{}([&](auto idim_x_) {
|
||||
constexpr auto major_ = number<idim_x_ + 1>{}; // RDim
|
||||
static_for<0, uniformed_ps_to_rhss_major_.size(), 1>{}([&](auto idim_u_) {
|
||||
if constexpr(major_.value == uniformed_ps_to_rhss_major_[idim_u_])
|
||||
{
|
||||
constexpr auto minor_ = uniformed_ps_to_rhss_minor_[idim_u_];
|
||||
constexpr auto h_length_ = hs_lengthss_[idim_x_][minor_];
|
||||
len_[idim_x_] *= h_length_;
|
||||
}
|
||||
});
|
||||
});
|
||||
return len_;
|
||||
}();
|
||||
constexpr auto p_len_over_h_seq_ = TO_SEQUENCE(p_len_, NDimX);
|
||||
return p_len_over_h_seq_;
|
||||
}
|
||||
|
||||
//
|
||||
// R: seq<3>, H: tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>>
|
||||
// => return seq<1, 3, 5>
|
||||
// R: seq<>, H: tuple<seq<2, 4>, seq<16, 8, 8>>
|
||||
// => return seq<0, 2, 3>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_rh_dim_lengths()
|
||||
{
|
||||
constexpr auto uniformed_rh_dim_lengths =
|
||||
merge_sequences(sequence<NDimR>{} /*for R dims*/, get_uniformed_h_dim_lengths());
|
||||
|
||||
return uniformed_rh_dim_lengths;
|
||||
}
|
||||
|
||||
// e.g. tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>> --> seq<3, 5> --> seq<0, 3, 8>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_h_dim_lengths_prefix_sum()
|
||||
{
|
||||
// <0, len_d0, len_d0+len_d1, ...>
|
||||
// e.g. seq<3, 5> --> seq<0, 3, 8>
|
||||
constexpr auto h_dim_prefix_sum = prefix_sum_sequence(uniformed_h_dim_lengths);
|
||||
constexpr auto h_dim_prefix_sum = prefix_sum_sequence(get_uniformed_h_dim_lengths());
|
||||
|
||||
return h_dim_prefix_sum;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_idx_y_to_h()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_rh_dim_lengths_prefix_sum()
|
||||
{
|
||||
// <0, len_d0, len_d0+len_d1, ...>
|
||||
// e.g. seq<3, 5> --> seq<0, 3, 8>
|
||||
constexpr auto rh_dim_prefix_sum = prefix_sum_sequence(get_uniformed_rh_dim_lengths());
|
||||
|
||||
return rh_dim_prefix_sum;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_idx_p_to_h()
|
||||
{
|
||||
// tuple<seq<xx..>, seq<yy..>> -> seq<xx..yy..>
|
||||
constexpr auto uniformed_ps_to_rhss_major_ =
|
||||
unpack([](auto... xs_) { return merge_sequences(xs_...); }, ps_to_rhss_major_);
|
||||
constexpr auto uniformed_ps_to_rhss_minor_ =
|
||||
unpack([](auto... xs_) { return merge_sequences(xs_...); }, ps_to_rhss_minor_);
|
||||
|
||||
constexpr auto all_ps_2_rhss = transform_sequences(
|
||||
[](auto major, auto minor) constexpr {
|
||||
constexpr auto rh_dim_prefix_sum = get_rh_dim_lengths_prefix_sum();
|
||||
return rh_dim_prefix_sum.at(major) + minor;
|
||||
},
|
||||
uniformed_ps_to_rhss_major_,
|
||||
uniformed_ps_to_rhss_minor_);
|
||||
|
||||
return all_ps_2_rhss;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_idx_y_to_rh()
|
||||
{
|
||||
constexpr auto all_ys_2_rhss = transform_sequences(
|
||||
[](auto major, auto minor) constexpr {
|
||||
// <0, 0, len_d0, len_d0+len_d1, ...>
|
||||
constexpr auto x_dim_prefix_sum = merge_sequences(
|
||||
sequence<0>{} /*for R dims*/, get_h_dim_lengths_prefix_sum());
|
||||
return x_dim_prefix_sum.at(major) + minor;
|
||||
constexpr auto rh_dim_prefix_sum = get_rh_dim_lengths_prefix_sum();
|
||||
return rh_dim_prefix_sum.at(major) + minor;
|
||||
},
|
||||
Ys2RHsMajor{},
|
||||
Ys2RHsMinor{});
|
||||
@@ -289,6 +368,45 @@ struct tile_distribution_encoding
|
||||
return all_ys_2_rhss;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_idx_y_to_h()
|
||||
{
|
||||
// TODO: Y can't point to R
|
||||
constexpr auto all_ys_2_rhss = transform_sequences(
|
||||
[](auto major, auto minor) constexpr {
|
||||
constexpr auto rh_dim_prefix_sum = get_rh_dim_lengths_prefix_sum();
|
||||
return rh_dim_prefix_sum.at(major) + minor - NDimR;
|
||||
},
|
||||
Ys2RHsMajor{},
|
||||
Ys2RHsMinor{});
|
||||
|
||||
return all_ys_2_rhss;
|
||||
}
|
||||
|
||||
// return tuple of seq
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_y_to_h_masks()
|
||||
{
|
||||
constexpr auto masks_ = generate_tuple(
|
||||
[&](auto i) {
|
||||
constexpr auto size_ = HsLengthss{}[i].size();
|
||||
constexpr auto current_y_to_h_mask_ = [&]() {
|
||||
array<index_t, size_> m_{0};
|
||||
// TODO: we loop over all y for each h dim
|
||||
for(auto j = 0; j < NDimY; j++)
|
||||
{
|
||||
if(Ys2RHsMajor{}[j] == (i + 1) /*RDim need plus 1*/)
|
||||
{
|
||||
m_[Ys2RHsMinor{}[j]] = 1;
|
||||
}
|
||||
}
|
||||
return m_;
|
||||
}();
|
||||
|
||||
return TO_SEQUENCE(current_y_to_h_mask_, size_);
|
||||
},
|
||||
number<NDimX>{});
|
||||
return masks_;
|
||||
}
|
||||
|
||||
// return tuple<sorted_dims, sorted_maps, sorted_prefix_sum>
|
||||
template <typename IdxSeq, typename PrefixSumSeq>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_sorted_info(IdxSeq, PrefixSumSeq)
|
||||
@@ -305,7 +423,8 @@ struct tile_distribution_encoding
|
||||
return make_tuple(sorted_dims, sorted_maps, sorted_prefix_sum);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_sorted_y_info()
|
||||
// Note here y_to_h does not count R dim!
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_sorted_y_to_h_info()
|
||||
{
|
||||
return get_sorted_info(get_uniformed_idx_y_to_h(), get_h_dim_lengths_prefix_sum());
|
||||
}
|
||||
|
||||
@@ -344,37 +344,82 @@ struct tile_window_with_static_distribution
|
||||
{
|
||||
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
|
||||
using LdsDataType = typename LdsTileWindow::DataType;
|
||||
|
||||
// issues * warps * lanes
|
||||
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
|
||||
|
||||
// TODO: LDS offset is not good for intrinsic based implementation(compiler can't figure out
|
||||
// dependency) hence avoid use offset based solution. size_per_buf should be zero (how to
|
||||
// check?)
|
||||
constexpr index_t size_per_buf =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<0>{}, number<0>{}, number<0>{}));
|
||||
|
||||
constexpr index_t size_per_wave =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<0>{}, number<1>{}, number<0>{})) -
|
||||
size_per_buf;
|
||||
|
||||
constexpr index_t size_per_issue =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<1>{}, number<0>{}, number<0>{})) -
|
||||
size_per_buf;
|
||||
|
||||
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
|
||||
|
||||
using Traits = typename Base::Traits;
|
||||
using Traits = typename Base::Traits;
|
||||
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
// TODO: we force CK_TILE_LDS_ADDR
|
||||
CK_TILE_LDS_ADDR LdsDataType* smem =
|
||||
lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ + m0_init_value;
|
||||
// Precompute invariant values outside loops
|
||||
const auto window_origin = lds_tile.get_window_origin();
|
||||
const auto& bottom_tensor_view = lds_tile.get_bottom_tensor_view();
|
||||
const auto& tensor_descriptor = bottom_tensor_view.get_tensor_descriptor();
|
||||
auto smem_base_ptr = bottom_tensor_view.get_buffer_view().p_data_;
|
||||
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
|
||||
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
|
||||
|
||||
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
|
||||
// Use precomputed window origin
|
||||
auto lds_bottom_tensor_thread_idx =
|
||||
window_origin + window_adaptor_thread_coord.get_bottom_index();
|
||||
|
||||
// Use precomputed tensor descriptor
|
||||
const auto lds_coord =
|
||||
make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx);
|
||||
|
||||
// Calculate SMEM address using base pointer
|
||||
CK_TILE_LDS_ADDR LdsDataType* smem = smem_base_ptr + lds_coord.get_offset();
|
||||
|
||||
// Write into bottom tensor
|
||||
this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
|
||||
smem,
|
||||
bottom_tensor_thread_coord,
|
||||
number<0>{},
|
||||
bool_constant<oob_conditional_check>{});
|
||||
|
||||
// Move thread coordinate if not last access
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Policy, index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_transpose() const
|
||||
{
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
|
||||
this->template load_transpose<Policy>(
|
||||
dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
|
||||
return dst_tensor;
|
||||
}
|
||||
|
||||
template <typename Policy,
|
||||
typename DistributedTensor,
|
||||
index_t i_access_unsupport_ = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_transpose(DistributedTensor& dst_tensor,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using Traits = typename Base::Traits;
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
|
||||
constexpr auto group_func = Policy::group_func;
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
@@ -385,10 +430,31 @@ struct tile_window_with_static_distribution
|
||||
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
|
||||
// read from bottom tensor
|
||||
this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
|
||||
smem, bottom_tensor_thread_coord, 0, bool_constant<oob_conditional_check>{});
|
||||
// data index [y0, y1, ...]
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
|
||||
|
||||
// read from bottom tensor
|
||||
const vector_t vec_value =
|
||||
this->get_bottom_tensor_view()
|
||||
.template get_transpose_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord, 0);
|
||||
// write into distributed tensor
|
||||
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
|
||||
constexpr auto orig_idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
: idx_ys_start[jj];
|
||||
},
|
||||
number<Base::NDimY>{});
|
||||
|
||||
constexpr auto grouped_idx_ys = group_func(orig_idx_ys);
|
||||
|
||||
constexpr index_t linear_distributed_index =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(grouped_idx_ys);
|
||||
|
||||
dst_tensor.get_thread_buffer().template at<linear_distributed_index>() =
|
||||
vec_value.template get_as<typename Base::DataType>()[j];
|
||||
});
|
||||
// move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
@@ -400,8 +466,6 @@ struct tile_window_with_static_distribution
|
||||
|
||||
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
|
||||
smem += size_per_issue; // Note we manually increase the per-issue offset
|
||||
}
|
||||
});
|
||||
});
|
||||
@@ -415,7 +479,6 @@ struct tile_window_with_static_distribution
|
||||
{
|
||||
using Traits = typename Base::Traits;
|
||||
|
||||
// using vector_type_t = typename Traits::vector_type_t;
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
|
||||
@@ -186,7 +186,7 @@ struct tile_window_linear
|
||||
const typename Base::WindowLengths& window_lengths,
|
||||
const typename Base::BottomTensorIndex& window_origin,
|
||||
const typename Base::TileDstr& tile_distribution)
|
||||
: cached_coords_{}, cached_flags_{}
|
||||
: cached_coords_{}, cached_window_adaptor_coords_{}, cached_flags_{}
|
||||
{
|
||||
this->bottom_tensor_view_ = bottom_tensor_view;
|
||||
this->window_lengths_ = window_lengths;
|
||||
@@ -214,7 +214,8 @@ struct tile_window_linear
|
||||
|
||||
if constexpr(need_save_non_linear_coord)
|
||||
{
|
||||
cached_coords_(non_linear_id) = bottom_tensor_thread_coord_tmp;
|
||||
cached_coords_(non_linear_id) = bottom_tensor_thread_coord_tmp;
|
||||
cached_window_adaptor_coords_(non_linear_id) = window_adaptor_thread_coord_tmp;
|
||||
}
|
||||
|
||||
// TODO: need pad_tensor_view to check which dim need use flag to check
|
||||
@@ -314,8 +315,7 @@ struct tile_window_linear
|
||||
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
|
||||
auto dst_tensor =
|
||||
make_static_distributed_tensor<typename Base::DataTypeDataType>(tile_dstr);
|
||||
auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
|
||||
|
||||
auto issue = [&](auto i_access_) {
|
||||
constexpr auto IAccess = number<i_access_>{};
|
||||
@@ -348,8 +348,9 @@ struct tile_window_linear
|
||||
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
Base::Traits::PackedSize;
|
||||
|
||||
dst_tensor.get_thread_buffer().template at<d>() = vec_value.template get_as<
|
||||
typename Base::DataTypeDataType>()[j / Base::Traits::PackedSize];
|
||||
dst_tensor.get_thread_buffer().template at<d>() =
|
||||
vec_value
|
||||
.template get_as<typename Base::DataType>()[j / Base::Traits::PackedSize];
|
||||
});
|
||||
};
|
||||
|
||||
@@ -400,8 +401,9 @@ struct tile_window_linear
|
||||
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
Base::Traits::PackedSize;
|
||||
|
||||
dst_tensor.get_thread_buffer().template at<d>() = vec_value.template get_as<
|
||||
typename Base::DataTypeDataType>()[j / Base::Traits::PackedSize];
|
||||
dst_tensor.get_thread_buffer().template at<d>() =
|
||||
vec_value
|
||||
.template get_as<typename Base::DataType>()[j / Base::Traits::PackedSize];
|
||||
});
|
||||
};
|
||||
|
||||
@@ -553,66 +555,101 @@ struct tile_window_linear
|
||||
{
|
||||
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
|
||||
using LdsDataType = typename LdsTileWindow::DataType;
|
||||
using vector_t = typename traits::vector_t;
|
||||
|
||||
// currently we only support everything is non linear dim
|
||||
// actually it's not performant if we have linear dim(e.g. fast changing)
|
||||
static_assert(NumAccess_NonLinear == NumAccess);
|
||||
static_assert(NumAccess_NonLinear == NumAccess, "Unsupported configuration");
|
||||
static_assert(Base::BottomTensorView::buffer_view::get_address_space() ==
|
||||
address_space_enum::global);
|
||||
address_space_enum::global,
|
||||
"Requires global memory");
|
||||
|
||||
// issues * warps * lanes
|
||||
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
|
||||
// Precompute invariant values outside the lambda
|
||||
const auto window_origin = lds_tile.get_window_origin();
|
||||
const auto& bottom_tensor_view = lds_tile.get_bottom_tensor_view();
|
||||
const auto& tensor_descriptor = bottom_tensor_view.get_tensor_descriptor();
|
||||
auto smem_base_ptr = bottom_tensor_view.get_buffer_view().p_data_;
|
||||
|
||||
// TODO: LDS offset is not good for intrinsic based implementation(compiler can't figure out
|
||||
// dependency) hence avoid use offset based solution. size_per_buf should be zero (how to
|
||||
// check?)
|
||||
constexpr index_t size_per_buf =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<0>{}, number<0>{}, number<0>{}));
|
||||
|
||||
constexpr index_t size_per_wave =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<0>{}, number<1>{}, number<0>{})) -
|
||||
size_per_buf;
|
||||
|
||||
constexpr index_t size_per_issue =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<1>{}, number<0>{}, number<0>{})) -
|
||||
size_per_buf;
|
||||
|
||||
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
|
||||
|
||||
using vector_t = typename Base::Traits::vector_t;
|
||||
|
||||
// TODO: we force CK_TILE_LDS_ADDR
|
||||
CK_TILE_LDS_ADDR LdsDataType* smem =
|
||||
lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ + m0_init_value;
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
auto issue = [&](auto i_access_) {
|
||||
constexpr auto IAccess = number<i_access_>{};
|
||||
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
|
||||
constexpr auto IAccess = number<i_access_>{};
|
||||
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
|
||||
|
||||
// Use precomputed values
|
||||
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
|
||||
auto window_adaptor_coord = cached_window_adaptor_coords_[non_linear_id];
|
||||
auto bottom_tensor_flag = cached_flags_[IAccess];
|
||||
|
||||
// read from bottom tensor
|
||||
auto lds_bottom_tensor_thread_idx =
|
||||
window_origin + window_adaptor_coord.get_bottom_index();
|
||||
const auto lds_coord =
|
||||
make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx);
|
||||
|
||||
CK_TILE_LDS_ADDR LdsDataType* smem = smem_base_ptr + lds_coord.get_offset();
|
||||
|
||||
// Read from bottom tensor
|
||||
this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
|
||||
smem,
|
||||
bottom_tensor_thread_coord,
|
||||
0,
|
||||
bottom_tensor_flag,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
|
||||
// move thread coordinate
|
||||
if constexpr(i_access_ != (NumAccess - 1))
|
||||
{
|
||||
smem += size_per_issue; // Note we manually increase the per-issue offset
|
||||
}
|
||||
};
|
||||
|
||||
WINDOW_DISPATCH_ISSUE();
|
||||
}
|
||||
|
||||
template <typename Policy, index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_transpose() const
|
||||
{
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
|
||||
this->template load_transpose_linear<Policy>(
|
||||
dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
|
||||
return dst_tensor;
|
||||
}
|
||||
|
||||
template <typename Policy,
|
||||
typename DistributedTensor,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_transpose_linear(DistributedTensor& dst_tensor,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using vector_t = typename traits::vector_t;
|
||||
using SFC_Ys = typename traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
|
||||
constexpr auto group_func = Policy::group_func;
|
||||
|
||||
auto issue = [&](auto i_access_) {
|
||||
constexpr auto IAccess = number<i_access_>{};
|
||||
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
|
||||
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
|
||||
auto bottom_tensor_flag = cached_flags_[IAccess];
|
||||
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess);
|
||||
|
||||
// read from bottom tensor
|
||||
const vector_t vec_value =
|
||||
this->get_bottom_tensor_view().template get_transpose_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord, 0);
|
||||
// write into distributed tensor
|
||||
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
|
||||
},
|
||||
number<Base::NDimY>{});
|
||||
|
||||
constexpr index_t linear_distributed_index =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
|
||||
dst_tensor.get_thread_buffer().template at<linear_distributed_index>() =
|
||||
vec_value.template get_as<typename Base::DataType>()[j];
|
||||
});
|
||||
};
|
||||
WINDOW_DISPATCH_ISSUE();
|
||||
}
|
||||
|
||||
template <index_t i_access = -1, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE void store(const static_distributed_tensor<typename Base::DataType,
|
||||
typename Base::TileDstr>& dstr_tensor,
|
||||
@@ -750,8 +787,7 @@ struct tile_window_linear
|
||||
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
Base::Traits::PackedSize;
|
||||
|
||||
vec_value.template get_as<typename Base::DataTypeDataType>()(
|
||||
j / Base::Traits::PackedSize) =
|
||||
vec_value.template get_as<typename Base::DataType>()(j / Base::Traits::PackedSize) =
|
||||
dstr_tensor.get_thread_buffer().template at<d>();
|
||||
});
|
||||
|
||||
@@ -806,8 +842,7 @@ struct tile_window_linear
|
||||
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
Base::Traits::PackedSize;
|
||||
|
||||
vec_value.template get_as<typename Base::DataTypeDataType>()(
|
||||
j / Base::Traits::PackedSize) =
|
||||
vec_value.template get_as<typename Base::DataType>()(j / Base::Traits::PackedSize) =
|
||||
dstr_tensor.get_thread_buffer().template at<d>();
|
||||
});
|
||||
|
||||
@@ -875,7 +910,8 @@ struct tile_window_linear
|
||||
|
||||
if constexpr(need_save_non_linear_coord)
|
||||
{
|
||||
cached_coords_(non_linear_id) = bottom_tensor_thread_coord_tmp;
|
||||
cached_coords_(non_linear_id) = bottom_tensor_thread_coord_tmp;
|
||||
cached_window_adaptor_coords_(non_linear_id) = window_adaptor_thread_coord_tmp;
|
||||
}
|
||||
|
||||
if constexpr(i_access != (NumAccess - 1))
|
||||
@@ -895,6 +931,8 @@ struct tile_window_linear
|
||||
|
||||
// this contains:
|
||||
array<typename Base::BottomTensorCoord, traits::NumAccess_NonLinear> cached_coords_;
|
||||
array<typename Base::WindowAdaptorCoord, traits::NumAccess_NonLinear>
|
||||
cached_window_adaptor_coords_;
|
||||
array<bool, Base::Traits::NumAccess> cached_flags_;
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user