Merge branch 'test_copy_fix' of https://github.com/ROCm/composable_kernel into fa_decode_pipeline

This commit is contained in:
aska-0096
2025-07-17 07:24:32 +00:00
430 changed files with 41159 additions and 6951 deletions

View File

@@ -1,6 +1,73 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
/**
* @file
* We're defining the data access pattern for a 2D window (`XPerTile` by `YPerTile`)
for `BlockSize` threads in a thread block.
* X dimension is considered contiguous in memory, so a single instruction can access
several adjacent and properly aligned elements (vector); the access pattern along X tile
dimension is parameterized only by the suggested vector size `VecSize`.
* We can't access more than `MaxVecSize = TileElementsPerThread = TileSize / BlockSize` elements
with a single memory access, so the actual vector size along the X dimension is
`X0 = min(MaxVecSize, VecSize)`.
* This leaves `X1 = XPerTile / X0` threads per tile in X dimension.
* X1 is also the number of threads per warp in X dimension, that is,
X dimension is not split between warps, and each warp accesses X dimension entirely,
and there is no iteration in X dimension.
* The tuple <X0, X1> defines the X-axis access pattern.
This part is common between the 2D distribution patterns.
* What's different between the different 2D distribution patterns, is the Y axis access pattern.
* There are 3 components in this access pattern;
* (1) number of Y-axis elements (rows) per warp for a single instruction access,
* (2) number of warps per thread block,
* (3) number of iterations to cover the entire Y axis.
* The raked here represents how data is partitioned across different processing granularity.
* It represents howe we are going to access the data in thread, warp, or blocked in contiguous
region.
* From below, the qualifier for 'raked' is the part of warp/thread hierarchy
* in the split of Y tile dimension where the iteration happens,
* meaning, the iteration can be logically inserted as a tile dimension in 3 ways,
* (1) after thread -> thread-raked,
* (2) between warp and thread -> warp-raked,
* (3) before warp -> block-raked
* *Thread raked*
* Y0 is the number of warps, which we can get from the equation `Y0 * WarpSize == BlockSize`
* Y1 is the number of rows accessed by a warp within a single iteration,
compute it from the equation `Y0 * X1 == WarpSize`
* Y2 is the number of iterations to cover the tile,
compute it from the equation `Y0 * Y1 * Y2 == YPerTile`
* *Warp raked*
* Y0 is the number of warps, we can get it in the same way as for thread-raked pattern,
`Y0 * WarpSize == BlockSize`
* Y1 is the number of iterations to cover the tile, `Y0 * Y1 * Y2 == YPerTile`.
Compute Y2 from the equation below
* Y2 is the number of rows accessed by a warp in a single iteration, `Y2 * X1 == WarpSize`
* *Block raked*
* Y0 is the number of iterations to cover the tile, `Y0 * Y1 * Y2 == YPerTile`.
Compute Y1 and Y2 from the equations below
* Y1 is the number of warps, `Y1 * WarpSize == BlockSize`
* Y2 is the number of rows accessed by a warp in a single iteration, `Y2 * X1 == WarpSize`
* In all cases, the tuple <Y0, Y1, Y2> defines the Y-axis access pattern.
* *Selection*
* When we are selecting, Thread-raked is used in element-wise operation because it is the
* Thread-major memory order.
* Warp-raked is used in matrix multiplication because the vectorization is in warp level.
* Block-raked is used mostly for the reduction process, where will reduce the block in global
* atomic level.
*
*/
#pragma once
#include "ck_tile/core/arch/arch.hpp"
@@ -105,9 +172,9 @@ struct TileDistributionEncodingPattern2D<BlockSize,
tile_distribution_encoding<sequence<Y0>,
tuple<sequence<Y1, Y2>, sequence<X0, X1>>,
tuple<sequence<0>, sequence<1, 2>>,
tuple<sequence<0>, sequence<0, 0>>,
tuple<sequence<0>, sequence<0, 0>>, // -> <Y0>, <Y1, X0>
sequence<1, 2>,
sequence<1, 1>>{});
sequence<1, 1>>{}); // -> <Y2, X1>
}
else
{
@@ -115,9 +182,9 @@ struct TileDistributionEncodingPattern2D<BlockSize,
tile_distribution_encoding<sequence<1>,
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
tuple<sequence<0>, sequence<1, 0>>, // -> <Y0>, <Y1, X0>
sequence<1, 2>,
sequence<2, 1>>{});
sequence<2, 1>>{}); // -> <Y2, X1>
}
}
@@ -129,9 +196,9 @@ struct TileDistributionEncodingPattern2D<BlockSize,
tile_distribution_encoding<sequence<Y0>,
tuple<sequence<X0, X1>, sequence<Y1, Y2>>,
tuple<sequence<0>, sequence<2, 1>>,
tuple<sequence<0>, sequence<0, 0>>,
tuple<sequence<0>, sequence<0, 0>>, // -> <Y0>, <Y1, X0>
sequence<1, 2>,
sequence<1, 1>>{});
sequence<1, 1>>{}); // -> <X1, Y2>
}
else
{
@@ -139,9 +206,9 @@ struct TileDistributionEncodingPattern2D<BlockSize,
tile_distribution_encoding<sequence<1>,
tuple<sequence<X0, X1>, sequence<Y0, Y1, Y2>>,
tuple<sequence<2>, sequence<2, 1>>,
tuple<sequence<0>, sequence<1, 0>>,
tuple<sequence<0>, sequence<1, 0>>, // -> <Y0>, <Y1, X0>
sequence<1, 2>,
sequence<1, 2>>{});
sequence<1, 2>>{}); // -> <X1, Y2>
}
}
};
@@ -182,9 +249,9 @@ struct TileDistributionEncodingPattern2D<BlockSize,
tile_distribution_encoding<sequence<1>,
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<2, 0>>,
tuple<sequence<0>, sequence<2, 0>>, // -> <Y0>, <Y2, X0>
sequence<1, 2>,
sequence<1, 1>>{});
sequence<1, 1>>{}); // -> <Y1, X1>
}
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution()
@@ -193,9 +260,9 @@ struct TileDistributionEncodingPattern2D<BlockSize,
tile_distribution_encoding<sequence<1>,
tuple<sequence<X0, X1>, sequence<Y0, Y1, Y2>>,
tuple<sequence<2>, sequence<2, 1>>,
tuple<sequence<0>, sequence<2, 0>>,
tuple<sequence<0>, sequence<2, 0>>, // -> <Y0>, <Y2, X0>
sequence<1, 2>,
sequence<1, 1>>{});
sequence<1, 1>>{}); // -> <X1, Y1>
}
};
@@ -233,9 +300,9 @@ struct TileDistributionEncodingPattern2D<BlockSize,
tile_distribution_encoding<sequence<1>,
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
tuple<sequence<1>, sequence<2, 0>>, // -> <Y1>, <Y2, X0>
sequence<1, 2>,
sequence<0, 1>>{});
sequence<0, 1>>{}); // -> <Y0, X1>
}
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution()
@@ -244,9 +311,9 @@ struct TileDistributionEncodingPattern2D<BlockSize,
tile_distribution_encoding<sequence<1>,
tuple<sequence<X0, X1>, sequence<Y0, Y1, Y2>>,
tuple<sequence<2>, sequence<2, 1>>,
tuple<sequence<1>, sequence<2, 0>>,
tuple<sequence<1>, sequence<2, 0>>, // -> <Y1>, <Y2, X0>
sequence<1, 2>,
sequence<1, 0>>{});
sequence<1, 0>>{}); // -> <X1, Y0>
}
};

View File

@@ -13,6 +13,7 @@
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/ignore.hpp"
// This attribute gives a hint to the compiler that a branch is likely to be taken.
// Then, the compiler should remove if possible the associated s_cbranch_execz branch that would
@@ -23,6 +24,8 @@
#define LIKELY(x) (__builtin_expect(!!(x), 1))
#endif
using as3_uint32_ptr = uint32_t __attribute__((address_space(3)))*;
namespace ck_tile {
// 128 bit SGPRs to supply buffer resource in buffer instructions
@@ -1270,7 +1273,7 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
// Direct loads from global to LDS.
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
__attribute__((address_space(3))) uint32_t* lds_ptr,
as3_uint32_ptr lds_ptr,
index_t size,
index_t voffset,
index_t soffset,
@@ -1760,7 +1763,7 @@ template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool pre_nop = false>
CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem,
CK_TILE_DEVICE void amd_async_buffer_load_impl(CK_TILE_LDS_ADDR T* smem,
int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset,
@@ -1790,29 +1793,61 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
index_t flag = 0,
bool_constant<oob_conditional_check> = {})
{
static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size");
constexpr index_t bytes = sizeof(T) * N;
#if defined(__gfx950__)
static_assert(bytes == 4 || bytes == 12 || bytes == 16,
"wrong! only support in dword, dwordx3, dwordx4");
ignore = src_wave_addr_offset;
ignore = src_immediate_addr_offset;
if constexpr(oob_conditional_check)
{
index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2];
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
smem,
sizeof(uint32_t),
v_offset,
src_wave_addr_offset,
src_immediate_addr_offset,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_load_lds(
src_wave_buffer_resource,
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
bytes,
v_offset,
0,
0,
static_cast<index_t>(coherence));
}
else
{
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
smem,
sizeof(uint32_t),
src_thread_addr_offset,
src_wave_addr_offset,
src_immediate_addr_offset,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_load_lds(
src_wave_buffer_resource,
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
bytes,
src_thread_addr_offset,
0,
0,
static_cast<index_t>(coherence));
}
#else
static_assert(bytes == 4, "wrong! not implemented vector size");
if constexpr(oob_conditional_check)
{
index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2];
llvm_amdgcn_raw_buffer_load_lds(
src_wave_buffer_resource,
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
bytes,
v_offset,
src_wave_addr_offset,
src_immediate_addr_offset,
static_cast<index_t>(coherence));
}
else
{
llvm_amdgcn_raw_buffer_load_lds(
src_wave_buffer_resource,
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
bytes,
src_thread_addr_offset,
src_wave_addr_offset,
src_immediate_addr_offset,
static_cast<index_t>(coherence));
}
#endif
}
template <index_t N,
@@ -2786,15 +2821,52 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
: "memory");
#else
// LDS pointer must be attributed with the LDS address space.
__attribute__((address_space(3))) uint32_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset));
as3_uint32_ptr lds_ptr =
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset));
llvm_amdgcn_raw_buffer_load_lds(
src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0);
#endif
}
#if defined(__gfx950__)
template <typename T, index_t N, address_space_enum BufferAddressSpace>
__device__ auto amd_transpose_load_to_vgpr(const T* in_ptr)
{
static_assert(__has_builtin(__builtin_amdgcn_raw_buffer_load_b32),
"We need to have the compatible compiler version to build this instruction");
if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::half_t>)
{
typedef __attribute__((__vector_size__(4 * sizeof(__fp16)))) __fp16 llvm_fp16x4_t;
__attribute__((address_space(3))) llvm_fp16x4_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) llvm_fp16x4_t*>(
reinterpret_cast<uintptr_t>(in_ptr));
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4f16(lds_ptr));
}
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::bf16_t>)
{
typedef __attribute__((__vector_size__(4 * sizeof(__bf16)))) __bf16 llvm_bf16x4_t;
__attribute__((address_space(3))) llvm_bf16x4_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) llvm_bf16x4_t*>(
reinterpret_cast<uintptr_t>(in_ptr));
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4bf16(lds_ptr));
}
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::fp8_t>)
{
typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_fp8x8_t;
__attribute__((address_space(3))) llvm_fp8x8_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) llvm_fp8x8_t*>(
reinterpret_cast<uintptr_t>(in_ptr));
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr8_b64_v2i32(lds_ptr));
}
else
{
static_assert(false, "not implemented");
}
}
#endif
} // namespace ck_tile
#endif // !CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN

View File

@@ -13,6 +13,9 @@
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/ignore.hpp"
using as3_uint32_ptr = uint32_t __attribute__((address_space(3)))*;
namespace ck_tile {
@@ -1138,7 +1141,7 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
// Direct loads from global to LDS.
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
__attribute__((address_space(3))) uint32_t* lds_ptr,
as3_uint32_ptr lds_ptr,
index_t size,
index_t voffset,
index_t soffset,
@@ -1560,29 +1563,61 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
index_t flag = 0,
bool_constant<oob_conditional_check> = {})
{
static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size");
constexpr index_t bytes = sizeof(T) * N;
#if defined(__gfx950__)
static_assert(bytes == 4 || bytes == 12 || bytes == 16,
"wrong! only support in dword, dwordx3, dwordx4");
ignore = src_wave_addr_offset;
ignore = src_immediate_addr_offset;
if constexpr(oob_conditional_check)
{
index_t v_offset = flag ? v_offset : src_wave_buffer_resource[2];
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
smem,
sizeof(uint32_t),
v_offset,
src_wave_addr_offset,
src_immediate_addr_offset,
static_cast<index_t>(coherence));
index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2];
llvm_amdgcn_raw_buffer_load_lds(
src_wave_buffer_resource,
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
bytes,
v_offset,
0,
0,
static_cast<index_t>(coherence));
}
else
{
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
smem,
sizeof(uint32_t),
src_thread_addr_offset,
src_wave_addr_offset,
src_immediate_addr_offset,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_load_lds(
src_wave_buffer_resource,
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
bytes,
src_thread_addr_offset,
0,
0,
static_cast<index_t>(coherence));
}
#else
static_assert(bytes == 4, "wrong! not implemented vector size");
if constexpr(oob_conditional_check)
{
index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2];
llvm_amdgcn_raw_buffer_load_lds(
src_wave_buffer_resource,
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
bytes,
v_offset,
src_wave_addr_offset,
src_immediate_addr_offset,
static_cast<index_t>(coherence));
}
else
{
llvm_amdgcn_raw_buffer_load_lds(
src_wave_buffer_resource,
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
bytes,
src_thread_addr_offset,
src_wave_addr_offset,
src_immediate_addr_offset,
static_cast<index_t>(coherence));
}
#endif
}
template <index_t N,
@@ -2556,15 +2591,52 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
: "memory");
#else
// LDS pointer must be attributed with the LDS address space.
__attribute__((address_space(3))) uint32_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset));
as3_uint32_ptr lds_ptr =
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset));
llvm_amdgcn_raw_buffer_load_lds(
src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0);
#endif
}
#if defined(__gfx950__)
template <typename T, index_t N, address_space_enum BufferAddressSpace>
__device__ auto amd_transpose_load_to_vgpr(const T* in_ptr)
{
static_assert(__has_builtin(__builtin_amdgcn_raw_buffer_load_b32),
"We need to have the compatible compiler version to build this instruction");
if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::half_t>)
{
typedef __attribute__((__vector_size__(4 * sizeof(__fp16)))) __fp16 llvm_fp16x4_t;
__attribute__((address_space(3))) llvm_fp16x4_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) llvm_fp16x4_t*>(
reinterpret_cast<uintptr_t>(in_ptr));
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4f16(lds_ptr));
}
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::bf16_t>)
{
typedef __attribute__((__vector_size__(4 * sizeof(__bf16)))) __bf16 llvm_bf16x4_t;
__attribute__((address_space(3))) llvm_bf16x4_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) llvm_bf16x4_t*>(
reinterpret_cast<uintptr_t>(in_ptr));
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4bf16(lds_ptr));
}
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::fp8_t>)
{
typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_fp8x8_t;
__attribute__((address_space(3))) llvm_fp8x8_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) llvm_fp8x8_t*>(
reinterpret_cast<uintptr_t>(in_ptr));
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr8_b64_v2i32(lds_ptr));
}
else
{
static_assert(false, "not implemented");
}
}
#endif
} // namespace ck_tile
#endif // CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN

View File

@@ -0,0 +1,86 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
namespace ck_tile {
// this generate wave level tile distribution
template <typename T, typename = void>
struct LaneGroupTransposeTraits;
template <typename T>
struct LaneGroupTransposeTraits<T, std::enable_if_t<sizeof(T) == 2>>
{
// before transpose, 4x16
static constexpr index_t ksecondDim = 4;
static constexpr index_t kleadDim = 16;
// after transpose, 16x4
static constexpr index_t ksecondDimT = 16;
static constexpr index_t kleadDimT = 4;
template <index_t kOuterDistDim0,
index_t kOuterDistDim1,
index_t kInnerDistDim0,
index_t kInnerDistDim1>
using TileDistribution =
tile_distribution_encoding<sequence<>,
tuple<sequence<kOuterDistDim0, kOuterDistDim1, 4>,
sequence<kInnerDistDim0, kInnerDistDim1, 4, 4>>,
tuple<sequence<1, 2, 1, 2>>,
tuple<sequence<0, 0, 2, 2>>,
sequence<2, 1, 2>,
sequence<1, 1, 3>>;
};
template <typename T>
struct LaneGroupTransposeTraits<T, std::enable_if_t<sizeof(T) == 1>>
{
static constexpr index_t ksecondDim = 8;
static constexpr index_t kleadDim = 16;
static constexpr index_t ksecondDimT = 16;
static constexpr index_t kleadDimT = 8;
template <index_t kOuterDistDim0,
index_t kOuterDistDim1,
index_t kInnerDistDim0,
index_t kInnerDistDim1>
using TileDistribution =
tile_distribution_encoding<sequence<>,
tuple<sequence<kOuterDistDim0, kOuterDistDim1, 8>,
sequence<kInnerDistDim0, kInnerDistDim1, 2, 8>>,
tuple<sequence<1, 2, 1, 2>>,
tuple<sequence<0, 0, 2, 2>>,
sequence<2, 1, 2>,
sequence<1, 1, 3>>;
};
/*
* @brief This function is used to generate the transposed distribution encoding
* for the given data type and distribution dimensions.
*
* @tparam T The data type of the elements in the tensor.
* @tparam kOuterDistDim0 The outer distribution dimension 0, which is outer dimension for stride.
* @tparam kOuterDistDim1 The outer distribution dimension 1, which is inner dimension for stride.
* @tparam kInnerDistDim0 The inner distribution dimension 0, which is outer dimension for
* consecutive.
* @tparam kInnerDistDim1 The inner distribution dimension 1, which is inner dimension for
* consecutive.
*/
template <typename T,
index_t kOuterDistDim0,
index_t kOuterDistDim1,
index_t kInnerDistDim0,
index_t kInnerDistDim1>
CK_TILE_DEVICE constexpr auto make_transposed_distr_encode()
{
using xdllevel_dstr_encoding = typename LaneGroupTransposeTraits<T>::
template TileDistribution<kOuterDistDim0, kOuterDistDim1, kInnerDistDim0, kInnerDistDim1>;
return xdllevel_dstr_encoding{};
}
} // namespace ck_tile

View File

@@ -10,6 +10,15 @@
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#define CK_TILE_S_CNT_MAX 0b1100'1111'0111'1111
#define CK_TILE_VMCNT(cnt) \
([]() { static_assert((cnt) < 0b111111, "VMCNT only has 6 bits"); }(), \
((cnt)&0b1111) | (((cnt)&0b110000) << 10))
#define CK_TILE_EXPCNT(cnt) \
([]() { static_assert((cnt) < 0b111, "EXP only has 3 bits"); }(), ((cnt) << 4))
#define CK_TILE_LGKMCNT(cnt) \
([]() { static_assert((cnt) < 0b1111, "LGKM only has 4 bits"); }(), ((cnt) << 8))
namespace ck_tile {
template <typename, bool>
@@ -113,13 +122,12 @@ CK_TILE_DEVICE void block_sync_load_raw(index_t cnt = 0)
#endif
}
template <index_t vmcnt>
CK_TILE_DEVICE void block_sync_lds_direct_load()
{
asm volatile("\
s_waitcnt vmcnt(0) \n \
s_waitcnt lgkmcnt(0) \n \
s_barrier \
" ::);
// We don't sync the lds insts here.
__builtin_amdgcn_s_waitcnt(CK_TILE_S_CNT_MAX & CK_TILE_VMCNT(vmcnt));
__builtin_amdgcn_s_barrier();
}
CK_TILE_DEVICE void s_nop(index_t cnt = 0)

View File

@@ -263,3 +263,9 @@
#ifndef CK_TILE_WA_ISSUE_2028
#define CK_TILE_WA_ISSUE_2028 0
#endif
// Y pointed to R, we don't see a valuable use case.
// Will enforce encoding to check Y not pointed to R if set to zero
#ifndef CK_TILE_ENC_SUPPORT_Y_TO_R
#define CK_TILE_ENC_SUPPORT_Y_TO_R 0
#endif

View File

@@ -1178,6 +1178,15 @@ struct reverse_slice_sequence_impl<sequence<x>, sequence<m>, sequence<id>, Slice
// clang-format off
// input a sequence(with optional mask), and the SliceSize : size per slice
// output the sequence each slice, and number of slices
// the length count for slice size is from right to left(reverse slice)
// or we can say, find the greatest common divider(gcd) from right to left, for the slice length
//
// e.g. <2, 8, 4>, slice length = 16
// step-1: we take the right most <*, *, 4>, remaining 16/4=4
// step-2: we only need 4 out of 8, of the midden dim, hence <*, 4, 4>
// step-3: since nonthing remain, so the first dim we only need 1, hence<1, 4, 4>
// => we got <1, 4, 4> as length for each slice
// => total number of slice = <2, 8, 4> / <1, 4, 4> = <2, 2, 1>
//
// e.g. <2, 1, 4, 2>, 8 -> lengths:<1, 1, 4, 2> , nums: <2, 1, 1, 1> : 2 slices , slice_idx: 0
// <4, 2, 4, 1, 2>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 1> : 16 slices , slice_idx: 2
@@ -1197,7 +1206,7 @@ struct reverse_slice_sequence_impl<sequence<x>, sequence<m>, sequence<id>, Slice
//
// return tuple<slice_lengths, slice_nums, slice_index>, slice_index is at which index will start
// have split slices (right -> left)
// or the first index that sliced length is different from the original length
// or the first index (right -> left) that sliced length is different from the original length
// clang-format on
template <typename Seq,
index_t SliceSize,
@@ -1207,6 +1216,11 @@ constexpr auto reverse_slice_sequence(Seq,
Mask = typename uniform_sequence_gen<Seq::size(), 1>::type{})
{
static_assert(Seq::size() == Mask::size());
static_assert(SliceSize != 0, "slice size zero is invalid");
static_assert(container_reduce(pick_sequence_elements_by_mask(Seq{}, Mask{}), multiplies{}, 1) %
SliceSize ==
0,
"slice size can't evenly divide input sizes");
using sliced_type =
impl::reverse_slice_sequence_impl<Seq,
Mask,

View File

@@ -7,6 +7,7 @@
namespace ck_tile {
using index_t = int32_t;
using int32_t = int32_t;
using long_index_t = int64_t;
using int8_t = int8_t;

View File

@@ -0,0 +1,213 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck_tile {
// modify from include/ck/utility/mxfp_utils.hpp
template <typename T>
struct numeric_utils : numeric_traits<T>
{
using traits = numeric_traits<T>;
using _numeric = numeric<T>;
using raw_type = typename T::raw_type;
static constexpr int exp_mask = (1 << traits::exp) - 1;
static constexpr int get_exponent(raw_type x)
{
// TODO: check if repeated calls are optimized.
return (x >> traits::mant) & exp_mask;
}
static constexpr bool is_positive(raw_type x)
{
return (x >> (traits::exp + traits::mant)) == _numeric::binary_zero;
}
static constexpr bool is_subnormal(raw_type x)
{
return get_exponent(x) == _numeric::binary_zero;
}
// TODO: replace double with template arg?
static constexpr double get_mantissa(raw_type x)
{
double mantissa = is_subnormal(x) ? 0.0f : 1.0f;
for(uint32_t i = 0; i < traits::mant; ++i)
{
mantissa += std::ldexp(static_cast<float>(x & 0b1), -(traits::mant - i));
x >>= 1;
}
return mantissa;
}
};
template <typename T>
CK_TILE_HOST_DEVICE float convert_to_float(typename T::raw_type data, int scale_exp = 127)
{
using utils = numeric_utils<T>;
static constexpr int e8m0_bias = 127; // TODO: make it generic.
float sign = utils::is_positive(data) ? 1.0 : -1.0;
int exp = (utils::is_subnormal(data) ? 1 : utils::get_exponent(data)) - utils::bias;
float mant = utils::get_mantissa(data);
return std::ldexp(sign * mant, exp + scale_exp - e8m0_bias);
}
template <typename T>
CK_TILE_HOST_DEVICE typename T::raw_type convert_to_type(float value)
{
using bitwise_type = typename numeric_traits<T>::bitwise_type;
if(std::abs(value) > float(numeric<T>::max()))
{
float max_value = numeric<T>::max();
// cppcheck-suppress redundantAssignment
uint32_t max_bitwise = bit_cast<uint32_t>(max_value);
// cppcheck-suppress redundantAssignment
bitwise_type sign =
bit_cast<uint32_t>(value) >> (numeric_traits<float>::exp + numeric_traits<float>::mant);
bitwise_type exp =
((max_bitwise >> numeric_traits<float>::mant) & numeric_traits<float>::exp_mask) -
(numeric_traits<float>::bias - numeric_traits<T>::bias);
bitwise_type mantissa =
max_bitwise >> (numeric_traits<float>::mant - numeric_traits<T>::mant);
uint32_t mant_prev = max_bitwise >> (numeric_traits<float>::mant - numeric_traits<T>::mant);
mant_prev &= ((1 << numeric_traits<T>::mant) - 1);
mant_prev--;
mant_prev <<= (numeric_traits<float>::mant - numeric_traits<T>::mant);
uint32_t prev_bit =
((max_bitwise >> numeric_traits<float>::mant) << numeric_traits<float>::mant) |
mant_prev;
float prev_val = bit_cast<float>(prev_bit);
float diff = max_value - prev_val;
float actual_max = max_value + (diff / 2);
if(std::abs(value) < actual_max)
{
return sign << ((numeric_traits<T>::exp + numeric_traits<T>::mant)) |
(exp << numeric_traits<T>::mant) | mantissa;
}
else
{
if constexpr(!numeric<T>::has_inf())
{
return (1 << (numeric_traits<T>::mant + numeric_traits<T>::exp)) - 1;
}
else
{
exp++;
return sign << ((numeric_traits<T>::exp + numeric_traits<T>::mant)) |
(exp << numeric_traits<T>::mant);
}
}
}
const int mfmt = numeric_traits<float>::mant;
uint32_t x;
x = bit_cast<uint32_t>(value);
uint32_t head, mantissa;
int32_t exponent, bias;
uint32_t sign;
head = x & numeric_traits<float>::head_mask;
mantissa = x & numeric_traits<float>::mant_mask;
exponent = (head >> numeric_traits<float>::mant) & numeric_traits<float>::exp_mask;
sign = head >> (numeric_traits<float>::mant + numeric_traits<float>::exp);
bias = numeric_traits<float>::bias;
if(x == 0)
{
return 0b0;
}
const int mini_bias = numeric_traits<T>::bias;
const int mini_denormal_act_exponent = 1 - mini_bias;
int act_exponent, out_exponent, exponent_diff;
bool is_subnorm = false;
if(exponent == 0)
{
act_exponent = exponent - bias + 1;
exponent_diff = mini_denormal_act_exponent - act_exponent;
is_subnorm = true;
}
else
{
act_exponent = exponent - bias;
if(act_exponent <= mini_denormal_act_exponent)
{
exponent_diff = mini_denormal_act_exponent - act_exponent;
is_subnorm = true;
}
else
{
exponent_diff = 0;
}
mantissa += (1UL << mfmt);
}
auto shift_amount = (mfmt - numeric_traits<T>::mant + exponent_diff);
shift_amount = (shift_amount >= 64) ? 63 : shift_amount;
bool midpoint = (mantissa & ((1UL << shift_amount) - 1)) == (1UL << (shift_amount - 1));
float min_subnorm = float(numeric<T>::epsilon()) * (sign ? -1 : 1);
if(is_subnorm && std::abs(value) < std::abs(min_subnorm))
{
// closer to 0
if(std::abs(value) <= std::abs(min_subnorm - value))
return sign << (numeric_traits<T>::exp + numeric_traits<T>::mant);
else
return 1 | (sign << (numeric_traits<T>::exp + numeric_traits<T>::mant));
}
if(exponent_diff > 0)
mantissa >>= exponent_diff;
else if(exponent_diff == -1)
mantissa <<= -exponent_diff;
bool implicit_one = mantissa & (1 << mfmt);
out_exponent = (act_exponent + exponent_diff) + mini_bias - (implicit_one ? 0 : 1);
uint32_t drop_mask = (1UL << (mfmt - numeric_traits<T>::mant)) - 1;
bool odd = mantissa & (1UL << (mfmt - numeric_traits<T>::mant));
mantissa += (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa) & drop_mask;
if(out_exponent == 0)
{
if((1UL << mfmt) & mantissa)
{
out_exponent = 1;
}
}
else
{
if((1UL << (mfmt + 1)) & mantissa)
{
mantissa >>= 1;
out_exponent++;
}
}
mantissa >>= (mfmt - numeric_traits<T>::mant);
if(out_exponent == 0 && mantissa == 0)
{
return sign << (numeric_traits<T>::exp + numeric_traits<T>::mant);
}
mantissa &= (1UL << numeric_traits<T>::mant) - 1;
return (sign << (numeric_traits<T>::exp + numeric_traits<T>::mant)) |
(out_exponent << numeric_traits<T>::mant) | mantissa;
}
} // namespace ck_tile

View File

@@ -103,94 +103,92 @@ struct numeric_traits<float>
} // namespace ck_tile
#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_) \
attr_ bool operator==(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) == static_cast<float>(y); \
} \
attr_ bool operator!=(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) != static_cast<float>(y); \
} \
attr_ bool operator<(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) < static_cast<float>(y); \
} \
attr_ bool operator<=(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) <= static_cast<float>(y); \
} \
attr_ bool operator>(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) > static_cast<float>(y); \
} \
attr_ bool operator>=(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) >= static_cast<float>(y); \
} \
attr_ type_ operator+(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) + static_cast<float>(y)); \
} \
attr_ type_ operator-(const type_& x) \
{ \
constexpr uint32_t bits = sizeof(type_) * 8; \
constexpr uint32_t mask = 1 << (bits - 1); \
type_ y = x; \
y.data ^= static_cast<typename type_::raw_type>(mask); \
return y; \
} \
attr_ type_ operator-(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) - static_cast<float>(y)); \
} \
attr_ type_ operator*(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) * static_cast<float>(y)); \
} \
attr_ type_ operator/(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) / static_cast<float>(y)); \
} \
attr_ type_& operator+=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) + static_cast<float>(y)); \
return x; \
} \
attr_ type_& operator-=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) - static_cast<float>(y)); \
return x; \
} \
attr_ type_& operator*=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) * static_cast<float>(y)); \
return x; \
} \
attr_ type_& operator/=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) / static_cast<float>(y)); \
return x; \
} \
attr_ type_& operator++(type_& x) \
{ \
x = type_(static_cast<float>(x) + 1.f); \
return x; \
} \
attr_ type_& operator--(type_& x) \
{ \
x = type_(static_cast<float>(x) - 1.f); \
return x; \
} \
attr_ type_ operator++(type_& x, int) \
{ \
type_ y(x); \
x = type_(static_cast<float>(x) + 1.f); \
return y; \
} \
attr_ type_ operator--(type_& x, int) \
{ \
type_ y(x); \
x = type_(static_cast<float>(x) - 1.f); \
return y; \
#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_) \
attr_ bool operator==(const type_& x, const type_& y) \
{ \
return std::abs(static_cast<float>(x) - static_cast<float>(y)) < \
static_cast<float>(numeric<type_>::epsilon()); \
} \
attr_ bool operator!=(const type_& x, const type_& y) { return not operator==(x, y); } \
attr_ bool operator<(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) < static_cast<float>(y); \
} \
attr_ bool operator<=(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) <= static_cast<float>(y); \
} \
attr_ bool operator>(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) > static_cast<float>(y); \
} \
attr_ bool operator>=(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) >= static_cast<float>(y); \
} \
attr_ type_ operator+(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) + static_cast<float>(y)); \
} \
attr_ type_ operator-(const type_& x) \
{ \
constexpr uint32_t bits = sizeof(type_) * 8; \
constexpr uint32_t mask = 1 << (bits - 1); \
type_ y = x; \
y.data ^= static_cast<typename type_::raw_type>(mask); \
return y; \
} \
attr_ type_ operator-(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) - static_cast<float>(y)); \
} \
attr_ type_ operator*(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) * static_cast<float>(y)); \
} \
attr_ type_ operator/(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) / static_cast<float>(y)); \
} \
attr_ type_& operator+=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) + static_cast<float>(y)); \
return x; \
} \
attr_ type_& operator-=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) - static_cast<float>(y)); \
return x; \
} \
attr_ type_& operator*=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) * static_cast<float>(y)); \
return x; \
} \
attr_ type_& operator/=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) / static_cast<float>(y)); \
return x; \
} \
attr_ type_& operator++(type_& x) \
{ \
x = type_(static_cast<float>(x) + 1.f); \
return x; \
} \
attr_ type_& operator--(type_& x) \
{ \
x = type_(static_cast<float>(x) - 1.f); \
return x; \
} \
attr_ type_ operator++(type_& x, int) \
{ \
type_ y(x); \
x = type_(static_cast<float>(x) + 1.f); \
return y; \
} \
attr_ type_ operator--(type_& x, int) \
{ \
type_ y(x); \
x = type_(static_cast<float>(x) - 1.f); \
return y; \
}

View File

@@ -0,0 +1,324 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cmath>
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/mxfp_convert.hpp"
#if defined(__gfx950__)
#define CK_TILE_FP4_CVT_DEVICE 1
#else
#define CK_TILE_FP4_CVT_DEVICE 0
#endif
#define TEST_convert_with_table 0
namespace ck_tile {
using fp32_t = float;
using fp32x2_t = float __attribute__((ext_vector_type(2)));
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
CK_TILE_HOST_DEVICE constexpr uint8_t float_to_e2m1(float);
// TODO: Add stochastic method
struct pk_float4_e2m1_t
{
static constexpr int exponent = 2;
static constexpr int mantissa = 1;
static constexpr int bias = 1;
// TODO: Can we merge raw_type and type?
using raw_type = uint8_t;
using type = raw_type;
raw_type data;
CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t() : data{type{}} {}
template <typename T, typename = std::enable_if_t<std::is_integral_v<T>>>
CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t(T init) : data{static_cast<type>(init)}
{
}
CK_TILE_HOST_DEVICE explicit constexpr pk_float4_e2m1_t(float init) : data{float_to_e2m1(init)}
{
}
CK_TILE_HOST_DEVICE constexpr operator type() const { return data; }
CK_TILE_HOST_DEVICE constexpr raw_type& get() { return data; }
CK_TILE_HOST_DEVICE constexpr raw_type get() const { return data; }
CK_TILE_HOST_DEVICE constexpr operator float() const;
CK_TILE_HOST_DEVICE constexpr operator fp32x2_t() const;
CK_TILE_HOST_DEVICE constexpr operator fp16_t() const;
CK_TILE_HOST_DEVICE constexpr operator fp16x2_t() const;
CK_TILE_HOST_DEVICE constexpr operator bf16_t() const;
CK_TILE_HOST_DEVICE constexpr operator bf16x2_t() const;
template <index_t I>
CK_TILE_HOST_DEVICE raw_type unpack(number<I>) const;
CK_TILE_HOST_DEVICE static pk_float4_e2m1_t pack(const type x0, const type x1)
{
return (x1 << 4) | (x0 & 0b00001111);
}
#if TEST_convert_with_table
static constexpr float e2m1_to_fp32_table[16] = {
0, 0.5, 1, 1.5, 2, 3, 4, 6, -0, -0.5, -1, -1.5, -2, -3, -4, -6};
static constexpr fp16_t e2m1_to_fp16_table[16] = {
bit_cast<fp16_t>(static_cast<uint16_t>(0x0000)), // 0
bit_cast<fp16_t>(static_cast<uint16_t>(0x3800)), // 0.5
bit_cast<fp16_t>(static_cast<uint16_t>(0x3C00)), // 1
bit_cast<fp16_t>(static_cast<uint16_t>(0x3E00)), // 1.5
bit_cast<fp16_t>(static_cast<uint16_t>(0x4000)), // 2
bit_cast<fp16_t>(static_cast<uint16_t>(0x4200)), // 3
bit_cast<fp16_t>(static_cast<uint16_t>(0x4400)), // 4
bit_cast<fp16_t>(static_cast<uint16_t>(0x4600)), // 6
bit_cast<fp16_t>(static_cast<uint16_t>(0x8000)), // -0
bit_cast<fp16_t>(static_cast<uint16_t>(0xB800)), // -0.5
bit_cast<fp16_t>(static_cast<uint16_t>(0xBC00)), // -1
bit_cast<fp16_t>(static_cast<uint16_t>(0xBE00)), // -1.5
bit_cast<fp16_t>(static_cast<uint16_t>(0xC000)), // -2
bit_cast<fp16_t>(static_cast<uint16_t>(0xC200)), // -3
bit_cast<fp16_t>(static_cast<uint16_t>(0xC400)), // -4
bit_cast<fp16_t>(static_cast<uint16_t>(0xC600)) // -6
};
#endif
};
using pk_fp4_t = pk_float4_e2m1_t;
using pk_fp4_raw_t = typename pk_fp4_t::raw_type;
template <>
struct numeric_traits<pk_fp4_t>
{
using bitwise_type = pk_fp4_raw_t;
static constexpr int exp = 2;
static constexpr int mant = 1;
static constexpr int bias = 1;
static constexpr int PackedSize = 2;
};
// limits
template <class T>
struct numeric;
template <>
struct numeric<pk_fp4_t>
{
static constexpr pk_fp4_raw_t binary_min_normal = 0b00100010; // 1
static constexpr pk_fp4_raw_t binary_max_normal = 0b01110111; // 6
static constexpr pk_fp4_raw_t binary_lowest_normal = 0b11111111; // -6
static constexpr pk_fp4_raw_t binary_min_subnorm = 0b00010001; // 0.5
static constexpr pk_fp4_raw_t binary_max_subnorm = 0b00010001; // 0.5
static constexpr pk_fp4_raw_t binary_zero = 0b00000000; // 0
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t min() { return binary_min_normal; }
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t max() { return binary_max_normal; }
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t lowest() { return binary_lowest_normal; }
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t epsilon() { return binary_min_subnorm; }
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t round_error() { return binary_min_subnorm; }
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t zero() { return binary_zero; }
CK_TILE_HOST_DEVICE static constexpr fp8_t denorm_min() { return binary_min_subnorm; }
CK_TILE_HOST_DEVICE static constexpr bool has_inf() { return false; }
// N/A
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t infinity() { return max(); }
// N/A
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t quiet_NaN() { return max(); }
// N/A
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t signaling_NaN() { return max(); }
};
template <index_t I>
CK_TILE_HOST_DEVICE pk_fp4_raw_t pk_fp4_t::unpack(number<I>) const
{
static_assert(I < 2, "Index is out of range.");
if constexpr(I == 1)
return (data >> 4);
else
return data & 0b00001111;
}
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, pk_fp4_t)
// TODO: consider replace this macro to improve performance
#if CK_TILE_FP4_CVT_DEVICE
namespace impl {
template <typename T>
CK_TILE_DEVICE T _from_f4(pk_fp4_raw_t src, float scale = 1.0f)
{
// TODO: check the order
if constexpr(std::is_same_v<T, fp32_t>)
return fp32x2_t(__builtin_amdgcn_cvt_scalef32_pk_f32_fp4(src, scale, 0))[0];
else if constexpr(std::is_same_v<T, fp32x2_t>)
return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(src, scale, 0);
else if constexpr(std::is_same_v<T, fp16_t>)
return fp16x2_t(__builtin_amdgcn_cvt_scalef32_pk_f16_fp4(src, scale, 0))[0];
else if constexpr(std::is_same_v<T, fp16x2_t>)
return __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(src, scale, 0);
else if constexpr(std::is_same_v<T, bf16_t>)
return bf16x2_t(__builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(src, scale, 0))[0];
else if constexpr(std::is_same_v<T, bf16x2_t>)
return __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(src, scale, 0);
else
static_assert(std::false_type::value, "Unsupported type.");
return T{};
}
template <typename T>
CK_TILE_DEVICE pk_fp4_raw_t _to_f4(T src, float scale = 1.0f)
{
// TODO: check the order
union
{
uint32_t u32;
pk_fp4_raw_t pf4[4];
} cvt{0};
if constexpr(std::is_same_v<T, fp32_t>)
cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(cvt.u32, src, src, scale, 0);
else if constexpr(std::is_same_v<T, fp32x2_t>)
cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(cvt.u32, src[0], src[1], scale, 0);
else if constexpr(std::is_same_v<T, fp16_t>)
cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(cvt.u32, fp16x2_t{src, src}, scale, 0);
else if constexpr(std::is_same_v<T, fp16x2_t>)
cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(cvt.u32, src, scale, 0);
else if constexpr(std::is_same_v<T, bf16_t>)
cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_bf16(cvt.u32, bf16x2_t{src, src}, scale, 0);
else if constexpr(std::is_same_v<T, bf16x2_t>)
cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_bf16(cvt.u32, src, scale, 0);
else
static_assert(std::false_type::value, "Unsupported type.");
return cvt.pf4[0];
}
} // namespace impl
#endif
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator bf16_t() const
{
#if CK_TILE_FP4_CVT_DEVICE
return impl::_from_f4<bf16_t>(data);
#else
return bf16_t{type_convert<bf16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{})))};
#endif
}
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator bf16x2_t() const
{
#if CK_TILE_FP4_CVT_DEVICE
return impl::_from_f4<bf16x2_t>(data);
#else
return bf16x2_t{type_convert<bf16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{}))),
type_convert<bf16_t>(convert_to_float<pk_fp4_t>(unpack(number<1>{})))};
#endif
}
// TODO: make float_to_e2m1 generic so that we can convert from directrly.
CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t float_to_e2m1(float x)
{
#if CK_TILE_FP4_CVT_DEVICE
return impl::_to_f4(x);
#else
return convert_to_type<pk_fp4_t>(x);
#endif
}
CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_to_fp32x2(const pk_fp4_t& x) { return fp32x2_t(x); }
CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_to_fp16x2(const pk_fp4_t& x) { return fp16x2_t(x); }
CK_TILE_HOST_DEVICE constexpr bf16x2_t pk_fp4_to_bf16x2(const pk_fp4_t& x) { return bf16x2_t(x); }
CK_TILE_HOST_DEVICE constexpr pk_fp4_t float_to_pk_fp4(const float& x) { return float_to_e2m1(x); }
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16_to_pk_fp4(const fp16_t& x)
{
#if CK_TILE_FP4_CVT_DEVICE
return impl::_to_f4(x);
#else
return float_to_e2m1(type_convert<float>(x));
#endif
}
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16_to_pk_fp4(const bf16_t& x)
{
#if CK_TILE_FP4_CVT_DEVICE
return impl::_to_f4(x);
#else
return float_to_e2m1(type_convert<float>(x));
#endif
}
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t& x)
{
#if CK_TILE_FP4_CVT_DEVICE
return impl::_to_f4(x);
#else
return pk_fp4_t::pack(float_to_e2m1(type_convert<float>(x[0])),
float_to_e2m1(type_convert<float>(x[1])));
#endif
}
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x)
{
#if CK_TILE_FP4_CVT_DEVICE
return impl::_to_f4(x);
#else
return pk_fp4_t::pack(float_to_e2m1(type_convert<float>(x[0])),
float_to_e2m1(type_convert<float>(x[1])));
#endif
}
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x)
{
#if CK_TILE_FP4_CVT_DEVICE
return impl::_to_f4(x);
#else
return pk_fp4_t::pack(float_to_e2m1(x[0]), float_to_e2m1(x[1]));
#endif
}
#if TEST_convert_with_table == 0
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator float() const
{
#if CK_TILE_FP4_CVT_DEVICE
return impl::_from_f4<fp32_t>(data);
#else
return convert_to_float<pk_fp4_t>(unpack(number<0>{}));
#endif
}
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp32x2_t() const
{
#if CK_TILE_FP4_CVT_DEVICE
return impl::_from_f4<fp32x2_t>(data);
#else
return fp32x2_t{convert_to_float<pk_fp4_t>(unpack(number<0>{})),
convert_to_float<pk_fp4_t>(unpack(number<1>{}))};
#endif
}
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp16_t() const
{
#if CK_TILE_FP4_CVT_DEVICE
return impl::_from_f4<fp16_t>(data);
#else
return fp16_t{type_convert<fp16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{})))};
#endif
}
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp16x2_t() const
{
#if CK_TILE_FP4_CVT_DEVICE
return impl::_from_f4<fp16x2_t>(data);
#else
return fp16x2_t{type_convert<fp16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{}))),
type_convert<fp16_t>(convert_to_float<pk_fp4_t>(unpack(number<1>{})))};
#endif
}
#else
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator float() const
{
return e2m1_to_fp32_table[data & 0xf];
}
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp32x2_t() const
{
return fp32x2_t{e2m1_to_fp32_table[data & 0xf], e2m1_to_fp32_table[data >> 4]};
}
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp16_t() const
{
return e2m1_to_fp16_table[data & 0xf];
}
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp16x2_t() const
{
return fp16x2_t{e2m1_to_fp16_table[data & 0xf], e2m1_to_fp16_table[data >> 4]};
}
#endif
} // namespace ck_tile

View File

@@ -11,6 +11,7 @@
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/int8.hpp"
#include "ck_tile/core/numeric/mxfp_convert.hpp"
namespace ck_tile {
@@ -64,6 +65,21 @@ CK_TILE_TYPE_CONVERT(bf8_t, bf8, float, float)
CK_TILE_TYPE_CONVERT(float, float, int8_t, int8)
CK_TILE_TYPE_CONVERT(int8_t, int8, float, float)
} // namespace ck_tile
#include "ck_tile/core/numeric/pk_fp4.hpp"
namespace ck_tile {
CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp32x2_t, fp32x2)
CK_TILE_TYPE_CONVERT(fp32x2_t, fp32x2, pk_fp4_t, pk_fp4)
CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp16x2_t, fp16x2)
CK_TILE_TYPE_CONVERT(fp16x2_t, fp16x2, pk_fp4_t, pk_fp4)
CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, bf16x2_t, bf16x2)
CK_TILE_TYPE_CONVERT(bf16x2_t, bf16x2, pk_fp4_t, pk_fp4)
CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, float, float)
CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, bf16_t, bf16)
CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp16_t, fp16)
#undef CK_TILE_TYPE_CONVERT
#endif

View File

@@ -18,6 +18,7 @@
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/utility/ignore.hpp"
namespace ck_tile {
@@ -133,6 +134,28 @@ struct buffer_view<address_space_enum::generic,
}
}
/*
In the generic address space, we do not support the transpose instruction in the buffer view.
Will report compilation error when developer wants to use it.
*/
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE constexpr auto transpose_get(index_t i,
index_t linear_offset,
bool is_valid_element,
bool_constant<oob_conditional_check> = {}) const
{
static_assert(false, "Error: transpose load not supported in global memory space.");
ignore = i;
ignore = linear_offset;
ignore = is_valid_element;
return;
}
// i is offset of T, not X. i should be aligned to X
template <memory_operation_enum Op,
typename X,
@@ -359,6 +382,28 @@ struct buffer_view<address_space_enum::global,
}
}
/*
In the global memory address space, we do not support the transpose instruction in the buffer
view. Will report compilation error when developer wants to use it.
*/
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE constexpr auto transpose_get(index_t i,
index_t linear_offset,
bool is_valid_element,
bool_constant<oob_conditional_check> = {}) const
{
static_assert(false, "Error: transpose load not supported in global memory space.");
ignore = i;
ignore = linear_offset;
ignore = is_valid_element;
return;
}
// i is offset of T, not X. i should be aligned to X
template <typename X,
bool oob_conditional_check = true,
@@ -407,10 +452,12 @@ struct buffer_view<address_space_enum::global,
"wrong! X should contain multiple T");
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_data_, (buffer_size_) * sizeof(type));
amd_async_buffer_load_with_oob<remove_cvref_t<T>, t_per_x, Coherence>(
smem,
cached_buf_res_,
src_wave_buffer_resource,
i,
linear_offset,
is_valid_element,
@@ -852,6 +899,47 @@ struct buffer_view<address_space_enum::lds,
smem_load<sizeof(X)>{}(dst, v_offset * sizeof(T), i_offset * sizeof(T));
}
template <typename X,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE constexpr auto transpose_get([[maybe_unused]] index_t i,
[[maybe_unused]] index_t linear_offset,
bool is_valid_element) const
{
// X contains multiple T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T");
if(is_valid_element)
{
#if defined(__gfx950__)
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
constexpr address_space_enum addr_space = get_address_space();
return amd_transpose_load_to_vgpr<remove_cvref_t<T>, t_per_x, addr_space>(
p_data_ + i + linear_offset);
#else
return X{numeric<remove_cvref_t<T>>::zero()};
#endif
}
else
{
if constexpr(InvalidElementUseNumericalZeroValue)
{
return X{numeric<remove_cvref_t<T>>::zero()};
}
else
{
return X{invalid_element_value_};
}
}
}
// i is offset of T, not X. i should be aligned to X
template <memory_operation_enum Op,
typename X,
@@ -923,6 +1011,15 @@ struct buffer_view<address_space_enum::lds,
std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8x16_t> &&
std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
// int8 on thread buffer
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 8>>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 4>>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 2>>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 1>>) ||
// ext_vector_type for pk_int4 must use int8_t as type
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 1>>) ||
@@ -945,6 +1042,8 @@ struct buffer_view<address_space_enum::lds,
if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, int8_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 1>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 1>>))
{
@@ -955,6 +1054,8 @@ struct buffer_view<address_space_enum::lds,
}
else if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, int8x2_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 2>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 2>>))
{
@@ -965,6 +1066,8 @@ struct buffer_view<address_space_enum::lds,
}
else if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, int8x4_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 4>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 4>>))
{
@@ -975,6 +1078,8 @@ struct buffer_view<address_space_enum::lds,
}
else if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 8>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 8>>))
{

View File

@@ -89,6 +89,19 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
tile, number<i_access>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
}
template <typename LdsTileWindow_,
typename TileWindow_,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto async_load_tile(LdsTileWindow_&& lds_tile,
const TileWindow_& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return tile_window.async_load(
lds_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
}
template <typename LdsTileWindow_,
typename TileWindow_,
index_t i_access = -1,

View File

@@ -0,0 +1,362 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/container/thread_buffer.hpp"
#include "ck_tile/core/container/statically_indexed_array.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
namespace util {
template <typename Suffix, typename Sequence>
struct is_sequence_suffix
{
static constexpr bool size_check = (Suffix::size() <= Sequence::size());
static constexpr index_t start_pos = Sequence::size() - Suffix::size();
using extract_indices = typename arithmetic_sequence_gen<start_pos, Sequence::size(), 1>::type;
static constexpr bool value =
size_check && (Suffix{} == decltype(Sequence::extract(extract_indices{})){});
};
template <index_t... Xs>
struct is_sequence_suffix<sequence<>, sequence<Xs...>>
{
static constexpr bool value = true;
};
template <typename Suffix, typename Sequence>
constexpr bool is_sequence_suffix_v = is_sequence_suffix<Suffix, Sequence>::value;
} // namespace util
// Default policy: Retains original 2D transpose behavior
template <typename DataType>
struct DefaultTranspose
{
struct Quad16
{
using InputEncoding = tile_distribution_encoding<sequence<>,
tuple<sequence<4>, sequence<4, 4>>,
tuple<sequence<1, 2>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>;
using OutputEncoding = tile_distribution_encoding<sequence<>,
tuple<sequence<16>, sequence<4>>,
tuple<sequence<1>>,
tuple<sequence<0>>,
sequence<2>,
sequence<0>>;
};
struct Quad8
{
using InputEncoding = tile_distribution_encoding<sequence<>,
tuple<sequence<8>, sequence<2, 8>>,
tuple<sequence<1, 2>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>;
using OutputEncoding = tile_distribution_encoding<sequence<>,
tuple<sequence<16>, sequence<8>>,
tuple<sequence<1>>,
tuple<sequence<0>>,
sequence<2>,
sequence<0>>;
};
// Select based on data size
using QuadInputEncoding = std::conditional_t<sizeof(DataType) == 2,
typename Quad16::InputEncoding,
typename Quad8::InputEncoding>;
using QuadOutputEncoding = std::conditional_t<sizeof(DataType) == 2,
typename Quad16::OutputEncoding,
typename Quad8::OutputEncoding>;
// Always swap last two dimensions
static constexpr auto transpose_dims = sequence<1, 0>{};
// Programmable: Element grouping function
static constexpr auto group_func = [](auto idx) {
return idx; // Identity mapping
};
template <typename InDstrEncode>
struct ValidationTraits
{
static constexpr auto input_hs_lengthss = InDstrEncode::hs_lengthss_;
static constexpr auto quad_hs_lengthss = QuadInputEncoding::hs_lengthss_;
// 1. Must be 2D tensor
static constexpr bool dims_valid = (InDstrEncode::NDimX == 2);
// 2. Quad pattern must be suffix of input pattern
static constexpr bool suffix_valid_dim0 =
util::is_sequence_suffix_v<decltype(quad_hs_lengthss.template get<0>()),
decltype(input_hs_lengthss.template get<0>())>;
static constexpr bool suffix_valid_dim1 =
util::is_sequence_suffix_v<decltype(quad_hs_lengthss.template get<1>()),
decltype(input_hs_lengthss.template get<1>())>;
// 3. PS→RHS mapping constraints
static constexpr auto input_ps_to_rhss_major = InDstrEncode::ps_to_rhss_major_;
static constexpr auto input_ps_to_rhss_minor = InDstrEncode::ps_to_rhss_minor_;
static constexpr index_t ndimp_outer = input_ps_to_rhss_major.size() - 1;
static constexpr index_t ndimp_inner =
input_ps_to_rhss_major[number<ndimp_outer>{}].size() - 1;
static constexpr bool ps_mapping_valid =
(input_ps_to_rhss_major[number<ndimp_outer>{}][number<ndimp_inner>{}] == 2) &&
(input_ps_to_rhss_minor[number<ndimp_outer>{}][number<ndimp_inner>{}] ==
input_hs_lengthss[number<1>{}].size() - 2) &&
(input_ps_to_rhss_major[number<ndimp_outer>{}][number<ndimp_inner - 1>{}] == 1) &&
(input_ps_to_rhss_minor[number<ndimp_outer>{}][number<ndimp_inner - 1>{}] ==
input_hs_lengthss[number<0>{}].size() - 1);
// 4. YS→RHS mapping constraints
static constexpr auto input_ys_to_rhs_major = InDstrEncode::ys_to_rhs_major_;
static constexpr auto input_ys_to_rhs_minor = InDstrEncode::ys_to_rhs_minor_;
static constexpr bool ys_mapping_valid =
(input_ys_to_rhs_major.back() == 2) &&
(input_ys_to_rhs_minor.back() == input_hs_lengthss[number<1>{}].size() - 1) &&
(input_ys_to_rhs_major[input_ys_to_rhs_major.size() - 2] == 1) &&
(input_ys_to_rhs_minor[input_ys_to_rhs_minor.size() - 2] ==
input_hs_lengthss[number<0>{}].size() - 2);
static constexpr bool value = dims_valid && suffix_valid_dim0 && suffix_valid_dim1 &&
ps_mapping_valid && ys_mapping_valid;
};
};
template <typename TileDistribution_, typename DataType_, typename Policy>
struct TransposeTileDistrChecker
{
using InDstrEncode = typename remove_cvref_t<TileDistribution_>::DstrEncode;
using Validator = typename Policy::template ValidationTraits<InDstrEncode>;
static constexpr bool distr_encoding_valid = Validator::value;
};
// this is used to generate the transposed output tile distribution encoding
// based on the input tile distribution encoding
template <typename TileDistribution_,
typename DataType_,
typename Policy = DefaultTranspose<DataType_>>
struct OutputTileDistributionTraits
{
using InDstrEncode = typename remove_cvref_t<TileDistribution_>::DstrEncode;
static constexpr auto input_hs_lengthss = InDstrEncode::hs_lengthss_;
static constexpr auto quad_input_hs_lengthss = Policy::QuadInputEncoding::hs_lengthss_;
static constexpr auto quad_output_hs_lengthss = Policy::QuadOutputEncoding::hs_lengthss_;
static constexpr auto input_ps_to_rhss_major = InDstrEncode::ps_to_rhss_major_;
static constexpr auto input_ps_to_rhss_minor = InDstrEncode::ps_to_rhss_minor_;
static constexpr auto input_ys_to_rhs_major = InDstrEncode::ys_to_rhs_major_;
static constexpr auto input_ys_to_rhs_minor = InDstrEncode::ys_to_rhs_minor_;
static constexpr auto quad_ps_to_rhss_major = Policy::QuadInputEncoding::ps_to_rhss_major_;
static constexpr auto quad_ps_to_rhss_minor = Policy::QuadInputEncoding::ps_to_rhss_minor_;
// for transpose load
// append the reversed quad output hs lengths to the input hs lengthss after removing
// the quad_input_hs_lengthss
// then reverse the whole sequence to get the dst_out_hs_lengthss
static constexpr auto reversed_quad_output_hs_lengthss = tuple_reverse(quad_output_hs_lengthss);
static constexpr auto full_out_hs_lengthss = generate_tuple(
[](auto i) {
return input_hs_lengthss[i]
.extract(typename arithmetic_sequence_gen<0,
input_hs_lengthss[i].size() -
quad_input_hs_lengthss[i].size(),
1>::type{})
.push_back(reversed_quad_output_hs_lengthss[i]);
},
number<InDstrEncode::NDimX>{});
static constexpr auto dst_out_hs_lengthss = tuple_reverse(full_out_hs_lengthss);
// for PS→RHS mapping(both major and minor), we need to modify the last element of the major
// sequence
static constexpr auto modified_ps_to_rhss_major = generate_tuple(
[](auto i) {
if constexpr(i == input_ps_to_rhss_major.size() - 1)
{
constexpr auto current_size = input_ps_to_rhss_major[i].size();
constexpr auto reduce_size = quad_ps_to_rhss_major[number<0>{}].size();
constexpr auto reduced_ps_to_rhss_major = input_ps_to_rhss_major[i].extract(
typename arithmetic_sequence_gen<0, current_size - reduce_size, 1>::type{});
return reduced_ps_to_rhss_major.push_back(number<2>{});
}
else
{
// For all other sequences, keep them unchanged
return input_ps_to_rhss_major[i];
}
},
number<input_ps_to_rhss_major.size()>{});
static constexpr auto minor_last_index =
full_out_hs_lengthss[number<InDstrEncode::NDimX - 1>{}].size() - 1;
static constexpr auto major_last_index = full_out_hs_lengthss[number<0>{}].size() - 1;
static constexpr auto dst_ps_to_rhss_minor = generate_tuple(
[](auto i) {
if constexpr(i == input_ps_to_rhss_minor.size() - 1)
{
constexpr auto current_size = input_ps_to_rhss_minor[i].size();
constexpr auto reduce_size = quad_ps_to_rhss_minor[number<0>{}].size();
constexpr auto reduced_ps_to_rhss_minor = input_ps_to_rhss_minor[i].extract(
typename arithmetic_sequence_gen<0, current_size - reduce_size, 1>::type{});
return reduced_ps_to_rhss_minor.push_back(number<minor_last_index>{});
}
else
{
// For all other sequences, keep them unchanged
return input_ps_to_rhss_minor[i];
}
},
number<input_ps_to_rhss_minor.size()>{});
// for major because of dst_out_hs_lengthss is reversed, this index also need to be reversed
static constexpr auto swap_one_and_two = [](const index_t idx) {
return (idx == 1) ? 2 : (idx == 2) ? 1 : idx;
};
static constexpr auto dst_ps_to_rhss_major = generate_tuple(
[](auto i) { return modified_ps_to_rhss_major[i].transform(swap_one_and_two); },
number<modified_ps_to_rhss_major.size()>{});
static constexpr auto modified_input_ys_to_rhs_major =
input_ys_to_rhs_major.pop_back().push_back(number<1>{});
static constexpr auto dst_ys_to_rhs_major = generate_sequence_v2(
[](auto i) { return number<swap_one_and_two(modified_input_ys_to_rhs_major[i])>{}; },
number<modified_input_ys_to_rhs_major.size()>{});
static constexpr auto dst_ys_to_rhs_minor =
input_ys_to_rhs_minor.pop_back().push_back(number<major_last_index>{});
using OutDstrEncode = tile_distribution_encoding<typename InDstrEncode::RsLengths,
remove_cvref_t<decltype(dst_out_hs_lengthss)>,
remove_cvref_t<decltype(dst_ps_to_rhss_major)>,
remove_cvref_t<decltype(dst_ps_to_rhss_minor)>,
remove_cvref_t<decltype(dst_ys_to_rhs_major)>,
remove_cvref_t<decltype(dst_ys_to_rhs_minor)>>;
};
template <typename InnerEncode,
index_t kLeadIterPerWarp,
index_t kSecondIterPerWarp,
index_t kLeadNumWarps,
index_t kSecondNumWarps>
CK_TILE_HOST_DEVICE constexpr auto InputTileDistributionEncoding()
{
constexpr auto block_outer_dst_encoding =
tile_distribution_encoding<sequence<>,
tuple<sequence<kSecondIterPerWarp, kSecondNumWarps>,
sequence<kLeadIterPerWarp, kLeadNumWarps>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 1>>,
sequence<2, 1>,
sequence<0, 0>>{};
constexpr auto blk_distr_encode =
detail::make_embed_tile_distribution_encoding(block_outer_dst_encoding, InnerEncode{});
return blk_distr_encode;
}
/**
* @brief transpose loads tile from a tensor and returns the resulting tensor with a new
* (transposed) tile distribution. use SFINAE to ensure the tile distribution encoding is valid.
*
* This function is intended for use with statically distributed tensor tiles, where the input
* and output tile distributions differ due to the transpose operation. It ensures that the
* element space size and vector length remain consistent between the input and output
* distributions.
*
* @tparam BottomTensorView_ The type of the bottom tensor view.
* @tparam WindowLengths_ The type representing the window lengths.
* @tparam TileDistribution_ The type representing the tile distribution.
* @tparam NumCoord The number of coordinates (dimensions).
* @tparam Policy The transpose policy to use (defaults to DefaultTranspose).
* the last is SFINAE to ensure the tile distribution encoding is valid.
*
* @param tile_window The tile window with static distribution to load and transpose.
*
* @return A statically distributed tensor containing the transposed tile data.
*
* @note
* - The function uses compile-time checks to ensure the input and output tile distributions
* are compatible in terms of element space size and vector length.
* - The transpose operation is performed according to the specified Policy.
*/
template <
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
typename Policy = DefaultTranspose<typename BottomTensorView_::DataType>,
typename = std::enable_if_t<TransposeTileDistrChecker<TileDistribution_,
typename BottomTensorView_::DataType,
Policy>::distr_encoding_valid,
Policy>>
CK_TILE_DEVICE auto
load_tile_transpose(const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window)
{
using OutTileDstrEncode =
typename OutputTileDistributionTraits<TileDistribution_,
typename BottomTensorView_::DataType>::OutDstrEncode;
auto out_tensor = make_static_distributed_tensor<typename BottomTensorView_::DataType>(
make_static_tile_distribution(OutTileDstrEncode{}));
auto trans_tensor = tile_window.template load_transpose<Policy>();
constexpr auto input_distr = TileDistribution_{};
constexpr auto output_distr = make_static_tile_distribution(OutTileDstrEncode{});
constexpr auto y_in_desc = input_distr.get_ys_to_d_descriptor();
constexpr auto y_out_desc = output_distr.get_ys_to_d_descriptor();
constexpr index_t NDimYIn = input_distr.get_num_of_dimension_y();
constexpr index_t NDimYOut = output_distr.get_num_of_dimension_y();
constexpr auto y_in_lengths = to_sequence(y_in_desc.get_lengths());
constexpr auto y_out_lengths = to_sequence(y_out_desc.get_lengths());
constexpr auto y_in_element_space_size = y_in_desc.get_element_space_size();
constexpr auto y_out_element_space_size = y_out_desc.get_element_space_size();
static_assert(y_in_element_space_size == y_out_element_space_size,
"the element space size is not the same!");
static_assert(y_in_lengths[NDimYIn - 1] == y_out_lengths[NDimYOut - 1],
"the vector length is not the same!");
constexpr index_t vecLoadSize = y_in_lengths[NDimYIn - 1];
constexpr index_t num_of_access =
reduce_on_sequence(y_in_lengths, multiplies{}, number<1>{}) / vecLoadSize;
using DataVec = array<typename BottomTensorView_::DataType, vecLoadSize>;
static_for<0, num_of_access, 1>{}([&](auto iAccess) {
out_tensor.get_thread_buffer().template set_as<DataVec>(
number<iAccess>{},
trans_tensor.get_thread_buffer().template get_as<DataVec>(number<iAccess>{}));
});
return out_tensor;
}
} // namespace ck_tile

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -129,7 +129,10 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT
// set output vectors
static_for<0, num_vec_out, 1>{}([&](auto i) {
constexpr auto idx_y_out_tmp = generate_array(
[&](auto ii) { return ii == y_dim_vec_in ? idx_y_start[ii] + i : idx_y_start[ii]; },
[&](auto ii) {
return ii == y_dim_vec_in ? static_cast<index_t>(idx_y_start[ii]) + i
: static_cast<index_t>(idx_y_start[ii]);
},
number<NDimY>{});
constexpr auto idx_y_out =

View File

@@ -161,7 +161,8 @@ struct tensor_view
CK_TILE_HOST_DEVICE constexpr void
async_get_vectorized_elements(CK_TILE_LDS_ADDR remove_cvref_t<DataType>* smem,
const TensorCoord& coord,
index_t linear_offset) const
index_t linear_offset,
bool_constant<oob_conditional_check> = {}) const
{
return buf_.template async_get<X>(
smem,
@@ -181,7 +182,8 @@ struct tensor_view
async_get_vectorized_elements(CK_TILE_LDS_ADDR remove_cvref_t<DataType>* smem,
const TensorCoord& coord,
index_t linear_offset,
bool is_valid_element) const
bool is_valid_element,
bool_constant<oob_conditional_check> = {}) const
{
return buf_.template async_get<X>(smem,
coord.get_offset() / PackedSize,
@@ -251,6 +253,33 @@ struct tensor_view
bool_constant<pre_nop>{});
}
template <typename X,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<X>
get_transpose_vectorized_elements(const TensorCoord& coord, index_t linear_offset) const
{
return buf_.template transpose_get<X>(
coord.get_offset(),
linear_offset,
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord));
}
template <typename X,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<X>
get_transpose_vectorized_elements(const TensorCoord& coord,
index_t linear_offset,
bool is_valid_element // flag
) const
{
return buf_.template transpose_get<X>(coord.get_offset(), linear_offset, is_valid_element);
}
// X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template <typename X,

View File

@@ -542,26 +542,26 @@ namespace detail {
//
// e.g
// X0 X1
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 32>, (0 means all length)
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice start:<0, 0>, end:<-1, 32>, (-1 means the last one)
// Y P P Y P Y P Y
// => <1, 4, 32> - <1, 1, 4, 2, 4> -> OK
// |--> slice along this Y dim, is the first dim of X1, totally 4 slices
//
// X0 X1
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 8>, (0 means all length)
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice start:<0, 0>, end:<-1, 8>, (-1 means the last one)
// Y P P Y P Y P Y
// => <1, 4, 32> - <1, 1, 1, 2, 4> -> OK
// |--> slice along this Y dim, the P dim is 1 in the left, so is OK
// totally 16 slices
//
// X0 X1
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 4>, (0 means all length)
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice start:<0, 0>, end:<-1, 4>, (-1 means the last one)
// Y P P Y P Y P Y
// => <1, 4, 32> - <1, 1, 1, 1, 4> -> Fail
// |--> slice along this P dim, will split threads, not supported
//
// X0 X1
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 16>, (0 means all length)
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice start:<0, 0>, end:<-1, 16>, (-1 means the last one)
// Y P P Y P Y P Y
// => <1, 4, 32> - <1, 1, 2, 2, 4> -> OK
// |--> slice along this Y dim, but this Y sim need to split into 2
@@ -577,11 +577,39 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
using Encoding = decltype(Distribution::get_static_tile_distribution_encoding());
static_assert(sizeof...(XSliceBegins) == sizeof...(XSliceEnds));
static_assert(sizeof...(XSliceBegins) == Encoding::NDimX, "only support slice over h, not r");
constexpr auto x_slice_lengths = x_slice_ends - x_slice_begins;
constexpr auto p_len_over_h = Encoding::detail::get_uniformed_p_dim_lengths_over_h();
constexpr auto x_slice_ends_ = generate_sequence_v2(
[&](auto i) {
if constexpr(x_slice_ends[i] == -1)
{
// -1 means till the end
constexpr auto x_length_ =
container_reduce(typename Encoding::HsLengthss{}[i], multiplies{}, number<1>{});
return x_length_;
}
else
{
return x_slice_ends[i];
}
},
number<x_slice_ends.size()>{});
constexpr auto x_slice_lengths = x_slice_ends_ - x_slice_begins;
constexpr auto x_slice_lengths_without_p = generate_sequence_v2(
[&](auto i) constexpr {
constexpr auto len_ = x_slice_lengths[i];
static_assert(len_ % p_len_over_h[i] == 0,
"slice length must be dividable by p_len_over_h");
return number<len_ / p_len_over_h[i]>{};
},
number<x_slice_lengths.size()>{});
constexpr auto src_h_prefix_sum = Encoding::detail::get_h_dim_lengths_prefix_sum();
constexpr auto src_y_info = Encoding::detail::get_sorted_y_info();
constexpr auto src_y_info = Encoding::detail::get_sorted_y_to_h_info();
constexpr auto src_y_dims = src_y_info[number<0>{}];
constexpr auto src_y_maps = src_y_info[number<1>{}];
constexpr auto src_y_prefix_sum = src_y_info[number<2>{}];
@@ -590,14 +618,15 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
{
auto y_slice_sorted_origins = make_zero_multi_index<Encoding::NDimY>();
auto y_slice_lengths = Encoding::detail::ys_lengths_;
constexpr auto y_to_h_masks = Encoding::detail::get_y_to_h_masks();
// This lambda will modify some value outside, so c++ will not treat return value as
// constexpr
// TODO: ugly
auto new_h_lengths = transform_tuples(
[&](auto h_len, auto id) {
constexpr auto sliced_h =
reverse_slice_sequence(h_len, number<x_slice_lengths[id]>{});
constexpr auto sliced_h = reverse_slice_sequence(
h_len, number<x_slice_lengths_without_p[id]>{}, y_to_h_masks[id]);
constexpr auto sliced_h_lens = sliced_h[number<0>{}];
constexpr auto sliced_h_index = sliced_h[number<2>{}];
@@ -605,26 +634,39 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
// update y_slice_lengths
constexpr auto uniformed_h_index = sliced_h_index + number<src_h_prefix_sum[id]>{};
constexpr auto found_y_index = container_find(src_y_dims, uniformed_h_index);
constexpr auto y_to_h_dim_end = src_y_prefix_sum[id + 1];
static_assert(found_y_index >= 0 && found_y_index < src_y_dims.size(),
"not sliced at y dim, please check");
static_for<0, sliced_h_index + 1, 1>{}([&](auto i) {
y_slice_lengths(src_y_maps[found_y_index - i]) =
sliced_h_lens[sliced_h_index - i];
});
{
constexpr auto sliced_y_to_h_lens =
pick_sequence_elements_by_mask(sliced_h_lens, y_to_h_masks[id]);
constexpr auto sliced_y_to_h_dims = sliced_y_to_h_lens.size();
static_for<0, sliced_y_to_h_dims, 1>{}([&](auto i) {
y_slice_lengths(src_y_maps[y_to_h_dim_end - 1 - i]) =
sliced_y_to_h_lens[sliced_y_to_h_dims - 1 - i];
});
}
// TODO: add validations not across p dim
// NOTE: this y_origin is for all dims, not only current dim
// will later use pick to select target dim
constexpr auto y_origin = [&]() {
constexpr auto h_trans = make_merge_transform_v3_division_mod(h_len);
auto h_origin_ = make_zero_multi_index<h_trans.NDimLow>();
h_trans.calculate_lower_index(h_origin_, sequence<x_slice_begins[id].value>{});
// can't use Encoding::Ys2RHsMajor/Ys2RHsMinor, these are unordered
constexpr auto y_to_h_len =
pick_sequence_elements_by_mask(h_len, y_to_h_masks[id]);
constexpr auto y_to_h_dims = y_to_h_len.size();
constexpr auto h_trans = make_merge_transform_v3_division_mod(y_to_h_len);
auto h_origin_ = make_zero_multi_index<h_trans.NDimLow>();
constexpr auto y_begin_ = x_slice_begins[id] / p_len_over_h[id];
h_trans.calculate_lower_index(h_origin_, sequence<y_begin_.value>{});
auto y_origin_ = make_zero_multi_index<Encoding::NDimY>();
static_for<0, sliced_h_index + 1, 1>{}([&](auto i) {
y_origin_(found_y_index - i) = h_origin_[sliced_h_index - i];
static_for<0, y_to_h_dims, 1>{}([&](auto i) {
y_origin_(y_to_h_dim_end - 1 - i) = h_origin_[y_to_h_dims - 1 - i];
});
return y_origin_;
}();

View File

@@ -47,6 +47,11 @@ struct tile_distribution_encoding
static constexpr auto ys_to_rhs_major_ = Ys2RHsMajor{};
static constexpr auto ys_to_rhs_minor_ = Ys2RHsMinor{};
#if !CK_TILE_ENC_SUPPORT_Y_TO_R
static_assert(container_find(ys_to_rhs_major_, 0) == NDimY,
"do not support Y dim pointed to R dim");
#endif
// redundant but useful info
// TODO: really bad code, should be over-hauled
struct detail
@@ -255,33 +260,107 @@ struct tile_distribution_encoding
}
}();
// e.g. tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>> --> seq<3, 5> --> seq<0, 3, 8>
CK_TILE_HOST_DEVICE static constexpr auto get_h_dim_lengths_prefix_sum()
CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_h_dim_lengths()
{
// <len_d0, len_d1, ...>
// e.g. tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>> --> seq<3, 5>
constexpr auto uniformed_h_dim_lengths = generate_sequence_v2(
[&](auto i) {
constexpr index_t size = HsLengthss{}[i].size();
return number<size>{};
constexpr index_t size_ = HsLengthss{}[i].size();
return number<size_>{};
},
number<NDimX>{});
return uniformed_h_dim_lengths;
}
// note: this function only count the p dim length along h, not r
CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_p_dim_lengths_over_h()
{
// e.g. tuple<seq<1, 4, 32>, seq<1, 2, 8, 4, 4>>
// Y P Y Y P Y P Y
// | | |
// v v v
// return : seq<4, 2 * 4> => seq<4, 8>
constexpr auto uniformed_ps_to_rhss_major_ =
unpack([](auto... xs_) { return merge_sequences(xs_...); }, ps_to_rhss_major_);
constexpr auto uniformed_ps_to_rhss_minor_ =
unpack([](auto... xs_) { return merge_sequences(xs_...); }, ps_to_rhss_minor_);
constexpr auto p_len_ = [&]() {
array<index_t, NDimX> len_{1};
static_for<0, NDimX, 1>{}([&](auto idim_x_) {
constexpr auto major_ = number<idim_x_ + 1>{}; // RDim
static_for<0, uniformed_ps_to_rhss_major_.size(), 1>{}([&](auto idim_u_) {
if constexpr(major_.value == uniformed_ps_to_rhss_major_[idim_u_])
{
constexpr auto minor_ = uniformed_ps_to_rhss_minor_[idim_u_];
constexpr auto h_length_ = hs_lengthss_[idim_x_][minor_];
len_[idim_x_] *= h_length_;
}
});
});
return len_;
}();
constexpr auto p_len_over_h_seq_ = TO_SEQUENCE(p_len_, NDimX);
return p_len_over_h_seq_;
}
//
// R: seq<3>, H: tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>>
// => return seq<1, 3, 5>
// R: seq<>, H: tuple<seq<2, 4>, seq<16, 8, 8>>
// => return seq<0, 2, 3>
CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_rh_dim_lengths()
{
constexpr auto uniformed_rh_dim_lengths =
merge_sequences(sequence<NDimR>{} /*for R dims*/, get_uniformed_h_dim_lengths());
return uniformed_rh_dim_lengths;
}
// e.g. tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>> --> seq<3, 5> --> seq<0, 3, 8>
CK_TILE_HOST_DEVICE static constexpr auto get_h_dim_lengths_prefix_sum()
{
// <0, len_d0, len_d0+len_d1, ...>
// e.g. seq<3, 5> --> seq<0, 3, 8>
constexpr auto h_dim_prefix_sum = prefix_sum_sequence(uniformed_h_dim_lengths);
constexpr auto h_dim_prefix_sum = prefix_sum_sequence(get_uniformed_h_dim_lengths());
return h_dim_prefix_sum;
}
CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_idx_y_to_h()
CK_TILE_HOST_DEVICE static constexpr auto get_rh_dim_lengths_prefix_sum()
{
// <0, len_d0, len_d0+len_d1, ...>
// e.g. seq<3, 5> --> seq<0, 3, 8>
constexpr auto rh_dim_prefix_sum = prefix_sum_sequence(get_uniformed_rh_dim_lengths());
return rh_dim_prefix_sum;
}
CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_idx_p_to_h()
{
// tuple<seq<xx..>, seq<yy..>> -> seq<xx..yy..>
constexpr auto uniformed_ps_to_rhss_major_ =
unpack([](auto... xs_) { return merge_sequences(xs_...); }, ps_to_rhss_major_);
constexpr auto uniformed_ps_to_rhss_minor_ =
unpack([](auto... xs_) { return merge_sequences(xs_...); }, ps_to_rhss_minor_);
constexpr auto all_ps_2_rhss = transform_sequences(
[](auto major, auto minor) constexpr {
constexpr auto rh_dim_prefix_sum = get_rh_dim_lengths_prefix_sum();
return rh_dim_prefix_sum.at(major) + minor;
},
uniformed_ps_to_rhss_major_,
uniformed_ps_to_rhss_minor_);
return all_ps_2_rhss;
}
CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_idx_y_to_rh()
{
constexpr auto all_ys_2_rhss = transform_sequences(
[](auto major, auto minor) constexpr {
// <0, 0, len_d0, len_d0+len_d1, ...>
constexpr auto x_dim_prefix_sum = merge_sequences(
sequence<0>{} /*for R dims*/, get_h_dim_lengths_prefix_sum());
return x_dim_prefix_sum.at(major) + minor;
constexpr auto rh_dim_prefix_sum = get_rh_dim_lengths_prefix_sum();
return rh_dim_prefix_sum.at(major) + minor;
},
Ys2RHsMajor{},
Ys2RHsMinor{});
@@ -289,6 +368,45 @@ struct tile_distribution_encoding
return all_ys_2_rhss;
}
CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_idx_y_to_h()
{
// TODO: Y can't point to R
constexpr auto all_ys_2_rhss = transform_sequences(
[](auto major, auto minor) constexpr {
constexpr auto rh_dim_prefix_sum = get_rh_dim_lengths_prefix_sum();
return rh_dim_prefix_sum.at(major) + minor - NDimR;
},
Ys2RHsMajor{},
Ys2RHsMinor{});
return all_ys_2_rhss;
}
// return tuple of seq
CK_TILE_HOST_DEVICE static constexpr auto get_y_to_h_masks()
{
constexpr auto masks_ = generate_tuple(
[&](auto i) {
constexpr auto size_ = HsLengthss{}[i].size();
constexpr auto current_y_to_h_mask_ = [&]() {
array<index_t, size_> m_{0};
// TODO: we loop over all y for each h dim
for(auto j = 0; j < NDimY; j++)
{
if(Ys2RHsMajor{}[j] == (i + 1) /*RDim need plus 1*/)
{
m_[Ys2RHsMinor{}[j]] = 1;
}
}
return m_;
}();
return TO_SEQUENCE(current_y_to_h_mask_, size_);
},
number<NDimX>{});
return masks_;
}
// return tuple<sorted_dims, sorted_maps, sorted_prefix_sum>
template <typename IdxSeq, typename PrefixSumSeq>
CK_TILE_HOST_DEVICE static constexpr auto get_sorted_info(IdxSeq, PrefixSumSeq)
@@ -305,7 +423,8 @@ struct tile_distribution_encoding
return make_tuple(sorted_dims, sorted_maps, sorted_prefix_sum);
}
CK_TILE_HOST_DEVICE static constexpr auto get_sorted_y_info()
// Note here y_to_h does not count R dim!
CK_TILE_HOST_DEVICE static constexpr auto get_sorted_y_to_h_info()
{
return get_sorted_info(get_uniformed_idx_y_to_h(), get_h_dim_lengths_prefix_sum());
}

View File

@@ -344,37 +344,82 @@ struct tile_window_with_static_distribution
{
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
using LdsDataType = typename LdsTileWindow::DataType;
// issues * warps * lanes
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
// TODO: LDS offset is not good for intrinsic based implementation(compiler can't figure out
// dependency) hence avoid use offset based solution. size_per_buf should be zero (how to
// check?)
constexpr index_t size_per_buf =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<0>{}, number<0>{}));
constexpr index_t size_per_wave =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<1>{}, number<0>{})) -
size_per_buf;
constexpr index_t size_per_issue =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<1>{}, number<0>{}, number<0>{})) -
size_per_buf;
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
using Traits = typename Base::Traits;
using Traits = typename Base::Traits;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
// TODO: we force CK_TILE_LDS_ADDR
CK_TILE_LDS_ADDR LdsDataType* smem =
lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ + m0_init_value;
// Precompute invariant values outside loops
const auto window_origin = lds_tile.get_window_origin();
const auto& bottom_tensor_view = lds_tile.get_bottom_tensor_view();
const auto& tensor_descriptor = bottom_tensor_view.get_tensor_descriptor();
auto smem_base_ptr = bottom_tensor_view.get_buffer_view().p_data_;
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// Use precomputed window origin
auto lds_bottom_tensor_thread_idx =
window_origin + window_adaptor_thread_coord.get_bottom_index();
// Use precomputed tensor descriptor
const auto lds_coord =
make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx);
// Calculate SMEM address using base pointer
CK_TILE_LDS_ADDR LdsDataType* smem = smem_base_ptr + lds_coord.get_offset();
// Write into bottom tensor
this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
smem,
bottom_tensor_thread_coord,
number<0>{},
bool_constant<oob_conditional_check>{});
// Move thread coordinate if not last access
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
idx_diff_ys);
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
});
}
template <typename Policy, index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_transpose() const
{
constexpr auto tile_dstr = typename Base::TileDstr{};
auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
this->template load_transpose<Policy>(
dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
return dst_tensor;
}
template <typename Policy,
typename DistributedTensor,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_transpose(DistributedTensor& dst_tensor,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
using Traits = typename Base::Traits;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
constexpr auto tile_dstr = typename Base::TileDstr{};
constexpr auto group_func = Policy::group_func;
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
@@ -385,10 +430,31 @@ struct tile_window_with_static_distribution
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// read from bottom tensor
this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
smem, bottom_tensor_thread_coord, 0, bool_constant<oob_conditional_check>{});
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
// read from bottom tensor
const vector_t vec_value =
this->get_bottom_tensor_view()
.template get_transpose_vectorized_elements<vector_t>(
bottom_tensor_thread_coord, 0);
// write into distributed tensor
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto orig_idx_ys = generate_tuple(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
},
number<Base::NDimY>{});
constexpr auto grouped_idx_ys = group_func(orig_idx_ys);
constexpr index_t linear_distributed_index =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(grouped_idx_ys);
dst_tensor.get_thread_buffer().template at<linear_distributed_index>() =
vec_value.template get_as<typename Base::DataType>()[j];
});
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
@@ -400,8 +466,6 @@ struct tile_window_with_static_distribution
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
smem += size_per_issue; // Note we manually increase the per-issue offset
}
});
});
@@ -415,7 +479,6 @@ struct tile_window_with_static_distribution
{
using Traits = typename Base::Traits;
// using vector_type_t = typename Traits::vector_type_t;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;

View File

@@ -186,7 +186,7 @@ struct tile_window_linear
const typename Base::WindowLengths& window_lengths,
const typename Base::BottomTensorIndex& window_origin,
const typename Base::TileDstr& tile_distribution)
: cached_coords_{}, cached_flags_{}
: cached_coords_{}, cached_window_adaptor_coords_{}, cached_flags_{}
{
this->bottom_tensor_view_ = bottom_tensor_view;
this->window_lengths_ = window_lengths;
@@ -214,7 +214,8 @@ struct tile_window_linear
if constexpr(need_save_non_linear_coord)
{
cached_coords_(non_linear_id) = bottom_tensor_thread_coord_tmp;
cached_coords_(non_linear_id) = bottom_tensor_thread_coord_tmp;
cached_window_adaptor_coords_(non_linear_id) = window_adaptor_thread_coord_tmp;
}
// TODO: need pad_tensor_view to check which dim need use flag to check
@@ -314,8 +315,7 @@ struct tile_window_linear
constexpr auto tile_dstr = typename Base::TileDstr{};
auto dst_tensor =
make_static_distributed_tensor<typename Base::DataTypeDataType>(tile_dstr);
auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access_>{};
@@ -348,8 +348,9 @@ struct tile_window_linear
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
Base::Traits::PackedSize;
dst_tensor.get_thread_buffer().template at<d>() = vec_value.template get_as<
typename Base::DataTypeDataType>()[j / Base::Traits::PackedSize];
dst_tensor.get_thread_buffer().template at<d>() =
vec_value
.template get_as<typename Base::DataType>()[j / Base::Traits::PackedSize];
});
};
@@ -400,8 +401,9 @@ struct tile_window_linear
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
Base::Traits::PackedSize;
dst_tensor.get_thread_buffer().template at<d>() = vec_value.template get_as<
typename Base::DataTypeDataType>()[j / Base::Traits::PackedSize];
dst_tensor.get_thread_buffer().template at<d>() =
vec_value
.template get_as<typename Base::DataType>()[j / Base::Traits::PackedSize];
});
};
@@ -553,66 +555,101 @@ struct tile_window_linear
{
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
using LdsDataType = typename LdsTileWindow::DataType;
using vector_t = typename traits::vector_t;
// currently we only support everything is non linear dim
// actually it's not performant if we have linear dim(e.g. fast changing)
static_assert(NumAccess_NonLinear == NumAccess);
static_assert(NumAccess_NonLinear == NumAccess, "Unsupported configuration");
static_assert(Base::BottomTensorView::buffer_view::get_address_space() ==
address_space_enum::global);
address_space_enum::global,
"Requires global memory");
// issues * warps * lanes
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
// Precompute invariant values outside the lambda
const auto window_origin = lds_tile.get_window_origin();
const auto& bottom_tensor_view = lds_tile.get_bottom_tensor_view();
const auto& tensor_descriptor = bottom_tensor_view.get_tensor_descriptor();
auto smem_base_ptr = bottom_tensor_view.get_buffer_view().p_data_;
// TODO: LDS offset is not good for intrinsic based implementation(compiler can't figure out
// dependency) hence avoid use offset based solution. size_per_buf should be zero (how to
// check?)
constexpr index_t size_per_buf =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<0>{}, number<0>{}));
constexpr index_t size_per_wave =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<1>{}, number<0>{})) -
size_per_buf;
constexpr index_t size_per_issue =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<1>{}, number<0>{}, number<0>{})) -
size_per_buf;
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
using vector_t = typename Base::Traits::vector_t;
// TODO: we force CK_TILE_LDS_ADDR
CK_TILE_LDS_ADDR LdsDataType* smem =
lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ + m0_init_value;
// loop over thread tensor space [y0, y1, ...]
auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access_>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
constexpr auto IAccess = number<i_access_>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
// Use precomputed values
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
auto window_adaptor_coord = cached_window_adaptor_coords_[non_linear_id];
auto bottom_tensor_flag = cached_flags_[IAccess];
// read from bottom tensor
auto lds_bottom_tensor_thread_idx =
window_origin + window_adaptor_coord.get_bottom_index();
const auto lds_coord =
make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx);
CK_TILE_LDS_ADDR LdsDataType* smem = smem_base_ptr + lds_coord.get_offset();
// Read from bottom tensor
this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
smem,
bottom_tensor_thread_coord,
0,
bottom_tensor_flag,
bool_constant<oob_conditional_check>{});
// move thread coordinate
if constexpr(i_access_ != (NumAccess - 1))
{
smem += size_per_issue; // Note we manually increase the per-issue offset
}
};
WINDOW_DISPATCH_ISSUE();
}
template <typename Policy, index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_transpose() const
{
constexpr auto tile_dstr = typename Base::TileDstr{};
auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
this->template load_transpose_linear<Policy>(
dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
return dst_tensor;
}
template <typename Policy,
typename DistributedTensor,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_transpose_linear(DistributedTensor& dst_tensor,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}) const
{
using vector_t = typename traits::vector_t;
using SFC_Ys = typename traits::SFC_Ys;
constexpr auto tile_dstr = typename Base::TileDstr{};
constexpr auto group_func = Policy::group_func;
auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access_>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
auto bottom_tensor_flag = cached_flags_[IAccess];
constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess);
// read from bottom tensor
const vector_t vec_value =
this->get_bottom_tensor_view().template get_transpose_vectorized_elements<vector_t>(
bottom_tensor_thread_coord, 0);
// write into distributed tensor
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
},
number<Base::NDimY>{});
constexpr index_t linear_distributed_index =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
dst_tensor.get_thread_buffer().template at<linear_distributed_index>() =
vec_value.template get_as<typename Base::DataType>()[j];
});
};
WINDOW_DISPATCH_ISSUE();
}
template <index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE void store(const static_distributed_tensor<typename Base::DataType,
typename Base::TileDstr>& dstr_tensor,
@@ -750,8 +787,7 @@ struct tile_window_linear
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
Base::Traits::PackedSize;
vec_value.template get_as<typename Base::DataTypeDataType>()(
j / Base::Traits::PackedSize) =
vec_value.template get_as<typename Base::DataType>()(j / Base::Traits::PackedSize) =
dstr_tensor.get_thread_buffer().template at<d>();
});
@@ -806,8 +842,7 @@ struct tile_window_linear
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
Base::Traits::PackedSize;
vec_value.template get_as<typename Base::DataTypeDataType>()(
j / Base::Traits::PackedSize) =
vec_value.template get_as<typename Base::DataType>()(j / Base::Traits::PackedSize) =
dstr_tensor.get_thread_buffer().template at<d>();
});
@@ -875,7 +910,8 @@ struct tile_window_linear
if constexpr(need_save_non_linear_coord)
{
cached_coords_(non_linear_id) = bottom_tensor_thread_coord_tmp;
cached_coords_(non_linear_id) = bottom_tensor_thread_coord_tmp;
cached_window_adaptor_coords_(non_linear_id) = window_adaptor_thread_coord_tmp;
}
if constexpr(i_access != (NumAccess - 1))
@@ -895,6 +931,8 @@ struct tile_window_linear
// this contains:
array<typename Base::BottomTensorCoord, traits::NumAccess_NonLinear> cached_coords_;
array<typename Base::WindowAdaptorCoord, traits::NumAccess_NonLinear>
cached_window_adaptor_coords_;
array<bool, Base::Traits::NumAccess> cached_flags_;
};