Merge branch 'develop' into hstu_attention_n0loop_fused_unroll

This commit is contained in:
Qianfeng Zhang
2025-08-18 13:47:44 +00:00
1876 changed files with 192746 additions and 21091 deletions

View File

@@ -10,9 +10,11 @@
#include "ck_tile/core/algorithm/static_encoding_pattern.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing_builtins.hpp"
#include "ck_tile/core/arch/amd_transpose_load_encoding.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/generic_memory_space_atomic.hpp"
#include "ck_tile/core/arch/utility.hpp"
#include "ck_tile/core/arch/workgroup_barrier.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/container_helper.hpp"
@@ -25,19 +27,23 @@
#include "ck_tile/core/container/thread_buffer.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/e8m0.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/int8.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/numeric/mxfp_convert.hpp"
#include "ck_tile/core/numeric/null_type.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/numeric/pk_fp4.hpp"
#include "ck_tile/core/numeric/pk_int4.hpp"
#include "ck_tile/core/numeric/type_convert.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/core/tensor/buffer_view.hpp"
#include "ck_tile/core/tensor/load_tile.hpp"
#include "ck_tile/core/tensor/load_tile_transpose.hpp"
#include "ck_tile/core/tensor/null_tensor.hpp"
#include "ck_tile/core/tensor/null_tile_window.hpp"
#include "ck_tile/core/tensor/shuffle_tile.hpp"
@@ -53,12 +59,15 @@
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
#include "ck_tile/core/tensor/tile_elementwise.hpp"
#include "ck_tile/core/tensor/tile_scatter_gather.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window_base.hpp"
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/tile_window_utils.hpp"
#include "ck_tile/core/tensor/transpose_tile.hpp"
#include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/debug.hpp"
#include "ck_tile/core/utility/env.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional_with_tuple.hpp"
@@ -66,6 +75,7 @@
#include "ck_tile/core/utility/literals.hpp"
#include "ck_tile/core/utility/magic_div.hpp"
#include "ck_tile/core/utility/philox_rand.hpp"
#include "ck_tile/core/utility/print.hpp"
#include "ck_tile/core/utility/random.hpp"
#include "ck_tile/core/utility/reduce_operator.hpp"
#include "ck_tile/core/utility/static_counter.hpp"

View File

@@ -9,6 +9,7 @@
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/utility/magic_div.hpp"
#include "ck_tile/core/utility/print.hpp"
namespace ck_tile {
@@ -139,20 +140,19 @@ struct pass_through : public base_transform<1, 1>
{
return make_tuple(low_vector_lengths, low_vector_strides);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("pass_through{");
//
printf("up_lengths_:");
print(up_lengths_);
//
printf("}");
}
};
template <typename LowLength>
CK_TILE_HOST_DEVICE static void print(const pass_through<LowLength>& pt)
{
printf("pass_through{");
printf("up_lengths_: ");
print(pt.get_upper_lengths());
printf("}");
}
template <typename LowLength,
typename LeftPadLength,
typename RightPadLength,
@@ -229,29 +229,25 @@ struct pad : public base_transform<1, 1>
ck_tile::is_known_at_compile_time<LeftPadLength>::value &&
ck_tile::is_known_at_compile_time<RightPadLength>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("pad{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("left_pad_length_: ");
print(left_pad_length_);
printf(", ");
//
printf("right_pad_length_: ");
print(right_pad_length_);
printf("}");
}
};
template <typename LowLength,
typename LeftPadLength,
typename RightPadLength,
bool SkipIsValidCheck>
CK_TILE_HOST_DEVICE static void
print(const pad<LowLength, LeftPadLength, RightPadLength, SkipIsValidCheck>& p)
{
printf("pad{");
printf("up_lengths_: ");
print(p.up_lengths_);
printf(", left_pad_length_: ");
print(p.left_pad_length_);
printf(", right_pad_length_: ");
print(p.right_pad_length_);
printf("}");
}
template <typename LowLength, typename LeftPadLength, bool SkipIsValidCheck = false>
struct left_pad
{
@@ -330,24 +326,20 @@ struct left_pad
// It's up to runtime to check the padding length should be multiple of vector length
return make_tuple(low_vector_lengths, low_vector_strides);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("left_pad{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("left_pad_length_: ");
print(left_pad_length_);
printf("}");
}
};
template <typename LowLength, typename LeftPadLength, bool SkipIsValidCheck>
CK_TILE_HOST_DEVICE static void
print(const left_pad<LowLength, LeftPadLength, SkipIsValidCheck>& lp)
{
printf("left_pad{");
printf("up_lengths_: ");
print(lp.up_lengths_);
printf(", left_pad_length_: ");
print(lp.left_pad_length_);
printf("}");
}
template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck = false>
struct right_pad : public base_transform<1, 1>
{
@@ -430,24 +422,20 @@ struct right_pad : public base_transform<1, 1>
// It's up to runtime to check the padding length should be multiple of vector length
return make_tuple(low_vector_lengths, low_vector_strides);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("right_pad{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("right_pad_length_: ");
print(right_pad_length_);
printf("}");
}
};
template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck>
CK_TILE_HOST_DEVICE static void
print(const right_pad<LowLength, RightPadLength, SkipIsValidCheck>& rp)
{
printf("right_pad{");
printf("up_lengths_: ");
print(rp.up_lengths_);
printf(", right_pad_length_: ");
print(rp.right_pad_length_);
printf("}");
}
// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1]
// UpLengths and Coefficients can be either of the followings:
// 1) Tuple of index_t, which is known at run-time, or
@@ -532,24 +520,19 @@ struct embed : public base_transform<1, UpLengths::size()>
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
ck_tile::is_known_at_compile_time<Coefficients>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("embed{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("coefficients_: ");
print(coefficients_);
printf("}");
}
};
template <typename UpLengths, typename Coefficients>
CK_TILE_HOST_DEVICE static void print(const embed<UpLengths, Coefficients>& e)
{
printf("embed{");
printf("up_lengths_: ");
print(e.up_lengths_);
printf(", coefficients_: ");
print(e.coefficients_);
printf("}");
}
template <typename LowLengths>
struct lambda_merge_generate_MagicDivision_calculate_magic_divisor
{
@@ -699,24 +682,19 @@ struct merge_v2_magic_division : public base_transform<LowLengths::size(), 1>
return make_tuple(up_vector_lengths, up_vector_strides);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("merge_v2_magic_division{");
//
printf("low_lengths_ ");
print(low_lengths_);
printf(", ");
//
printf("up_lengths_ ");
print(up_lengths_);
printf("}");
}
};
template <typename LowLengths>
CK_TILE_HOST_DEVICE static void print(const merge_v2_magic_division<LowLengths>& m)
{
printf("merge_v2_magic_division{");
printf("low_lengths_: ");
print(m.low_lengths_);
printf(", up_lengths_: ");
print(m.up_lengths_);
printf("}");
}
// Implementation of "merge" transformation primitive that uses division and mod. It is supposed to
// be used for low_lengths that are known at compile time and are power of 2, otherwise performance
// will be very bad
@@ -830,29 +808,21 @@ struct merge_v3_division_mod : public base_transform<LowLengths::size(), 1>
return make_tuple(up_vector_lengths, up_vector_strides);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("Merge_v3_direct_division_mod{");
//
printf("low_lengths_ ");
print(low_lengths_);
printf(", ");
//
printf("low_lengths_scan_ ");
print(low_lengths_scan_);
printf(", ");
//
printf("up_lengths_ ");
print(up_lengths_);
printf("}");
}
};
template <typename LowLengths>
CK_TILE_HOST_DEVICE static void print(const merge_v3_division_mod<LowLengths>& m)
{
printf("merge_v3_division_mod{");
printf("low_lengths_: ");
print(m.low_lengths_);
printf(", low_lengths_scan_: ");
print(m.low_lengths_scan_);
printf(", up_lengths_: ");
print(m.up_lengths_);
printf("}");
}
template <typename UpLengths, bool Use24BitIntegerCalculation>
struct unmerge : public base_transform<1, UpLengths::size()>
{
@@ -958,24 +928,19 @@ struct unmerge : public base_transform<1, UpLengths::size()>
return make_tuple(up_vector_lengths, up_vector_strides);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("unmerge{");
//
printf("up_lengths_");
print(up_lengths_);
printf(", ");
//
printf("up_lengths_scan_");
print(up_lengths_scan_);
printf("}");
}
};
template <typename UpLengths, bool Use24BitIntegerCalculation>
CK_TILE_HOST_DEVICE static void print(const unmerge<UpLengths, Use24BitIntegerCalculation>& u)
{
printf("unmerge{");
printf("up_lengths_: ");
print(u.up_lengths_);
printf(", up_lengths_scan_: ");
print(u.up_lengths_scan_);
printf("}");
}
template <typename LowerIndex>
struct freeze : public base_transform<1, 0>
{
@@ -1023,19 +988,17 @@ struct freeze : public base_transform<1, 0>
{
return ck_tile::is_known_at_compile_time<LowerIndex>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("freeze{");
//
printf("low_idx_: ");
print(low_idx_);
printf("}");
}
};
template <typename LowerIndex>
CK_TILE_HOST_DEVICE static void print(const freeze<LowerIndex>& f)
{
printf("freeze{");
printf("low_idx_: ");
print(f.low_idx_);
printf("}");
}
// insert a dangling upper dimension without lower dimension
template <typename UpperLength>
struct insert : public base_transform<0, 1>
@@ -1092,18 +1055,17 @@ struct insert : public base_transform<0, 1>
{
return ck_tile::is_known_at_compile_time<UpperLength>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("insert{");
//
print(up_lengths_);
printf("}");
}
};
template <typename UpperLength>
CK_TILE_HOST_DEVICE static void print(const insert<UpperLength>& i)
{
printf("insert{");
printf("up_lengths_: ");
print(i.up_lengths_);
printf("}");
}
// replicate the original tensor and create a higher dimensional tensor
template <typename UpLengths>
struct replicate : public base_transform<0, UpLengths::size()>
@@ -1152,21 +1114,19 @@ struct replicate : public base_transform<0, UpLengths::size()>
return ck_tile::is_known_at_compile_time<UpLengths>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("replicate{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf("}");
}
//
UpLengths up_lengths_;
};
template <typename UpLengths>
CK_TILE_HOST_DEVICE static void print(const replicate<UpLengths>& r)
{
printf("replicate{");
printf("up_lengths_: ");
print(r.up_lengths_);
printf("}");
}
template <typename LowLength, typename SliceBegin, typename SliceEnd>
struct slice : public base_transform<1, 1>
{
@@ -1238,28 +1198,20 @@ struct slice : public base_transform<1, 1>
ck_tile::is_known_at_compile_time<SliceBegin>::value &&
ck_tile::is_known_at_compile_time<SliceEnd>::value;
}
};
CK_TILE_HOST_DEVICE void print() const
{
printf("slice{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("slice_begin_: ");
print(slice_begin_);
printf(", ");
//
printf("slice_end_: ");
print(slice_end_);
printf("}");
} // namespace ck
}; // namespace ck
template <typename LowLength, typename SliceBegin, typename SliceEnd>
CK_TILE_HOST_DEVICE static void print(const slice<LowLength, SliceBegin, SliceEnd>& s)
{
printf("slice{");
printf("up_lengths_: ");
print(s.up_lengths_);
printf(", slice_begin_: ");
print(s.slice_begin_);
printf(", slice_end_: ");
print(s.slice_end_);
printf("}");
}
/*
* \brief lower_idx = upper_idx % modulus.
@@ -1328,19 +1280,19 @@ struct modulo : public base_transform<1, 1>
{
return ck_tile::is_known_at_compile_time<UpLengths>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("Modulus{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf("}");
}
};
template <typename Modulus, typename UpLength>
CK_TILE_HOST_DEVICE static void print(const modulo<Modulus, UpLength>& m)
{
printf("modulo{");
printf("modulus_: ");
print(m.modulus_);
printf(", up_lengths_: ");
print(m.up_lengths_);
printf("}");
}
// 2D XOR, NOTE: "xor" is a keyword
template <typename LowLengths>
struct xor_t : public base_transform<2, 2>
@@ -1424,20 +1376,17 @@ struct xor_t : public base_transform<2, 2>
return make_tuple(up_vector_lengths, up_vector_strides);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("xor_t{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
printf("}");
}
};
template <typename LowLengths>
CK_TILE_HOST_DEVICE static void print(const xor_t<LowLengths>& x)
{
printf("xor_t{");
printf("up_lengths_: ");
print(x.up_lengths_);
printf("}");
}
template <typename LowLength, typename OffsetLength>
struct offset : public base_transform<1, 1>
{
@@ -1509,24 +1458,19 @@ struct offset : public base_transform<1, 1>
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
ck_tile::is_known_at_compile_time<OffsetLength>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("offset{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("offset_length_: ");
print(offset_length_);
printf("}");
}
};
template <typename LowLength, typename OffsetLength>
CK_TILE_HOST_DEVICE static void print(const offset<LowLength, OffsetLength>& o)
{
printf("offset{");
printf("up_lengths_: ");
print(o.up_lengths_);
printf(", offset_length_: ");
print(o.offset_length_);
printf("}");
}
template <typename UpLength, typename IndexingAdaptor>
struct indexing : public base_transform<1, 1>
{
@@ -1595,20 +1539,19 @@ struct indexing : public base_transform<1, 1>
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
IndexingAdaptor::is_known_at_compile_time();
}
CK_TILE_HOST_DEVICE void print() const
{
printf("embed{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
printf("}");
}
};
template <typename UpLength, typename IndexingAdaptor>
CK_TILE_HOST_DEVICE static void print(const indexing<UpLength, IndexingAdaptor>& i)
{
printf("indexing{");
printf("up_lengths_: ");
print(i.up_lengths_);
printf(", iadaptor_: ");
print(i.iadaptor_);
printf("}");
}
//*******************************************************************************************************
template <typename LowLength>

View File

@@ -100,10 +100,8 @@ struct space_filling_curve
// Given tensor strides \p access_lengths, and 1D index of space-filling-curve, compute the
// idim-th element of multidimensional index.
// All constexpr variables have to be captured by VALUE.
constexpr auto compute_index = [ idx_1d, access_strides ](auto idim) constexpr
{
constexpr auto compute_index_impl = [ idx_1d, access_strides ](auto jdim) constexpr
{
constexpr auto compute_index = [idx_1d, access_strides](auto idim) constexpr {
constexpr auto compute_index_impl = [idx_1d, access_strides](auto jdim) constexpr {
auto res = idx_1d.value;
auto id = 0;

View File

@@ -1,6 +1,73 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
/**
* @file
* We're defining the data access pattern for a 2D window (`XPerTile` by `YPerTile`)
for `BlockSize` threads in a thread block.
* X dimension is considered contiguous in memory, so a single instruction can access
several adjacent and properly aligned elements (vector); the access pattern along X tile
dimension is parameterized only by the suggested vector size `VecSize`.
* We can't access more than `MaxVecSize = TileElementsPerThread = TileSize / BlockSize` elements
with a single memory access, so the actual vector size along the X dimension is
`X0 = min(MaxVecSize, VecSize)`.
* This leaves `X1 = XPerTile / X0` threads per tile in X dimension.
* X1 is also the number of threads per warp in X dimension, that is,
X dimension is not split between warps, and each warp accesses X dimension entirely,
and there is no iteration in X dimension.
* The tuple <X0, X1> defines the X-axis access pattern.
This part is common between the 2D distribution patterns.
* What's different between the different 2D distribution patterns, is the Y axis access pattern.
* There are 3 components in this access pattern;
* (1) number of Y-axis elements (rows) per warp for a single instruction access,
* (2) number of warps per thread block,
* (3) number of iterations to cover the entire Y axis.
* The raked here represents how data is partitioned across different processing granularity.
* It represents howe we are going to access the data in thread, warp, or blocked in contiguous
region.
* From below, the qualifier for 'raked' is the part of warp/thread hierarchy
* in the split of Y tile dimension where the iteration happens,
* meaning, the iteration can be logically inserted as a tile dimension in 3 ways,
* (1) after thread -> thread-raked,
* (2) between warp and thread -> warp-raked,
* (3) before warp -> block-raked
* *Thread raked*
* Y0 is the number of warps, which we can get from the equation `Y0 * WarpSize == BlockSize`
* Y1 is the number of rows accessed by a warp within a single iteration,
compute it from the equation `Y0 * X1 == WarpSize`
* Y2 is the number of iterations to cover the tile,
compute it from the equation `Y0 * Y1 * Y2 == YPerTile`
* *Warp raked*
* Y0 is the number of warps, we can get it in the same way as for thread-raked pattern,
`Y0 * WarpSize == BlockSize`
* Y1 is the number of iterations to cover the tile, `Y0 * Y1 * Y2 == YPerTile`.
Compute Y2 from the equation below
* Y2 is the number of rows accessed by a warp in a single iteration, `Y2 * X1 == WarpSize`
* *Block raked*
* Y0 is the number of iterations to cover the tile, `Y0 * Y1 * Y2 == YPerTile`.
Compute Y1 and Y2 from the equations below
* Y1 is the number of warps, `Y1 * WarpSize == BlockSize`
* Y2 is the number of rows accessed by a warp in a single iteration, `Y2 * X1 == WarpSize`
* In all cases, the tuple <Y0, Y1, Y2> defines the Y-axis access pattern.
* *Selection*
* When we are selecting, Thread-raked is used in element-wise operation because it is the
* Thread-major memory order.
* Warp-raked is used in matrix multiplication because the vectorization is in warp level.
* Block-raked is used mostly for the reduction process, where will reduce the block in global
* atomic level.
*
*/
#pragma once
#include "ck_tile/core/arch/arch.hpp"
@@ -10,6 +77,7 @@
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
#include "ck_tile/core/utility/print.hpp"
namespace ck_tile {
@@ -56,78 +124,116 @@ template <index_t BlockSize,
index_t YPerTile,
index_t XPerTile,
index_t VecSize,
tile_distribution_pattern DistributionPattern>
tile_distribution_pattern DistributionPattern,
index_t NumWaveGroups = 1>
struct TileDistributionEncodingPattern2D : public TileDistributionEncodingPattern
{
};
// Thread raked
template <index_t BlockSize, index_t YPerTile, index_t XPerTile, index_t VecSize>
template <index_t BlockSize,
index_t YPerTile,
index_t XPerTile,
index_t VecSize,
index_t NumWaveGroups>
struct TileDistributionEncodingPattern2D<BlockSize,
YPerTile,
XPerTile,
VecSize,
tile_distribution_pattern::thread_raked>
: public TileDistributionEncodingPattern
tile_distribution_pattern::thread_raked,
NumWaveGroups> : public TileDistributionEncodingPattern
{
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
static constexpr index_t warp_size = get_warp_size();
static constexpr index_t num_warps = BlockSize / get_warp_size();
static constexpr index_t X1 = VecSize;
static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
static constexpr index_t warp_size = get_warp_size();
static constexpr index_t num_warps = BlockSize / get_warp_size();
static constexpr index_t LargestVec = (XPerTile * YPerTile) / (num_warps * warp_size);
static constexpr index_t X1 = VecSize > LargestVec ? LargestVec : VecSize;
static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
// # of rows in Y dim accessed by single wavefront in one iteration
static constexpr index_t Y1 = warp_size / X0;
static_assert(X0 * Y1 == warp_size, "X0 * Y1 must cover whole wavefront!");
static constexpr index_t Y0 = num_warps;
static constexpr index_t Y0 = num_warps / NumWaveGroups;
// YPerWarp = YPerTile / Y0;
// Y2 = YPerWarp / Y1;
static constexpr index_t Y2 = YPerTile / (Y1 * Y0); // # of iters within wavefront
static_assert(X0 * Y1 * Y0 == BlockSize, "X0 * warp_ys * Y0 must cover whole workgroup!");
static_assert(X0 * Y1 * Y0 * NumWaveGroups == BlockSize,
"X0 * warp_ys * Y0 must cover whole workgroup!");
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<1, 2>,
sequence<2, 1>>{});
if constexpr(NumWaveGroups != 1)
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<Y0>,
tuple<sequence<Y1, Y2>, sequence<X0, X1>>,
tuple<sequence<0>, sequence<1, 2>>,
tuple<sequence<0>, sequence<0, 0>>, // -> <Y0>, <Y1, X0>
sequence<1, 2>,
sequence<1, 1>>{}); // -> <Y2, X1>
}
else
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>, // -> <Y0>, <Y1, X0>
sequence<1, 2>,
sequence<2, 1>>{}); // -> <Y2, X1>
}
}
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<X0, X1>, sequence<Y0, Y1, Y2>>,
tuple<sequence<2>, sequence<2, 1>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<1, 2>,
sequence<1, 2>>{});
if constexpr(NumWaveGroups != 1)
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<Y0>,
tuple<sequence<X0, X1>, sequence<Y1, Y2>>,
tuple<sequence<0>, sequence<2, 1>>,
tuple<sequence<0>, sequence<0, 0>>, // -> <Y0>, <Y1, X0>
sequence<1, 2>,
sequence<1, 1>>{}); // -> <X1, Y2>
}
else
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<X0, X1>, sequence<Y0, Y1, Y2>>,
tuple<sequence<2>, sequence<2, 1>>,
tuple<sequence<0>, sequence<1, 0>>, // -> <Y0>, <Y1, X0>
sequence<1, 2>,
sequence<1, 2>>{}); // -> <X1, Y2>
}
}
};
// Warp raked
template <index_t BlockSize, index_t YPerTile, index_t XPerTile, index_t VecSize>
template <index_t BlockSize,
index_t YPerTile,
index_t XPerTile,
index_t VecSize,
index_t NumWaveGroups>
struct TileDistributionEncodingPattern2D<BlockSize,
YPerTile,
XPerTile,
VecSize,
tile_distribution_pattern::warp_raked>
: public TileDistributionEncodingPattern
tile_distribution_pattern::warp_raked,
NumWaveGroups> : public TileDistributionEncodingPattern
{
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
static constexpr index_t warp_size = get_warp_size();
static constexpr index_t num_warps = BlockSize / get_warp_size();
static constexpr index_t X1 = VecSize;
static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
static constexpr index_t warp_size = get_warp_size();
static constexpr index_t num_warps = BlockSize / get_warp_size();
static constexpr index_t LargestVec = (XPerTile * YPerTile) / (num_warps * warp_size);
static constexpr index_t X1 = VecSize > LargestVec ? LargestVec : VecSize;
static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
static constexpr index_t Y2 = warp_size / X0; // # of rows in Y dim to cover whole wavefront
static_assert(X0 * Y2 == warp_size, "X0 * Y2 must cover whole wavefront!");
@@ -144,9 +250,9 @@ struct TileDistributionEncodingPattern2D<BlockSize,
tile_distribution_encoding<sequence<1>,
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<2, 0>>,
tuple<sequence<0>, sequence<2, 0>>, // -> <Y0>, <Y2, X0>
sequence<1, 2>,
sequence<1, 1>>{});
sequence<1, 1>>{}); // -> <Y1, X1>
}
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution()
@@ -155,28 +261,33 @@ struct TileDistributionEncodingPattern2D<BlockSize,
tile_distribution_encoding<sequence<1>,
tuple<sequence<X0, X1>, sequence<Y0, Y1, Y2>>,
tuple<sequence<2>, sequence<2, 1>>,
tuple<sequence<0>, sequence<2, 0>>,
tuple<sequence<0>, sequence<2, 0>>, // -> <Y0>, <Y2, X0>
sequence<1, 2>,
sequence<1, 1>>{});
sequence<1, 1>>{}); // -> <X1, Y1>
}
};
// Block raked
template <index_t BlockSize, index_t YPerTile, index_t XPerTile, index_t VecSize>
template <index_t BlockSize,
index_t YPerTile,
index_t XPerTile,
index_t VecSize,
index_t NumWaveGroups>
struct TileDistributionEncodingPattern2D<BlockSize,
YPerTile,
XPerTile,
VecSize,
tile_distribution_pattern::block_raked>
: public TileDistributionEncodingPattern
tile_distribution_pattern::block_raked,
NumWaveGroups> : public TileDistributionEncodingPattern
{
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
static constexpr index_t warp_size = get_warp_size();
static constexpr index_t num_warps = BlockSize / get_warp_size();
static constexpr index_t X1 = VecSize;
static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
static constexpr index_t warp_size = get_warp_size();
static constexpr index_t num_warps = BlockSize / get_warp_size();
static constexpr index_t LargestVec = (XPerTile * YPerTile) / (num_warps * warp_size);
static constexpr index_t X1 = VecSize > LargestVec ? LargestVec : VecSize;
static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
static constexpr index_t Y2 = warp_size / X0; // # of rows in Y dim to cover whole wavefront
static_assert(X0 * Y2 == warp_size, "X0 * Y2 must cover whole wavefront!");
static constexpr index_t Y1 = num_warps;
@@ -190,9 +301,9 @@ struct TileDistributionEncodingPattern2D<BlockSize,
tile_distribution_encoding<sequence<1>,
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
tuple<sequence<1>, sequence<2, 0>>, // -> <Y1>, <Y2, X0>
sequence<1, 2>,
sequence<0, 1>>{});
sequence<0, 1>>{}); // -> <Y0, X1>
}
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution()
@@ -201,10 +312,57 @@ struct TileDistributionEncodingPattern2D<BlockSize,
tile_distribution_encoding<sequence<1>,
tuple<sequence<X0, X1>, sequence<Y0, Y1, Y2>>,
tuple<sequence<2>, sequence<2, 1>>,
tuple<sequence<1>, sequence<2, 0>>,
tuple<sequence<1>, sequence<2, 0>>, // -> <Y1>, <Y2, X0>
sequence<1, 2>,
sequence<1, 0>>{});
sequence<1, 0>>{}); // -> <X1, Y0>
}
};
// Helper function to convert enum to string
constexpr const char* tile_distribution_pattern_to_string(tile_distribution_pattern pattern)
{
switch(pattern)
{
case tile_distribution_pattern::thread_raked: return "thread_raked";
case tile_distribution_pattern::warp_raked: return "warp_raked";
case tile_distribution_pattern::block_raked: return "block_raked";
default: return "unknown";
}
}
template <index_t BlockSize,
index_t YPerTile,
index_t XPerTile,
index_t VecSize,
tile_distribution_pattern DistributionPattern,
index_t NumWaveGroups>
CK_TILE_HOST_DEVICE void print(const TileDistributionEncodingPattern2D<BlockSize,
YPerTile,
XPerTile,
VecSize,
DistributionPattern,
NumWaveGroups>&)
{
using PatternType = TileDistributionEncodingPattern2D<BlockSize,
YPerTile,
XPerTile,
VecSize,
DistributionPattern,
NumWaveGroups>;
printf("TileDistributionEncodingPattern2D<BlockSize:%d, YPerTile:%d, XPerTile:%d, "
"VecSize:%d, %s>: ",
BlockSize,
YPerTile,
XPerTile,
VecSize,
tile_distribution_pattern_to_string(DistributionPattern));
printf("{<Y0, Y1, Y2>: <%d, %d, %d>, <X0, X1>: <%d, %d>}\n",
PatternType::Y0,
PatternType::Y1,
PatternType::Y2,
PatternType::X0,
PatternType::X1);
}
} // namespace ck_tile

View File

@@ -13,6 +13,18 @@
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/ignore.hpp"
// This attribute gives a hint to the compiler that a branch is likely to be taken.
// Then, the compiler should remove if possible the associated s_cbranch_execz branch that would
// have been generated.
#if __cplusplus >= 202002L
#define LIKELY(x) (x) [[likely]]
#else
#define LIKELY(x) (__builtin_expect(!!(x), 1))
#endif
using as3_uint32_ptr = uint32_t __attribute__((address_space(3)))*;
namespace ck_tile {
@@ -54,10 +66,36 @@ template<> struct buffer_load_trait<4 , thread_buffer<bf16_t, 2>> { using payloa
// TODO: glc/slc/...
template <index_t bytes, bool pre_nop = false>
struct buffer_load;
template <index_t bytes, bool pre_nop = false>
struct buffer_load_if;
template <index_t bytes>
struct buffer_store;
template <index_t bytes>
struct buffer_store_if;
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wundefined-reinterpret-cast"
// TODO: strict aliasing rule seems fail when reinterpret_cast between vector type
// (exp_vector_type(xxx))
#define HAS_RAW_BUFFER_BUILTINS \
__has_builtin(__builtin_amdgcn_raw_buffer_load_b32) && \
__has_builtin(__builtin_amdgcn_make_buffer_rsrc) && \
__has_builtin(__builtin_amdgcn_raw_buffer_store_b32)
#if HAS_RAW_BUFFER_BUILTINS
CK_TILE_DEVICE __amdgpu_buffer_rsrc_t cast_to_amdgpu_buffer_rsrc_t(int32x4_t res)
{
__amdgpu_buffer_rsrc_t as_rsrc;
static_assert(sizeof(res) == sizeof(as_rsrc) && "Size of buffer resource should match");
memcpy(&as_rsrc, &res, sizeof(res));
return as_rsrc;
}
#endif
template <bool pre_nop>
struct buffer_load<16, pre_nop>
{
@@ -72,6 +110,11 @@ struct buffer_load<16, pre_nop>
{
static_assert(sizeof(T) == 16);
using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t;
#if HAS_RAW_BUFFER_BUILTINS
index_t s_offset = i_offset;
reinterpret_cast<mbuf_t&>(value) = __builtin_amdgcn_raw_buffer_load_b128(
cast_to_amdgpu_buffer_rsrc_t(res), v_offset, s_offset, 0);
#else
if constexpr(pre_nop)
asm volatile("s_nop 4\n"
"buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3"
@@ -83,6 +126,7 @@ struct buffer_load<16, pre_nop>
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
#endif
}
};
@@ -100,6 +144,11 @@ struct buffer_load<8, pre_nop>
{
static_assert(sizeof(T) == 8);
using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t;
#if HAS_RAW_BUFFER_BUILTINS
index_t s_offset = i_offset;
reinterpret_cast<mbuf_t&>(value) = __builtin_amdgcn_raw_buffer_load_b64(
cast_to_amdgpu_buffer_rsrc_t(res), v_offset, s_offset, 0);
#else
if constexpr(pre_nop)
asm volatile("s_nop 4\n"
"buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3"
@@ -111,6 +160,7 @@ struct buffer_load<8, pre_nop>
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
#endif
}
};
@@ -128,6 +178,12 @@ struct buffer_load<4, pre_nop>
{
static_assert(sizeof(T) == 4);
using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t;
#if HAS_RAW_BUFFER_BUILTINS
index_t s_offset = i_offset;
reinterpret_cast<mbuf_t&>(value) = __builtin_amdgcn_raw_buffer_load_b32(
cast_to_amdgpu_buffer_rsrc_t(res), v_offset, s_offset, 0);
#else
if constexpr(pre_nop)
asm volatile("s_nop 4\n"
"buffer_load_dword %0, %1, %2, 0 offen offset:%3"
@@ -139,6 +195,7 @@ struct buffer_load<4, pre_nop>
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
#endif
}
};
@@ -156,6 +213,12 @@ struct buffer_load<2, pre_nop>
{
static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually
using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t;
#if HAS_RAW_BUFFER_BUILTINS
index_t s_offset = i_offset;
reinterpret_cast<mbuf_t&>(value) = __builtin_amdgcn_raw_buffer_load_b16(
cast_to_amdgpu_buffer_rsrc_t(res), v_offset, s_offset, 0);
#else
if constexpr(pre_nop)
asm volatile("s_nop 4\n"
"buffer_load_ushort %0, %1, %2, 0 offen offset:%3"
@@ -167,6 +230,7 @@ struct buffer_load<2, pre_nop>
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
#endif
}
};
@@ -184,6 +248,11 @@ struct buffer_load<1, pre_nop>
{
static_assert(sizeof(T) == 4);
using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t;
#if HAS_RAW_BUFFER_BUILTINS
index_t s_offset = i_offset;
reinterpret_cast<mbuf_t&>(value) = __builtin_amdgcn_raw_buffer_load_b16(
cast_to_amdgpu_buffer_rsrc_t(res), v_offset, s_offset, 0);
#else
if constexpr(pre_nop)
asm volatile("s_nop 4\n"
"buffer_load_ubyte %0, %1, %2, 0 offen offset:%3"
@@ -195,12 +264,31 @@ struct buffer_load<1, pre_nop>
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
#endif
}
};
template <index_t bytes, bool pre_nop = false>
struct buffer_load_if;
#if HAS_RAW_BUFFER_BUILTINS
template <index_t bytes, bool pre_nop>
struct buffer_load_if
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t i_offset /*max 0xFFF*/,
index_t flag = 0,
bool_constant<pre_nop> = {})
{
if LIKELY(1 <= flag)
{
buffer_load<bytes, pre_nop>{}(
value, res, v_offset, s_offset, i_offset, flag, bool_constant<pre_nop>{});
}
}
};
#else
template <bool pre_nop>
struct buffer_load_if<16, pre_nop>
{
@@ -366,9 +454,9 @@ struct buffer_load_if<1, pre_nop>
: "memory");
}
};
#endif
#pragma clang diagnostic pop // "-Wundefined-reinterpret-cast"
template <index_t bytes>
struct buffer_store;
template <>
struct buffer_store<16>
@@ -383,10 +471,16 @@ struct buffer_store<16>
{
static_assert(sizeof(T) == 16);
using mbuf_t = fp32x4_t;
#if HAS_RAW_BUFFER_BUILTINS
index_t s_offset = i_offset;
__builtin_amdgcn_raw_buffer_store_b128(
bit_cast<mbuf_t>(value), cast_to_amdgpu_buffer_rsrc_t(res), v_offset, s_offset, 0);
#else
asm volatile("buffer_store_dwordx4 %0, %1, %2, 0 offen offset:%3"
:
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
#endif
}
};
@@ -403,10 +497,16 @@ struct buffer_store<8>
{
static_assert(sizeof(T) == 8);
using mbuf_t = fp32x2_t;
#if HAS_RAW_BUFFER_BUILTINS
index_t s_offset = i_offset;
__builtin_amdgcn_raw_buffer_store_b64(
bit_cast<mbuf_t>(value), cast_to_amdgpu_buffer_rsrc_t(res), v_offset, s_offset, 0);
#else
asm volatile("buffer_store_dwordx2 %0, %1, %2, 0 offen offset:%3"
:
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
#endif
}
};
@@ -423,10 +523,16 @@ struct buffer_store<4>
{
static_assert(sizeof(T) == 4);
using mbuf_t = float;
#if HAS_RAW_BUFFER_BUILTINS
index_t s_offset = i_offset;
__builtin_amdgcn_raw_buffer_store_b32(
bit_cast<mbuf_t>(value), cast_to_amdgpu_buffer_rsrc_t(res), v_offset, s_offset, 0);
#else
asm volatile("buffer_store_dword %0, %1, %2, 0 offen offset:%3"
:
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
#endif
}
};
@@ -443,10 +549,16 @@ struct buffer_store<2>
{
static_assert(sizeof(T) == 2);
using mbuf_t = short;
#if HAS_RAW_BUFFER_BUILTINS
index_t s_offset = i_offset;
__builtin_amdgcn_raw_buffer_store_b16(
bit_cast<mbuf_t>(value), cast_to_amdgpu_buffer_rsrc_t(res), v_offset, s_offset, 0);
#else
asm volatile("buffer_store_short %0, %1, %2, 0 offen offset:%3"
:
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
#endif
}
};
@@ -463,16 +575,38 @@ struct buffer_store<1>
{
static_assert(sizeof(T) == 4);
using mbuf_t = float;
#if HAS_RAW_BUFFER_BUILTINS
index_t s_offset = i_offset;
__builtin_amdgcn_raw_buffer_store_b8(
bit_cast<mbuf_t>(value), cast_to_amdgpu_buffer_rsrc_t(res), v_offset, s_offset, 0);
#else
asm volatile("buffer_store_byte %0, %1, %2, 0 offen offset:%3"
:
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
#endif
}
};
#if HAS_RAW_BUFFER_BUILTINS
template <index_t bytes>
struct buffer_store_if;
struct buffer_store_if
{
template <typename T>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t i_offset /*max 0xFFF*/,
index_t flag = 1)
{
if LIKELY(1 <= flag)
{
buffer_store<bytes>{}(value, res, v_offset, s_offset, i_offset);
}
}
};
#else
template <>
struct buffer_store_if<16>
{
@@ -613,6 +747,7 @@ struct buffer_store_if<1>
: "memory");
}
};
#endif
CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0)
{
@@ -1134,7 +1269,7 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
// Direct loads from global to LDS.
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
__attribute__((address_space(3))) uint32_t* lds_ptr,
as3_uint32_ptr lds_ptr,
index_t size,
index_t voffset,
index_t soffset,
@@ -1179,6 +1314,17 @@ enum struct amd_buffer_coherence_enum
glc = 1,
slc = 2,
glc_slc = 3,
// gfx94: bit 0 = sc0, bit 1 = nt, bit 3 = swz, bit 4 = sc1
// SC[1:0] System Cache level: 0=wave, 1=group, 2=device, 3=system
// NT Non-Temporal: 0=expect temporal reuse; 1=do not expect temporal reuse
WAVE_NT0 = 0,
WAVE_NT1 = 2,
GROUP_NT0 = 1,
GROUP_NT1 = 3,
DEVICE_NT0 = 8,
DEVICE_NT1 = 10,
SYSTEM_NT0 = 9,
SYSTEM_NT1 = 11,
};
template <index_t N,
@@ -1301,8 +1447,10 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
static_assert(
(std::is_same<T, double>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, fp16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, bf16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, fp16_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) ||
(std::is_same<T, bf16_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) ||
(std::is_same<T, int32_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
@@ -1425,6 +1573,54 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
return bit_cast<rtn_type>(tmp);
}
else if constexpr(N == 16)
{
thread_buffer<float, 8> tmp;
tmp.template get_as<fp32x4_t>()(number<0>{}) =
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
tmp.template get_as<fp32x4_t>()(number<1>{}) =
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(float),
static_cast<index_t>(coherence));
return bit_cast<rtn_type>(tmp);
}
else if constexpr(N == 32)
{
thread_buffer<float, 16> tmp;
tmp.template get_as<fp32x4_t>()(number<0>{}) =
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
tmp.template get_as<fp32x4_t>()(number<1>{}) =
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(float),
static_cast<index_t>(coherence));
tmp.template get_as<fp32x4_t>()(number<2>{}) =
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 8 * sizeof(float),
static_cast<index_t>(coherence));
tmp.template get_as<fp32x4_t>()(number<3>{}) =
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 12 * sizeof(float),
static_cast<index_t>(coherence));
return bit_cast<rtn_type>(tmp);
}
}
else if constexpr(std::is_same<T, bf16_t>::value) // bf16
{
@@ -1461,6 +1657,54 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
return bit_cast<rtn_type>(tmp);
}
else if constexpr(N == 16)
{
thread_buffer<float, 8> tmp;
tmp.template get_as<fp32x4_t>()(number<0>{}) =
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
tmp.template get_as<fp32x4_t>()(number<1>{}) =
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(float),
static_cast<index_t>(coherence));
return bit_cast<rtn_type>(tmp);
}
else if constexpr(N == 32)
{
thread_buffer<float, 16> tmp;
tmp.template get_as<fp32x4_t>()(number<0>{}) =
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
tmp.template get_as<fp32x4_t>()(number<1>{}) =
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(float),
static_cast<index_t>(coherence));
tmp.template get_as<fp32x4_t>()(number<2>{}) =
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 8 * sizeof(float),
static_cast<index_t>(coherence));
tmp.template get_as<fp32x4_t>()(number<3>{}) =
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 12 * sizeof(float),
static_cast<index_t>(coherence));
return bit_cast<rtn_type>(tmp);
}
}
else // other datatype
{
@@ -1515,7 +1759,7 @@ template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool pre_nop = false>
CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem,
CK_TILE_DEVICE void amd_async_buffer_load_impl(CK_TILE_LDS_ADDR T* smem,
int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset,
@@ -1545,29 +1789,35 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
index_t flag = 0,
bool_constant<oob_conditional_check> = {})
{
static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size");
constexpr index_t bytes = sizeof(T) * N;
// Used to catch the cases when src_immediate_addr_offset is NOT 0.
// Remove this assert once other sizes are implemented.
assert(src_immediate_addr_offset == 0 &&
"wrong! not implemented src_immediate_addr_offset size, only 0 supported");
ignore = src_immediate_addr_offset;
#if defined(__gfx950__)
static_assert(bytes == 4 || bytes == 12 || bytes == 16,
"wrong! only support in dword, dwordx3, dwordx4");
src_wave_addr_offset = 0;
#else
static_assert(bytes == 4, "wrong! not implemented vector size");
#endif
// Set up v_offset:
index_t v_offset = src_thread_addr_offset;
if constexpr(oob_conditional_check)
{
index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2];
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
smem,
sizeof(uint32_t),
v_offset,
src_wave_addr_offset,
src_immediate_addr_offset,
static_cast<index_t>(coherence));
}
else
{
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
smem,
sizeof(uint32_t),
src_thread_addr_offset,
src_wave_addr_offset,
src_immediate_addr_offset,
static_cast<index_t>(coherence));
}
v_offset = flag ? v_offset : src_wave_buffer_resource[2];
llvm_amdgcn_raw_buffer_load_lds(
src_wave_buffer_resource,
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
bytes,
v_offset,
src_wave_addr_offset,
/*src_immediate_addr_offset*/ 0,
static_cast<index_t>(coherence));
}
template <index_t N,
@@ -2511,44 +2761,45 @@ CK_TILE_DEVICE void amd_buffer_atomic_max(const thread_buffer<T, N>& src_thread_
#endif
}
template <typename T, index_t NumElemsPerThread>
CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
const index_t global_offset,
T* lds_base_ptr,
const index_t lds_offset,
const bool is_valid,
const index_t src_element_space_size)
#if defined(__gfx950__)
template <typename T, index_t N, address_space_enum BufferAddressSpace>
__device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
{
// Direct loads require that each thread reads and writes exactly a single DWORD.
constexpr auto dword_bytes = 4;
constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread;
static_assert(bytes_per_thread == dword_bytes);
const uint32_t* global_ptr =
reinterpret_cast<uint32_t*>(reinterpret_cast<uintptr_t>(global_base_ptr));
const int32x4_t src_resource =
make_wave_buffer_resource(global_ptr, src_element_space_size * sizeof(T));
const index_t global_offset_bytes = is_valid ? global_offset * sizeof(T) : 0x80000000;
#if CK_TILE_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM
T* lds_ptr = lds_base_ptr + lds_offset;
auto const lds_ptr_sgpr =
__builtin_amdgcn_readfirstlane((reinterpret_cast<uintptr_t>(lds_ptr)));
asm volatile("s_mov_b32 m0, %0; \n\t"
"buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr),
"v"(global_offset_bytes),
"s"(src_resource)
: "memory");
#else
// LDS pointer must be attributed with the LDS address space.
__attribute__((address_space(3))) uint32_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset));
llvm_amdgcn_raw_buffer_load_lds(
src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0);
#endif
static_assert(__has_builtin(__builtin_amdgcn_raw_buffer_load_b32),
"We need to have the compatible compiler version to build this instruction");
if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::half_t>)
{
typedef __attribute__((__vector_size__(4 * sizeof(__fp16)))) __fp16 llvm_fp16x4_t;
__attribute__((address_space(3))) llvm_fp16x4_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) llvm_fp16x4_t*>(
reinterpret_cast<uintptr_t>(in_ptr));
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4f16(lds_ptr));
}
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::bf16_t>)
{
typedef __attribute__((__vector_size__(4 * sizeof(__bf16)))) __bf16 llvm_bf16x4_t;
__attribute__((address_space(3))) llvm_bf16x4_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) llvm_bf16x4_t*>(
reinterpret_cast<uintptr_t>(in_ptr));
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4bf16(lds_ptr));
}
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::fp8_t> ||
std::is_same_v<remove_cvref_t<T>, ck_tile::bf8_t> ||
std::is_same_v<remove_cvref_t<T>, ck_tile::int8_t>)
{
typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_i32x2_t;
__attribute__((address_space(3))) llvm_i32x2_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) llvm_i32x2_t*>(
reinterpret_cast<uintptr_t>(in_ptr));
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr8_b64_v2i32(lds_ptr));
}
else
{
static_assert(false, "not implemented");
}
}
#endif
} // namespace ck_tile

View File

@@ -13,6 +13,9 @@
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/ignore.hpp"
using as3_uint32_ptr = uint32_t __attribute__((address_space(3)))*;
namespace ck_tile {
@@ -29,10 +32,6 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz
{
buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD};
int32x4_t r = __builtin_bit_cast(int32x4_t, res);
r.x = __builtin_amdgcn_readfirstlane(r.x);
r.y = __builtin_amdgcn_readfirstlane(r.y);
r.z = __builtin_amdgcn_readfirstlane(r.z);
r.w = __builtin_amdgcn_readfirstlane(r.w);
return r;
}
@@ -881,95 +880,95 @@ CK_TILE_DEVICE_EXTERN int8_t
llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i8.v4i32");
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i8");
CK_TILE_DEVICE_EXTERN int8x2_t
llvm_amdgcn_raw_buffer_load_i8x2(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i8.v4i32");
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i8");
CK_TILE_DEVICE_EXTERN int8x4_t
llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8.v4i32");
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8");
// buffer load i16
CK_TILE_DEVICE_EXTERN int16_t
llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i16.v4i32");
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i16");
CK_TILE_DEVICE_EXTERN int16x2_t
llvm_amdgcn_raw_buffer_load_i16x2(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i16.v4i32");
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i16");
CK_TILE_DEVICE_EXTERN int16x4_t
llvm_amdgcn_raw_buffer_load_i16x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i16.v4i32");
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i16");
// buffer load i32
CK_TILE_DEVICE_EXTERN int32_t
llvm_amdgcn_raw_buffer_load_i32(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32.v4i32");
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32");
CK_TILE_DEVICE_EXTERN int32x2_t
llvm_amdgcn_raw_buffer_load_i32x2(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i32.v4i32");
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i32");
CK_TILE_DEVICE_EXTERN int32x4_t
llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i32.v4i32");
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i32");
// buffer load fp16
CK_TILE_DEVICE_EXTERN _Float16
llvm_amdgcn_raw_buffer_load_fp16(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16.v4i32");
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16");
CK_TILE_DEVICE_EXTERN fp16x2_t llvm_amdgcn_raw_buffer_load_fp16x2(
int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16.v4i32");
CK_TILE_DEVICE_EXTERN fp16x2_t
llvm_amdgcn_raw_buffer_load_fp16x2(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16");
CK_TILE_DEVICE_EXTERN fp16x4_t llvm_amdgcn_raw_buffer_load_fp16x4(
int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f16.v4i32");
CK_TILE_DEVICE_EXTERN fp16x4_t
llvm_amdgcn_raw_buffer_load_fp16x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f16");
// buffer load fp32
CK_TILE_DEVICE_EXTERN float
llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f32.v4i32");
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f32");
CK_TILE_DEVICE_EXTERN fp32x2_t llvm_amdgcn_raw_buffer_load_fp32x2(
int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f32.v4i32");
CK_TILE_DEVICE_EXTERN fp32x2_t
llvm_amdgcn_raw_buffer_load_fp32x2(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f32");
CK_TILE_DEVICE_EXTERN fp32x4_t llvm_amdgcn_raw_buffer_load_fp32x4(
int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f32.v4i32");
CK_TILE_DEVICE_EXTERN fp32x4_t
llvm_amdgcn_raw_buffer_load_fp32x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f32");
// buffer store i8
CK_TILE_DEVICE_EXTERN void
@@ -977,21 +976,21 @@ llvm_amdgcn_raw_buffer_store_i8(int8_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i8.v4i32");
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i8");
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i8x2(int8x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i8.v4i32");
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i8");
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i8.v4i32");
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i8");
// buffer store i16
CK_TILE_DEVICE_EXTERN void
@@ -999,21 +998,21 @@ llvm_amdgcn_raw_buffer_store_i16(int16_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16.v4i32");
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16");
CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_i16x2(
int16x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16.v4i32");
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i16x2(int16x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16");
CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_i16x4(
int16x4_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i16.v4i32");
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i16x4(int16x4_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i16");
// buffer store i32
CK_TILE_DEVICE_EXTERN void
@@ -1021,7 +1020,7 @@ llvm_amdgcn_raw_buffer_store_i32(int32_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i32.v4i32");
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i32");
// buffer store ui16
CK_TILE_DEVICE_EXTERN void
@@ -1029,35 +1028,35 @@ llvm_amdgcn_raw_buffer_store_ui16(uint16_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16.v4i32");
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16");
CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_ui16x2(
uint16x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16.v4i32");
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_ui16x2(uint16x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16");
CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_ui16x4(
uint16x4_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i16.v4i32");
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_ui16x4(uint16x4_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i16");
CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_i32x2(
int32x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i32.v4i32");
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i32x2(int32x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i32");
CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_i32x4(
int32x4_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i32.v4i32");
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i32");
// buffer store fp16
CK_TILE_DEVICE_EXTERN void
@@ -1065,21 +1064,21 @@ llvm_amdgcn_raw_buffer_store_fp16(_Float16 vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f16.v4i32");
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f16");
CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_fp16x2(
fp16x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f16.v4i32");
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_fp16x2(fp16x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f16");
CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_fp16x4(
fp16x4_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f16.v4i32");
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_fp16x4(fp16x4_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f16");
// buffer store fp32
CK_TILE_DEVICE_EXTERN void
@@ -1087,21 +1086,21 @@ llvm_amdgcn_raw_buffer_store_fp32(float vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f32.v4i32");
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f32");
CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_fp32x2(
fp32x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f32.v4i32");
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_fp32x2(fp32x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f32");
CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_fp32x4(
fp32x4_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32.v4i32");
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_fp32x4(fp32x4_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32");
// buffer atomic-add fp16
CK_TILE_DEVICE_EXTERN fp16x2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2(
@@ -1109,7 +1108,7 @@ CK_TILE_DEVICE_EXTERN fp16x2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2(
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16.v4i32");
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16");
// buffer atomic-add i32
CK_TILE_DEVICE_EXTERN int32_t llvm_amdgcn_raw_buffer_atomic_add_i32(
@@ -1117,7 +1116,7 @@ CK_TILE_DEVICE_EXTERN int32_t llvm_amdgcn_raw_buffer_atomic_add_i32(
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.add.i32.v4i32");
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.add.i32");
// buffer atomic-add fp32
CK_TILE_DEVICE_EXTERN float llvm_amdgcn_raw_buffer_atomic_add_fp32(
@@ -1125,25 +1124,25 @@ CK_TILE_DEVICE_EXTERN float llvm_amdgcn_raw_buffer_atomic_add_fp32(
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32.v4i32");
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32");
// buffer atomic-max fp64
CK_TILE_DEVICE_EXTERN double llvm_amdgcn_raw_buffer_atomic_max_fp64(
double vdata,
int32x4_t rsrc, // dst_wave_buffer_resource
int voffset, // dst_thread_addr_offset
int soffset, // dst_wave_addr_offset
int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64.v4i32");
CK_TILE_DEVICE_EXTERN double
llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
int32x4_t rsrc, // dst_wave_buffer_resource
int voffset, // dst_thread_addr_offset
int soffset, // dst_wave_addr_offset
int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64");
// Direct loads from global to LDS.
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
__attribute__((address_space(3))) uint32_t* lds_ptr,
as3_uint32_ptr lds_ptr,
index_t size,
index_t voffset,
index_t soffset,
index_t offset,
index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds.v4i32");
index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds");
template <bool pre_nop = false>
CK_TILE_DEVICE void async_buffer_load_dword_v(void* smem,
@@ -1183,6 +1182,17 @@ enum struct amd_buffer_coherence_enum
glc = 1,
slc = 2,
glc_slc = 3,
// gfx94: bit 0 = sc0, bit 1 = nt, bit 3 = swz, bit 4 = sc1
// SC[1:0] System Cache level: 0=wave, 1=group, 2=device, 3=system
// NT Non-Temporal: 0=expect temporal reuse; 1=do not expect temporal reuse
WAVE_NT0 = 0,
WAVE_NT1 = 2,
GROUP_NT0 = 1,
GROUP_NT1 = 3,
DEVICE_NT0 = 8,
DEVICE_NT1 = 10,
SYSTEM_NT0 = 9,
SYSTEM_NT1 = 11,
};
template <index_t N,
@@ -1549,29 +1559,35 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
index_t flag = 0,
bool_constant<oob_conditional_check> = {})
{
static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size");
constexpr index_t bytes = sizeof(T) * N;
// Used to catch the cases when src_immediate_addr_offset is NOT 0.
// Remove this assert once other sizes are implemented.
assert(src_immediate_addr_offset == 0 &&
"wrong! not implemented src_immediate_addr_offset size, only 0 supported");
ignore = src_immediate_addr_offset;
#if defined(__gfx950__)
static_assert(bytes == 4 || bytes == 12 || bytes == 16,
"wrong! only support in dword, dwordx3, dwordx4");
src_wave_addr_offset = 0;
#else
static_assert(bytes == 4, "wrong! not implemented vector size");
#endif
// Set up v_offset:
index_t v_offset = src_thread_addr_offset;
if constexpr(oob_conditional_check)
{
index_t v_offset = flag ? v_offset : src_wave_buffer_resource[2];
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
smem,
sizeof(uint32_t),
v_offset,
src_wave_addr_offset,
src_immediate_addr_offset,
static_cast<index_t>(coherence));
}
else
{
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
smem,
sizeof(uint32_t),
src_thread_addr_offset,
src_wave_addr_offset,
src_immediate_addr_offset,
static_cast<index_t>(coherence));
}
v_offset = flag ? v_offset : src_wave_buffer_resource[2];
llvm_amdgcn_raw_buffer_load_lds(
src_wave_buffer_resource,
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
bytes,
v_offset,
src_wave_addr_offset,
/*src_immediate_addr_offset*/ 0,
static_cast<index_t>(coherence));
}
template <index_t N,
@@ -2523,11 +2539,6 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
const bool is_valid,
const index_t src_element_space_size)
{
// Direct loads require that each thread reads and writes exactly a single DWORD.
constexpr auto dword_bytes = 4;
constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread;
static_assert(bytes_per_thread == dword_bytes);
const uint32_t* global_ptr =
reinterpret_cast<uint32_t*>(reinterpret_cast<uintptr_t>(global_base_ptr));
const int32x4_t src_resource =
@@ -2544,16 +2555,70 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
"s"(src_resource)
: "memory");
#else
// Direct loads require that each thread reads and writes exactly a single DWORD.
#if defined(__gfx9__)
constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread;
#endif
// Direct loads require that each thread reads and writes a multiple of DWORDs (4 bytes).
// For gfx950: supports 1, 3, or 4 DWORDs per thread
// For gfx942: supports exactly 1 DWORD per thread
#if defined(__gfx950__)
constexpr auto dword_bytes = 4;
static_assert(bytes_per_thread == dword_bytes || bytes_per_thread == dword_bytes * 3 ||
bytes_per_thread == dword_bytes * 4);
#elif defined(__gfx9__)
constexpr auto dword_bytes = 4;
static_assert(bytes_per_thread == dword_bytes);
#endif
// LDS pointer must be attributed with the LDS address space.
__attribute__((address_space(3))) uint32_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset));
as3_uint32_ptr lds_ptr =
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset));
llvm_amdgcn_raw_buffer_load_lds(
src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0);
src_resource, lds_ptr, bytes_per_thread, global_offset_bytes, 0, 0, 0);
#endif
}
#if defined(__gfx950__)
template <typename T, index_t N, address_space_enum BufferAddressSpace>
__device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
{
static_assert(__has_builtin(__builtin_amdgcn_raw_buffer_load_b32),
"We need to have the compatible compiler version to build this instruction");
if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::half_t>)
{
typedef __attribute__((__vector_size__(4 * sizeof(__fp16)))) __fp16 llvm_fp16x4_t;
__attribute__((address_space(3))) llvm_fp16x4_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) llvm_fp16x4_t*>(
reinterpret_cast<uintptr_t>(in_ptr));
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4f16(lds_ptr));
}
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::bf16_t>)
{
typedef __attribute__((__vector_size__(4 * sizeof(__bf16)))) __bf16 llvm_bf16x4_t;
__attribute__((address_space(3))) llvm_bf16x4_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) llvm_bf16x4_t*>(
reinterpret_cast<uintptr_t>(in_ptr));
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4bf16(lds_ptr));
}
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::fp8_t> ||
std::is_same_v<remove_cvref_t<T>, ck_tile::bf8_t> ||
std::is_same_v<remove_cvref_t<T>, ck_tile::int8_t>)
{
typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_i32x2_t;
__attribute__((address_space(3))) llvm_i32x2_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) llvm_i32x2_t*>(
reinterpret_cast<uintptr_t>(in_ptr));
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr8_b64_v2i32(lds_ptr));
}
else
{
static_assert(false, "not implemented");
}
}
#endif
} // namespace ck_tile
#endif // CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN

View File

@@ -0,0 +1,88 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
namespace ck_tile {
// this generate wave level tile distribution
template <typename T, index_t LaneGroupSize = 16, typename = void>
struct LaneGroupTransposeTraits;
template <typename T, index_t LaneGroupSize>
struct LaneGroupTransposeTraits<T, LaneGroupSize, std::enable_if_t<sizeof(T) == 2>>
{
static_assert(LaneGroupSize == 16 || LaneGroupSize == 32 || LaneGroupSize == 64,
"LaneGroupSize must be 16, 32, or 64");
// before transpose, 4x16
static constexpr index_t ksecondDim = 4;
static constexpr index_t kleadDim = LaneGroupSize;
// after transpose, 16x4
static constexpr index_t ksecondDimT = LaneGroupSize;
static constexpr index_t kleadDimT = 4;
template <index_t kOuterDistDim0,
index_t kOuterDistDim1,
index_t kInnerDistDim0,
index_t kInnerDistDim1>
using TileDistribution = tile_distribution_encoding<
sequence<>,
tuple<sequence<kOuterDistDim0, kOuterDistDim1, 4>,
sequence<kInnerDistDim0, kInnerDistDim1, LaneGroupSize / 16, 4, 4>>,
tuple<sequence<1, 2, 2, 1, 2>>,
tuple<sequence<0, 0, 2, 2, 3>>,
sequence<2, 1, 2>,
sequence<1, 1, 4>>;
};
template <typename T, index_t LaneGroupSize>
struct LaneGroupTransposeTraits<T, LaneGroupSize, std::enable_if_t<sizeof(T) == 1>>
{
static constexpr index_t ksecondDim = 8;
static constexpr index_t kleadDim = LaneGroupSize;
static constexpr index_t ksecondDimT = LaneGroupSize;
static constexpr index_t kleadDimT = 8;
template <index_t kOuterDistDim0,
index_t kOuterDistDim1,
index_t kInnerDistDim0,
index_t kInnerDistDim1>
using TileDistribution = tile_distribution_encoding<
sequence<>,
tuple<sequence<kOuterDistDim0, kOuterDistDim1, 8>,
sequence<kInnerDistDim0, kInnerDistDim1, LaneGroupSize / 16, 2, 8>>,
tuple<sequence<1, 2, 2, 1, 2>>,
tuple<sequence<0, 0, 2, 2, 3>>,
sequence<2, 1, 2>,
sequence<1, 1, 4>>;
};
/*
* @brief This function is used to generate the transposed distribution encoding
* for the given data type and distribution dimensions.
*
* @tparam T The data type of the elements in the tensor.
* @tparam kOuterDistDim0 The outer distribution dimension 0, which is outer dimension for stride.
* @tparam kOuterDistDim1 The outer distribution dimension 1, which is inner dimension for stride.
* @tparam kInnerDistDim0 The inner distribution dimension 0, which is outer dimension for
* consecutive.
* @tparam kInnerDistDim1 The inner distribution dimension 1, which is inner dimension for
* consecutive.
*/
template <typename T,
index_t LaneGroupSize,
index_t kOuterDistDim0,
index_t kOuterDistDim1,
index_t kInnerDistDim0,
index_t kInnerDistDim1>
CK_TILE_DEVICE constexpr auto make_transposed_distr_encode()
{
return typename LaneGroupTransposeTraits<T, LaneGroupSize>::
template TileDistribution<kOuterDistDim0, kOuterDistDim1, kInnerDistDim0, kInnerDistDim1>{};
}
} // namespace ck_tile

View File

@@ -9,6 +9,16 @@
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/ignore.hpp"
#define CK_TILE_S_CNT_MAX 0b1100'1111'0111'1111
#define CK_TILE_VMCNT(cnt) \
([]() { static_assert(!((cnt) >> 6), "VMCNT only has 6 bits"); }(), \
((cnt) & 0b1111) | (((cnt) & 0b110000) << 10))
#define CK_TILE_EXPCNT(cnt) \
([]() { static_assert(!((cnt) >> 3), "EXP only has 3 bits"); }(), ((cnt) << 4))
#define CK_TILE_LGKMCNT(cnt) \
([]() { static_assert(!((cnt) >> 4), "LGKM only has 4 bits"); }(), ((cnt) << 8))
namespace ck_tile {
@@ -50,8 +60,11 @@ enum struct memory_operation_enum : std::uint16_t
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
{
// warpSize is defined by HIP
return warpSize;
#if defined(__GFX9__) || (!defined(__HIP_DEVICE_COMPILE__) && !defined(CK_TILE_WAVE32_ENABLED))
return 64;
#else
return 32;
#endif
}
CK_TILE_DEVICE index_t get_grid_size() { return gridDim.x; }
@@ -81,21 +94,6 @@ CK_TILE_DEVICE index_t get_thread_id() { return threadIdx.x; }
CK_TILE_DEVICE index_t get_block_id() { return blockIdx.x; }
CK_TILE_DEVICE void block_sync_lds()
{
#if CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
// asm volatile("\
// s_waitcnt lgkmcnt(0) \n \
// s_barrier \
// " ::);
__builtin_amdgcn_s_waitcnt(0xc07f);
__builtin_amdgcn_s_barrier();
#else
__syncthreads();
#endif
}
CK_TILE_DEVICE void block_sync_load_raw(index_t cnt = 0)
{
#ifdef __gfx12__
@@ -114,13 +112,68 @@ CK_TILE_DEVICE void block_sync_load_raw(index_t cnt = 0)
#endif
}
// https://llvm.org/docs/AMDGPU/gfx9_waitcnt.html
struct waitcnt_arg
{
// bit numbers (hex) -------------------------> FE'DC'BA98'7'654'3210
// [V]M [E]XP [L]GKM counters and [U]NUSED ---> VV'UU'LLLL'U'EEE'VVVV
CK_TILE_DEVICE static constexpr index_t MAX = 0b11'00'1111'0'111'1111;
CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0b111111;
CK_TILE_DEVICE static constexpr index_t kMaxExpCnt = 0b111;
CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0b1111;
template <index_t cnt>
CK_TILE_DEVICE static constexpr index_t from_vmcnt()
{
static_assert(cnt >= 0 && !(cnt >> 6), "valid range is [0..63]");
return MAX & ((cnt & 0b1111) | ((cnt & 0b110000) << 10));
}
template <index_t cnt>
CK_TILE_DEVICE static constexpr index_t from_expcnt()
{
static_assert(cnt >= 0 && !(cnt >> 3), "valid range is [0..7]");
return MAX & (cnt << 4);
}
template <index_t cnt>
CK_TILE_DEVICE static constexpr index_t from_lgkmcnt()
{
static_assert(cnt >= 0 && !(cnt >> 4), "valid range is [0..15]");
return MAX & (cnt << 8);
}
};
template <index_t vmcnt = waitcnt_arg::kMaxVmCnt,
index_t expcnt = waitcnt_arg::kMaxExpCnt,
index_t lgkmcnt = waitcnt_arg::kMaxLgkmCnt>
CK_TILE_DEVICE void s_waitcnt()
{
__builtin_amdgcn_s_waitcnt(waitcnt_arg::from_vmcnt<vmcnt>() |
waitcnt_arg::from_expcnt<expcnt>() |
waitcnt_arg::from_lgkmcnt<lgkmcnt>());
}
template <index_t vmcnt = waitcnt_arg::kMaxVmCnt,
index_t expcnt = waitcnt_arg::kMaxExpCnt,
index_t lgkmcnt = waitcnt_arg::kMaxLgkmCnt>
CK_TILE_DEVICE void s_waitcnt_barrier()
{
s_waitcnt<vmcnt, expcnt, lgkmcnt>();
__builtin_amdgcn_s_barrier();
}
template <index_t lgkmcnt = 0>
CK_TILE_DEVICE void block_sync_lds()
{
s_waitcnt_barrier<waitcnt_arg::kMaxVmCnt, waitcnt_arg::kMaxExpCnt, lgkmcnt>();
}
template <index_t vmcnt = 0>
CK_TILE_DEVICE void block_sync_lds_direct_load()
{
asm volatile("\
s_waitcnt vmcnt(0) \n \
s_waitcnt lgkmcnt(0) \n \
s_barrier \
" ::);
s_waitcnt_barrier<vmcnt, waitcnt_arg::kMaxExpCnt, waitcnt_arg::kMaxLgkmCnt>();
}
CK_TILE_DEVICE void s_nop(index_t cnt = 0)
@@ -158,4 +211,44 @@ __host__ __device__ T CK_CONSTANT_ADDRESS_SPACE* cast_pointer_to_constant_addres
#pragma clang diagnostic pop
}
CK_TILE_HOST_DEVICE constexpr index_t get_smem_capacity()
{
#if defined(__gfx950__)
return 163840;
#else
return 65536;
#endif
}
/// Helper function to convert address space enum to string
CK_TILE_HOST_DEVICE constexpr const char* address_space_to_string(address_space_enum addr_space)
{
switch(addr_space)
{
case address_space_enum::generic: return "generic";
case address_space_enum::global: return "global";
case address_space_enum::lds: return "lds";
case address_space_enum::sgpr: return "sgpr";
case address_space_enum::constant: return "constant";
case address_space_enum::vgpr: return "vgpr";
default: return "unknown";
}
}
// Architecture tags
struct gfx11_t
{
};
struct gfx12_t
{
};
CK_TILE_DEVICE static constexpr auto get_device_arch()
{
#if defined(__gfx11__)
return gfx11_t{};
#else // if defined(__gfx12__)
return gfx12_t{};
#endif
}
} // namespace ck_tile

View File

@@ -6,6 +6,10 @@
#include "ck_tile/core/numeric/type_convert.hpp"
#include "ck_tile/core/container/thread_buffer.hpp"
#define HAS_GLOBAL_ATOMIC_PK_ADD_BUILTIN \
__has_builtin(__builtin_amdgcn_global_atomic_fadd_v2f16) && \
__has_builtin(__builtin_amdgcn_global_atomic_fadd_v2bf16)
namespace ck_tile {
template <typename T, typename ComputeType>
@@ -32,6 +36,14 @@ CK_TILE_HOST_DEVICE bf16x4_t add_bf16x4_t(const bf16x4_t& a, const bf16x4_t& b)
return rtn;
}
CK_TILE_HOST_DEVICE fp16x2_t add_f16x2_t(const fp16x2_t& a, const fp16x2_t& b)
{
fp16x2_t rtn;
rtn[0] = add<fp16_t, float>(a[0], b[0]);
rtn[1] = add<fp16_t, float>(a[1], b[1]);
return rtn;
}
CK_TILE_HOST_DEVICE fp8x4_t add_fp8x4_t(const fp8x4_t& a, const fp8x4_t& b)
{
fp8x4_t rtn;
@@ -304,6 +316,44 @@ CK_TILE_DEVICE void atomic_add<bf8x8_t>(bf8x8_t* p_dst, bf8x8_t const& x)
} while(cur_v.u64 != old_v);
}
//
// Atomic add for fp16x2_t
//
template <>
CK_TILE_DEVICE void atomic_add<fp16x2_t>(fp16x2_t* p_dst, fp16x2_t const& x)
{
#if HAS_GLOBAL_ATOMIC_PK_ADD_BUILTIN
__builtin_amdgcn_global_atomic_fadd_v2f16(c_style_pointer_cast<fp16x2_t*>(p_dst), x);
#else
union U32F162_ADDR
{
uint32_t* u32_a;
fp16x2_t* f162_a;
};
union U32F162
{
uint32_t u32;
fp16x2_t f162;
};
U32F162_ADDR dword_addr;
U32F162 cur_v;
U32F162 new_;
uint32_t old_v, new_v;
dword_addr.f162_a = p_dst;
cur_v.u32 = *dword_addr.u32_a;
do
{
old_v = cur_v.u32;
new_.f162 = add_f16x2_t(cur_v.f162, x);
new_v = new_.u32;
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
} while(cur_v.u32 != old_v);
#endif
}
template <typename T, index_t N>
CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
{
@@ -311,6 +361,7 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
(std::is_same<T, uint32_t>::value && (N == 1)) ||
(std::is_same<T, float>::value && (N == 1 || N == 2)) ||
(std::is_same<T, double>::value && (N == 1 || N == 2)) ||
(std::is_same<T, fp16_t>::value && (N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, bf16_t>::value && (N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, fp8_t>::value && (N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, bf8_t>::value && (N == 4 || N == 8 || N == 16)),
@@ -406,6 +457,13 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
atomic_add(c_style_pointer_cast<bf8x8_t*>(p_dst) + 1, x.template get_as<bf8x8_t>()[I1]);
}
}
else if constexpr(std::is_same<T, fp16_t>::value)
{
static_for<0, N / 2, 1>{}([&](auto i) {
atomic_add(c_style_pointer_cast<fp16x2_t*>(p_dst) + i,
x.template get_as<fp16x2_t>()[i]);
});
}
}
template <typename T, index_t N>

View File

@@ -35,7 +35,7 @@ CK_TILE_DEVICE T warp_shuffle_up(const T& v_local, uint32_t lane_delta)
#elif 1
static_assert(sizeof(T) == sizeof(int32_t), "wrong!");
const uint32_t wrap_around_lane_delta = warpSize - lane_delta;
const uint32_t wrap_around_lane_delta = get_warp_size() - lane_delta;
const int32_t v_remote_tmp = __builtin_amdgcn_ds_bpermute(
(__lane_id() << 2) + (wrap_around_lane_delta << 2), bit_cast<int32_t>(v_local));
@@ -59,6 +59,21 @@ CK_TILE_DEVICE T warp_shuffle_down(const T& v_local, uint32_t lane_delta)
#endif
}
template <typename T>
CK_TILE_DEVICE auto warp_shuffle_down_pair(const T& v_local)
{
static_assert(sizeof(T) == sizeof(int32_t), "wrong!");
const int32x2_t x = __builtin_amdgcn_permlane32_swap(
bit_cast<int32_t>(v_local), bit_cast<int32_t>(v_local), false, false);
thread_buffer<T, 2> v;
v(0) = bit_cast<T>(x[0]);
v(1) = bit_cast<T>(x[1]);
return v;
}
template <typename T>
CK_TILE_DEVICE T warp_shuffle(const T& v_local, uint32_t src_lane)
{

View File

@@ -0,0 +1,65 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
namespace ck_tile {
struct workgroup_barrier
{
CK_TILE_DEVICE workgroup_barrier(uint32_t* ptr) : base_ptr(ptr) {}
CK_TILE_DEVICE uint32_t ld(uint32_t offset = 0)
{
return __atomic_load_n(base_ptr + offset, __ATOMIC_RELAXED);
}
CK_TILE_DEVICE void wait_eq(uint32_t value, uint32_t offset = 0)
{
if(threadIdx.x == 0)
{
while(ld(offset) != value) {}
}
__syncthreads();
}
CK_TILE_DEVICE void wait_lt(uint32_t value, uint32_t offset = 0)
{
if(threadIdx.x == 0)
{
while(ld(offset) < value) {}
}
__syncthreads();
}
CK_TILE_DEVICE void wait_set(uint32_t compare, uint32_t value, uint32_t offset = 0)
{
if(threadIdx.x == 0)
{
while(atomicCAS(base_ptr + offset, compare, value) != compare) {}
}
__syncthreads();
}
// enter critical zoon, assume buffer is zero when launch kernel
CK_TILE_DEVICE void aquire(uint32_t offset = 0) { wait_set(offset, 0, 1); }
// exit critical zoon, assume buffer is zero when launch kernel
CK_TILE_DEVICE void release(uint32_t offset = 0) { wait_set(offset, 1, 0); }
CK_TILE_DEVICE void inc(uint32_t offset = 0)
{
__syncthreads();
if(threadIdx.x == 0)
{
atomicAdd(base_ptr + offset, 1);
}
}
uint32_t* base_ptr;
};
} // namespace ck_tile

View File

@@ -15,7 +15,8 @@
#define __gfx103__
#endif
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \
defined(__gfx1103__) || defined(__gfx11_generic__)
defined(__gfx1103__) || defined(__gfx1150__) || defined(__gfx1151__) || \
defined(__gfx1152__) || defined(__gfx11_generic__)
#define __gfx11__
#endif
#if defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__)
@@ -28,12 +29,6 @@
#include "hip/hip_fp16.h"
#endif
#include "ck_tile/core/utility/env.hpp"
// environment variable to enable logging:
// export CK_TILE_LOGGING=ON or CK_TILE_LOGGING=1 or CK_TILE_LOGGING=ENABLED
CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_LOGGING)
#ifdef __HIPCC__
#define CK_TILE_HOST inline __host__
#define CK_TILE_DEVICE inline __device__
@@ -157,7 +152,7 @@ CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_LOGGING)
// buffer atomic add: floating point
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
#elif defined(__gfx9__) // for GPU code
#elif defined(__gfx9__) || defined(__gfx12__) // for GPU code
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
#else // for GPU code
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0
@@ -196,6 +191,16 @@ CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_LOGGING)
#endif
#endif
// use llvm builtin bf16 data type after ROCm 6.5
#ifndef CK_TILE_USE_LLVM_BUILTIN_BF16
#if(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 5 && HIP_VERSION_PATCH >= 50421) || \
(HIP_VERSION_MAJOR >= 7)
#define CK_TILE_USE_LLVM_BUILTIN_BF16 1
#else
#define CK_TILE_USE_LLVM_BUILTIN_BF16 0
#endif
#endif
#ifndef CK_TILE_DEBUG_LOG
#define CK_TILE_DEBUG_LOG 0
#endif
@@ -228,6 +233,10 @@ CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_LOGGING)
#define CK_TILE_FMHA_FWD_FAST_EXP2 0
#endif
#ifndef CK_TILE_FMHA_FLOAT_TO_FLOAT16_RTN
#define CK_TILE_FMHA_FLOAT_TO_FLOAT16_RTN 0
#endif
#ifndef CK_TILE_BUFFER_LOAD_RAW_BF16_WA
#define CK_TILE_BUFFER_LOAD_RAW_BF16_WA 1
#endif
@@ -241,22 +250,38 @@ CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_LOGGING)
#define CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID 1
#endif
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#ifdef CK_TILE_USE_OCP_FP8
#ifndef CK_TILE_USE_OCP_FP8
#if defined(__HIP_DEVICE_COMPILE__)
#if defined(__gfx950__) || defined(__gfx12__)
#define CK_TILE_USE_OCP_FP8 1
#else
#define CK_TILE_USE_OCP_FP8 0
#endif
#elif defined(__gfx950__) || defined(__gfx12__) // for GPU code
#define CK_TILE_USE_OCP_FP8 1
#else // for GPU code
#else
#define CK_TILE_USE_OCP_FP8 0
#endif
#endif
#ifndef CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
#if __clang_major__ == 20
#if __clang_major__ >= 20
#define CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN 1
#else
#define CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN 0
#endif
#endif
#ifndef CK_TILE_WA_ISSUE_2028
#define CK_TILE_WA_ISSUE_2028 0
#endif
#ifndef CK_TILE_WAVE32_ENABLED
#if defined(__gfx11__) || defined(__gfx12__)
#define CK_TILE_WAVE32_ENABLED
#endif
#endif
// Y pointed to R, we don't see a valuable use case.
// Will enforce encoding to check Y not pointed to R if set to zero
#ifndef CK_TILE_ENC_SUPPORT_Y_TO_R
#define CK_TILE_ENC_SUPPORT_Y_TO_R 0
#endif

View File

@@ -19,6 +19,25 @@ namespace ck_tile {
// array<index_t, 4> buf {3, 2}; => {3, 2, 2, 2} (not {3,2,0,0})
// use make_array_with({...}) to construct an array with compatible behavior as old ck
// TODO: manually added constructor same as old ck
/**
* @brief A fixed-size array container similar to std::array with additional utilities.
*
* This template class provides a lightweight fixed-size array with value semantics,
* supporting both host and device functionality for GPU programming. It includes
* specialized initialization methods and type punning capabilities.
*
* @tparam T_ The type of elements in the array
* @tparam N_ The fixed number of elements in the array
*
* @note This implementation provides additional features beyond std::array:
* - GPU compatibility via CK_TILE_HOST_DEVICE macros
* - Type punning via get_as() and set_as() methods
* - Various specialized access methods
* - Specialized initialization behaviors
*
* The initializer_list constructor fills remaining elements with the last value
* provided if the list size is smaller than N, which is different than std::array.
*/
template <typename T_, index_t N_>
struct array
{
@@ -142,6 +161,14 @@ struct array
// empty Array
/// @brief Specialization of array container for zero elements.
///
/// This is a specialization of the array container template for the case where the number of
/// elements is 0. It provides the same interface as the general array template, but with operations
/// appropriate for an empty array.
///
/// @tparam T The type of elements stored in the array (not used in this specialization but
/// maintained for API consistency).
template <typename T>
struct array<T, 0>
{
@@ -150,9 +177,27 @@ struct array<T, 0>
CK_TILE_HOST_DEVICE constexpr array() {}
CK_TILE_HOST_DEVICE static constexpr index_t size() { return 0; }
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return is_static_v<T>; };
CK_TILE_HOST_DEVICE void print() const { printf("array{size: 0, data: []}"); }
};
template <typename T, index_t N>
CK_TILE_HOST_DEVICE static void print(const array<T, N>& a)
{
printf("array{size: %ld, data: [", static_cast<long>(N));
for(index_t i = 0; i < N; ++i)
{
if(i > 0)
printf(", ");
print(a[i]);
}
printf("]}");
}
template <typename T>
CK_TILE_HOST_DEVICE static void print(const array<T, 0>&)
{
printf("array{size: 0, data: []}");
}
template <typename, typename>
struct vector_traits;

View File

@@ -16,7 +16,7 @@ template <typename TData, index_t NSize>
CK_TILE_HOST_DEVICE constexpr auto container_push_back(const array<TData, NSize>& a, const TData& x)
{
array<TData, NSize + 1> r;
static_for<0, NSize, 1>{}([&r, &a ](auto i) constexpr { r(i) = a[i]; });
static_for<0, NSize, 1>{}([&r, &a](auto i) constexpr { r(i) = a[i]; });
r[number<NSize>{}] = x;
return r;
}

View File

@@ -139,26 +139,21 @@ struct map
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
CK_TILE_HOST_DEVICE constexpr iterator end() { return iterator{impl_, size_}; }
CK_TILE_HOST_DEVICE void print() const
{
printf("map{size_: %d, ", size_);
//
printf("impl_: [");
//
for(const auto& [k, d] : *this)
{
printf("{key: ");
print(k);
printf(", data: ");
print(d);
printf("}, ");
}
//
printf("]");
//
printf("}");
}
};
template <typename key, typename data, index_t max_size>
CK_TILE_HOST_DEVICE static void print(const map<key, data, max_size>& m)
{
printf("map{size_: %d, impl_: [", m.size_);
for(const auto& [k, d] : m)
{
printf("{key: ");
print(k);
printf(", data: ");
print(d);
printf("}, ");
}
printf("]}");
}
} // namespace ck_tile

View File

@@ -9,13 +9,10 @@
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/to_sequence.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/print.hpp"
namespace ck_tile {
template <index_t, index_t, index_t>
struct static_for;
template <index_t...>
struct sequence;
@@ -196,15 +193,24 @@ struct sequence
{
return sequence<f(Is)...>{};
}
CK_TILE_HOST_DEVICE static void print()
{
printf("sequence{size: %d, data: [", size());
((printf("%d ", Is)), ...);
printf("]}");
}
};
template <index_t... Is>
CK_TILE_HOST_DEVICE static void print(const sequence<Is...>&)
{
printf("sequence<");
if constexpr(sizeof...(Is) > 0)
{
bool first = true;
(([&first](index_t value) {
printf("%s%d", first ? "" : ", ", value);
first = false;
}(Is)),
...);
}
printf(">");
}
namespace impl {
template <typename T, T... Ints>
struct __integer_sequence;
@@ -1178,6 +1184,15 @@ struct reverse_slice_sequence_impl<sequence<x>, sequence<m>, sequence<id>, Slice
// clang-format off
// input a sequence(with optional mask), and the SliceSize : size per slice
// output the sequence each slice, and number of slices
// the length count for slice size is from right to left(reverse slice)
// or we can say, find the greatest common divider(gcd) from right to left, for the slice length
//
// e.g. <2, 8, 4>, slice length = 16
// step-1: we take the right most <*, *, 4>, remaining 16/4=4
// step-2: we only need 4 out of 8, of the midden dim, hence <*, 4, 4>
// step-3: since nonthing remain, so the first dim we only need 1, hence<1, 4, 4>
// => we got <1, 4, 4> as length for each slice
// => total number of slice = <2, 8, 4> / <1, 4, 4> = <2, 2, 1>
//
// e.g. <2, 1, 4, 2>, 8 -> lengths:<1, 1, 4, 2> , nums: <2, 1, 1, 1> : 2 slices , slice_idx: 0
// <4, 2, 4, 1, 2>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 1> : 16 slices , slice_idx: 2
@@ -1197,7 +1212,7 @@ struct reverse_slice_sequence_impl<sequence<x>, sequence<m>, sequence<id>, Slice
//
// return tuple<slice_lengths, slice_nums, slice_index>, slice_index is at which index will start
// have split slices (right -> left)
// or the first index that sliced length is different from the original length
// or the first index (right -> left) that sliced length is different from the original length
// clang-format on
template <typename Seq,
index_t SliceSize,
@@ -1207,6 +1222,11 @@ constexpr auto reverse_slice_sequence(Seq,
Mask = typename uniform_sequence_gen<Seq::size(), 1>::type{})
{
static_assert(Seq::size() == Mask::size());
static_assert(SliceSize != 0, "slice size zero is invalid");
static_assert(container_reduce(pick_sequence_elements_by_mask(Seq{}, Mask{}), multiplies{}, 1) %
SliceSize ==
0,
"slice size can't evenly divide input sizes");
using sliced_type =
impl::reverse_slice_sequence_impl<Seq,
Mask,
@@ -1222,9 +1242,8 @@ constexpr auto reverse_slice_sequence(Seq,
template <typename Seq,
index_t SliceSize,
typename Mask = typename uniform_sequence_gen<Seq::size(), 1>::type>
constexpr auto slice_sequence(Seq,
number<SliceSize>,
Mask = typename uniform_sequence_gen<Seq::size(), 1>::type{})
constexpr auto
slice_sequence(Seq, number<SliceSize>, Mask = typename uniform_sequence_gen<Seq::size(), 1>::type{})
{
constexpr auto r =
reverse_slice_sequence(Seq{}.reverse(), number<SliceSize>{}, Mask{}.reverse());

View File

@@ -42,7 +42,11 @@ struct thread_buffer {
// TODO: this ctor can't ignore
CK_TILE_HOST_DEVICE constexpr thread_buffer() : data{} {}
CK_TILE_HOST_DEVICE constexpr thread_buffer(const value_type & o) : data{o} {}
CK_TILE_HOST_DEVICE constexpr thread_buffer(const value_type & o) : data{} {
static_for<0, N, 1>{}(
[&](auto i) { data[i] = o; }
);
}
CK_TILE_HOST_DEVICE static constexpr auto size() { return N; }
CK_TILE_HOST_DEVICE auto & get() {return data; }

View File

@@ -262,12 +262,18 @@ struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
return flag;
}
CK_TILE_HOST_DEVICE static constexpr bool IsTuple() { return true; }
#define TP_COM_() static_assert(I < size(), "wrong! out of range")
// clang-format off
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() const { TP_COM_(); return impl::getv<I>(*this); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) const { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() { TP_COM_(); return impl::getv<I>(*this); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() const & { TP_COM_(); return impl::getv<I>(*this); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) const & { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() & { TP_COM_(); return impl::getv<I>(*this); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) & { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() && { TP_COM_(); return impl::getv<I>(std::move(*this)); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) && { TP_COM_(); return std::move(*this).template get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() const && { TP_COM_(); return impl::getv<I>(std::move(*this)); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) const &&{ TP_COM_(); return std::move(*this).template get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) at() const { TP_COM_(); return impl::getv<I>(*this); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) at(number<I>) const { TP_COM_(); return get<I>(); }
@@ -294,12 +300,29 @@ struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
#undef TP_COM_
};
template <typename, typename = void>
template <typename... T>
CK_TILE_HOST_DEVICE void print(const tuple<T...>& t)
{
printf("tuple<");
if constexpr(sizeof...(T) > 0)
{
bool first = true;
static_for<0, sizeof...(T), 1>{}([&t, &first](auto i) {
if(!first)
printf(", ");
print(t.get(i));
first = false;
});
}
printf(">");
}
template <typename, typename>
struct vector_traits;
// specialization for array
template <typename... T>
struct vector_traits<tuple<T...>>
struct vector_traits<tuple<T...>, void>
{
using scalar_type = __type_pack_element<0, T...>;
static constexpr index_t vector_size = sizeof...(T);
@@ -396,11 +419,16 @@ struct tuple_array_impl<T, 1>
};
} // namespace impl
template <typename F, index_t... ids>
CK_TILE_HOST_DEVICE constexpr auto generate_tuple_for(F&& f, sequence<ids...>)
{
return make_tuple(f(number<ids>{})...);
}
template <typename F, index_t N>
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F&& f, number<N>)
{
return unpack([&f](auto&&... is) { return make_tuple(f(is)...); },
typename arithmetic_sequence_gen<0, N, 1>::type{});
return generate_tuple_for(f, make_index_sequence<N>{});
}
template <typename F, index_t N>
@@ -465,6 +493,12 @@ transform_tuples_impl(F f, const X& x, const Y& y, const Z& z, sequence<Is...>)
return make_tuple(f(x.at(number<Is>{}), y.at(number<Is>{}), z.at(number<Is>{}))...);
}
template <typename F, typename Tuple, index_t... Is>
constexpr decltype(auto) apply_impl(F&& f, Tuple&& t, sequence<Is...>)
{
return std::forward<F>(f)(std::forward<Tuple>(t).get(number<Is>{})...);
}
} // namespace detail
template <typename F, typename X>
@@ -488,6 +522,13 @@ CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y,
f, x, y, z, typename arithmetic_sequence_gen<0, X::size(), 1>::type{});
}
template <typename F, typename Tuple>
constexpr decltype(auto) apply(F&& f, Tuple&& t)
{
constexpr index_t N = std::decay_t<Tuple>::size();
return detail::apply_impl(std::forward<F>(f), std::forward<Tuple>(t), make_index_sequence<N>{});
}
namespace detail {
template <typename F, typename X, index_t... Is>

View File

@@ -6,6 +6,9 @@
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#if CK_TILE_USE_LLVM_BUILTIN_BF16
#include <hip/hip_bfloat16.h>
#endif
#include <stdint.h>
#pragma once
@@ -102,7 +105,11 @@ struct native_t<bfloat16_t>
using bf16_t = bfloat16_t;
using bf16_raw_t = typename bf16_t::raw_type;
#else
#if CK_TILE_USE_LLVM_BUILTIN_BF16
using bfloat16_t = __bf16;
#else
using bfloat16_t = ushort;
#endif
using bf16_t = bfloat16_t;
using bf16_raw_t = uint16_t;
#endif
@@ -280,7 +287,11 @@ template <bf16_rounding_mode rounding =
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
CK_TILE_HOST_DEVICE constexpr bfloat16_t float_to_bf16(float f, constant<rounding> = {})
{
#if defined(__gfx950__)
return static_cast<bfloat16_t>(f);
#else
return bit_cast<bfloat16_t>(float_to_bf16_raw(f, constant<rounding>{}));
#endif
}
template <bf16_rounding_mode rounding =

View File

@@ -0,0 +1,102 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/mxfp_convert.hpp"
namespace ck_tile {
/**
* @brief Unsigned representation of a conventional biased Float32 exponent.
*
* bias = 127;
*
* E8M0_1 = 0b01111111; => 2^(127-127) = 1
* E8M0_2 = 0b10000000; => 2^(128-127) = 2^1 = 2
* E8M0_3 = 0b10000010; => 2^(130-127) = 2^3 = 8
* E8M0_135 = 0b10000111; => 2^(135-127) = 2^8 = 256
* E8M0_142 = 0b10001110; => 2^(142-127) = 2^15 = 32768
* E8M0_MIN = 0b00000000; => 2^-127
* E8M0_MAX = 0b11111110; => 2^127
* E8M0_NAN = 0b11111111; => NaN
*/
struct e8m0_bexp_t
{
using raw_type = uint8_t;
using type = raw_type;
raw_type data;
CK_TILE_HOST_DEVICE constexpr e8m0_bexp_t() : data{type{0b11111111}} {}
CK_TILE_HOST_DEVICE explicit constexpr e8m0_bexp_t(type init) : data{init} {}
CK_TILE_HOST_DEVICE explicit constexpr e8m0_bexp_t(float scale)
: e8m0_bexp_t(static_cast<type>(numeric_utils<float>::get_exponent(scale)))
{
}
CK_TILE_HOST_DEVICE constexpr operator type() const { return data; }
CK_TILE_HOST_DEVICE constexpr raw_type& get() { return data; }
CK_TILE_HOST_DEVICE constexpr raw_type get() const { return data; }
CK_TILE_HOST_DEVICE constexpr operator float() const;
constexpr bool operator==(const e8m0_bexp_t& other) const { return data == other.data; }
constexpr bool operator!=(const e8m0_bexp_t& other) const { return data != other.data; }
};
using e8m0_t = e8m0_bexp_t;
using e8m0_raw_t = typename e8m0_t::raw_type;
template <>
struct numeric_traits<e8m0_t>
{
using bitwise_type = e8m0_raw_t;
static constexpr int exp = 8;
static constexpr int mant = 0;
static constexpr int bias = 127;
static constexpr int PackedSize = 1;
};
// limits
template <class T>
struct numeric;
template <>
struct numeric<e8m0_t>
{
static constexpr e8m0_raw_t binary_min = 0b00000000; // 2^-127
static constexpr e8m0_raw_t binary_max = 0b11111110; // 2^127
static constexpr e8m0_raw_t binary_nan = 0b11111111;
CK_TILE_HOST_DEVICE static constexpr e8m0_t min() { return e8m0_t{binary_min}; }
CK_TILE_HOST_DEVICE static constexpr e8m0_t max() { return e8m0_t{binary_max}; }
CK_TILE_HOST_DEVICE static constexpr e8m0_t quiet_NaN() { return e8m0_t{binary_nan}; }
CK_TILE_HOST_DEVICE static constexpr e8m0_t signaling_NaN() { return e8m0_t{binary_nan}; }
CK_TILE_HOST_DEVICE static constexpr bool has_inf() { return false; }
CK_TILE_HOST_DEVICE static constexpr e8m0_t epsilon() { return signaling_NaN(); }
CK_TILE_HOST_DEVICE static constexpr e8m0_t round_error() { return signaling_NaN(); }
CK_TILE_HOST_DEVICE static constexpr e8m0_t zero() { return signaling_NaN(); }
CK_TILE_HOST_DEVICE static constexpr e8m0_t infinity() { return signaling_NaN(); }
};
CK_TILE_HOST_DEVICE constexpr e8m0_bexp_t::operator float() const
{
using traits = numeric_traits<float>;
if(data == numeric<e8m0_t>::binary_nan)
{
return std::numeric_limits<float>::signaling_NaN();
}
else if(data == 0)
{
return std::numeric_limits<float>::min();
}
else
{
return bit_cast<float>(static_cast<traits::bitwise_type>(data) << traits::mant);
}
}
} // namespace ck_tile

View File

@@ -43,19 +43,19 @@ enum class fp8_interpretation
};
/*
* ______________FNUZ_________________ | ______________OCP________________
* ______________FNUZ_________________ | ______________OCP________________
* e4m3 e5m2 | e4m3 e5m2
* bias : 8 16 | 7 15
* inf : 1.0000.000 1.00000.00 | N/A s.11111.00
* inf : N/A N/A | N/A s.11111.00
* Nan : 1.0000.000 1.00000.00 | s.1111.111 s.11111.{01, 10, 11}
* zero : 0.0000.000 0.00000.00 | s.0000.000 s.00000.00
* Max(norm) : s.1111.111 (240) s.11111.11(57344) | s.1111.110(448) s.11110.11(57344)
* Max(snorm): s.0000.111 s.00000.11 | s.0000.111 s.00000.11
* 0.0068359375 2.288818e-05 | 0.013671875 4.57763671875e-05
* 0.0068359375 2.288818e-05 | 0.013671875 4.57763671875e-05
* Min(norm) : s.0001.000 s.00001.00 | s.0001.000 s.00001.00
* 2^-7(0.00078125) 2^-15(3.05176e-05) | 2^-6(0.015625) 2^-14(6.10352e-05)
* 2^-7(0.0078125) 2^-15(3.05176e-05) | 2^-6(0.015625) 2^-14(6.10352e-05)
* Min(snorm): s.0000.001 s.00000.01 | s.0000.001 s.00000.01
* 2^-10(0.00097656) 2^-17(7.629395e-06)| 2^-9(0.001953125) 2^-16(1.52588e-05)
* 2^-10(0.0009765625) 2^-17(7.62939e-06) | 2^-9(0.001953125) 2^-16(1.52588e-05)
*/
template <fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
@@ -75,7 +75,7 @@ struct alignas(1) float8_e4m3_t
#if CK_TILE_USE_OCP_FP8
static constexpr int bias = 7; // OCP
#else
static constexpr int bias = 8; // FNUZ
static constexpr int bias = 8; // FNUZ
#endif
using raw_type = uint8_t;
raw_type data;
@@ -259,50 +259,50 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
// fp8/bf8 type exponent/mantissa layout
constexpr int DstT_exp = numeric_traits<DstT>::exp; // exponent width of the destination type
constexpr int DstT_mant = numeric_traits<DstT>::mant; // mantissa width of the destination type
constexpr int DstT_bias = numeric_traits<DstT>::bias;
constexpr bool is_fnuz =
(numeric_traits<DstT>::f8_interpret == fp8_interpretation::E4M3_FNUZ) ||
(numeric_traits<DstT>::f8_interpret == fp8_interpretation::E5M2_FNUZ);
constexpr int SrcT_exp = numeric_traits<SrcT>::exp;
constexpr int SrcT_mant = numeric_traits<SrcT>::mant;
constexpr int SrcT_exp = numeric_traits<SrcT>::exp;
constexpr int SrcT_mant = numeric_traits<SrcT>::mant;
constexpr int bias = numeric_traits<SrcT>::bias;
constexpr unsigned int fInf = numeric_traits<SrcT>::Inf;
constexpr unsigned int abs_mask = numeric_traits<SrcT>::abs_mask;
using SrcT_bitwise = typename numeric_traits<SrcT>::bitwise_type;
SrcT_bitwise src_bitwise = bit_cast<SrcT_bitwise>(src);
unsigned long long head, mantissa;
int exponent, bias;
unsigned int head, mantissa;
int exponent;
unsigned int sign;
unsigned long long fInf, abs_mask;
head = src_bitwise & numeric_traits<SrcT>::head_mask;
mantissa = src_bitwise & numeric_traits<SrcT>::mant_mask;
exponent = (head >> SrcT_mant) & numeric_traits<SrcT>::exp_mask;
sign = head >> (SrcT_exp + SrcT_mant);
bias = numeric_traits<SrcT>::bias;
fInf = numeric_traits<SrcT>::Inf;
abs_mask = numeric_traits<SrcT>::abs_mask;
unsigned int signed_inf = 0;
unsigned int nan = 0;
if constexpr(is_fnuz)
{
signed_inf = clip ? ((sign << 7) + 0x7f) : 0x80;
signed_inf = clip ? ((sign << (DstT_exp + DstT_mant)) + 0x7f) : 0x80;
nan = 0x80;
}
else
{
if constexpr(DstT_exp == 4)
{ // e4m3
signed_inf = (sign << 7) + (clip ? 0x7e : 0x7f);
signed_inf = (sign << (DstT_exp + DstT_mant)) + (clip ? 0x7e : 0x7f);
}
else
{ // e5m2
signed_inf = (sign << 7) + (clip ? 0x7b : 0x7c);
signed_inf = (sign << (DstT_exp + DstT_mant)) + (clip ? 0x7b : 0x7c);
}
nan = (sign << 7) + 0x7f;
nan = (sign << (DstT_exp + DstT_mant)) + 0x7f;
}
// Max values
unsigned long long ifmax = 0;
unsigned int ifmax = 0;
if constexpr(is_float)
{
if constexpr(DstT_exp == 5)
@@ -343,9 +343,6 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
// Deal with inf and NaNs
if((src_bitwise & fInf) == fInf)
{
if constexpr(is_fnuz)
return signed_inf;
return mantissa != 0 ? nan : signed_inf;
}
@@ -354,11 +351,6 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
return signed_inf;
}
if(src_bitwise == 0)
{
return 0;
}
// First need to check if it is normal or denorm as there is a difference of
// implicit 1 Then need to adjust the exponent to align with the F8 exponent,
// in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
@@ -367,8 +359,7 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
// bits
const int f8_bias = (1 << (DstT_exp - 1)) - 1 + (is_fnuz ? 1 : 0);
const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal
constexpr int f8_denormal_act_exponent = 1 - DstT_bias; // actual exponent of f8 denormal
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
// f8_exponent is the converted f8 exponent with bias encoding
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
@@ -406,11 +397,16 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
// for this case, act_exponent could be larger. Just
// that it does not need shift mantissa
}
mantissa += (1ull << SrcT_mant); // Add the implicit 1 into mantissa
mantissa += (1u << SrcT_mant); // Add the implicit 1 into mantissa
}
bool midpoint = (mantissa & ((1ull << (SrcT_mant - DstT_mant + exponent_diff)) - 1)) ==
(1ull << (SrcT_mant - DstT_mant + exponent_diff - 1));
// The value is smaller than min f8 denormal and results in zero (the early exit also prevents
// an undefined behavior of bit shifts >= type width).
if(exponent_diff > DstT_mant)
{
return is_fnuz ? 0 : (sign << (DstT_exp + DstT_mant));
}
bool midpoint = (mantissa & ((1u << (SrcT_mant - DstT_mant + exponent_diff)) - 1)) ==
(1u << (SrcT_mant - DstT_mant + exponent_diff - 1));
/* This part is a bit tricky. The judgment of whether it is a tie needs to be
done before we shift right as shift right could rip off some residual part and
make something not midpoint look like midpoint. For example, the fp16 number
@@ -422,31 +418,31 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
mantissa >>= exponent_diff;
else if(exponent_diff == -1)
mantissa <<= -exponent_diff;
bool implicit_one = mantissa & (1ull << SrcT_mant);
bool implicit_one = mantissa & (1u << SrcT_mant);
// if there is no implicit 1, it means the f8 is denormal and need to adjust
// to denorm exponent
f8_exponent =
(act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1);
(act_exponent + exponent_diff) /*actual f8 exponent*/ + DstT_bias - (implicit_one ? 0 : 1);
// Now we have the exponent and mantissa adjusted
unsigned long long drop_mask = (1ull << (SrcT_mant - DstT_mant)) - 1;
unsigned int drop_mask = (1u << (SrcT_mant - DstT_mant)) - 1;
bool odd =
mantissa & (1ull << (SrcT_mant -
DstT_mant)); // if the least significant bit that is not truncated is 1
mantissa &
(1u << (SrcT_mant - DstT_mant)); // if the least significant bit that is not truncated is 1
mantissa +=
(stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1ull) : mantissa)) & drop_mask;
(stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1u) : mantissa)) & drop_mask;
// Now we deal with overflow
if(f8_exponent == 0)
{
if((1ull << SrcT_mant) & mantissa)
if((1u << SrcT_mant) & mantissa)
{
f8_exponent = 1; // denormal overflow to become normal, promote exponent
}
}
else
{
if((1ull << (SrcT_mant + 1)) & mantissa)
if((1u << (SrcT_mant + 1)) & mantissa)
{
mantissa >>= 1;
f8_exponent++;
@@ -471,9 +467,9 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
}
if(f8_exponent == 0 && mantissa == 0)
return is_fnuz ? 0 : (sign << 7);
return is_fnuz ? 0 : (sign << (DstT_exp + DstT_mant));
mantissa &= (1 << DstT_mant) - 1;
return (sign << 7) | (f8_exponent << DstT_mant) | mantissa;
return (sign << (DstT_exp + DstT_mant)) | (f8_exponent << DstT_mant) | mantissa;
}
template <typename SrcT, typename DstT, bool clip = true>
@@ -481,8 +477,9 @@ CK_TILE_HOST_DEVICE DstT run_cast_from_f8(SrcT x)
{
static_assert(std::is_same<SrcT, fp8_t>::value || std::is_same<SrcT, bf8_t>::value,
"SrcT type must be fp8 or bf8.");
constexpr int SrcT_exp = numeric_traits<SrcT>::exp;
constexpr int SrcT_mant = numeric_traits<SrcT>::mant;
constexpr int SrcT_exp = numeric_traits<SrcT>::exp;
constexpr int SrcT_mant = numeric_traits<SrcT>::mant;
constexpr uint8_t SrcT_abs_mask = numeric_traits<SrcT>::abs_mask;
constexpr bool is_fnuz =
(numeric_traits<SrcT>::f8_interpret == fp8_interpretation::E4M3_FNUZ) ||
(numeric_traits<SrcT>::f8_interpret == fp8_interpretation::E5M2_FNUZ);
@@ -518,9 +515,9 @@ CK_TILE_HOST_DEVICE DstT run_cast_from_f8(SrcT x)
return 0;
}
unsigned long long sign = x >> 7;
unsigned long long mantissa = x & ((1 << SrcT_mant) - 1);
int exponent = (x & 0x7F) >> SrcT_mant;
unsigned int sign = x >> (SrcT_exp + SrcT_mant);
unsigned int mantissa = x & ((1 << SrcT_mant) - 1);
int exponent = (x & SrcT_abs_mask) >> SrcT_mant;
if constexpr(is_fnuz)
{
if((x & 0xff) == 0x80)
@@ -530,7 +527,7 @@ CK_TILE_HOST_DEVICE DstT run_cast_from_f8(SrcT x)
}
else
{
if(x == 0x80)
if(x == SrcT(0x80))
{
return fNeg0;
}
@@ -559,7 +556,7 @@ CK_TILE_HOST_DEVICE DstT run_cast_from_f8(SrcT x)
if constexpr(SrcT_exp == 5 && is_half && !is_fnuz)
{
retval = x << 8;
retval = static_cast<typename numeric_traits<DstT>::bitwise_type>(x) << 8;
return bit_cast<DstT>(retval);
}

View File

@@ -7,6 +7,7 @@
namespace ck_tile {
using index_t = int32_t;
using int32_t = int32_t;
using long_index_t = int64_t;
using int8_t = int8_t;

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -19,14 +19,18 @@ struct constant
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; }
};
template <auto v>
CK_TILE_HOST_DEVICE static void print(const constant<v>&)
{
printf("%ld", static_cast<long>(v));
}
template <typename T, T v>
struct integral_constant : constant<v>
{
using value_type = T;
using type = integral_constant; // using injected-class-name
static constexpr T value = v;
// constexpr CK_TILE_HOST_DEVICE operator value_type() const noexcept { return value; }
// constexpr CK_TILE_HOST_DEVICE value_type operator()() const noexcept { return value; } //
};
template <index_t v>
@@ -79,4 +83,14 @@ CK_TILE_BINARY_OP(<=)
#undef CK_TILE_LEFT_UNARY_OP
#undef CK_TILE_BINARY_OP
template <typename T>
struct is_constant : std::false_type
{
};
template <auto v>
struct is_constant<constant<v>> : std::true_type
{
};
template <typename T>
inline constexpr bool is_constant_v = is_constant<T>::value;
} // namespace ck_tile

View File

@@ -31,8 +31,8 @@ struct scales
CK_TILE_HOST_DEVICE constexpr explicit scales(Scale lhs) : lhs_(lhs) {}
template <typename Right>
CK_TILE_HOST_DEVICE constexpr auto operator()(const Right& rhs) const
-> decltype(std::declval<const Scale&>() * rhs)
CK_TILE_HOST_DEVICE constexpr auto
operator()(const Right& rhs) const -> decltype(std::declval<const Scale&>() * rhs)
{
return lhs_ * rhs;
}
@@ -43,13 +43,13 @@ struct scales
/// FIXME: create macro to replace '__host__ __device__' and nothing more
template <typename Scale>
__host__ __device__ scales(Scale)->scales<Scale>;
__host__ __device__ scales(Scale) -> scales<Scale>;
template <typename Left = void, typename Right = Left>
struct plus
{
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
-> decltype(lhs + rhs)
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
const Right& rhs) const -> decltype(lhs + rhs)
{
return lhs + rhs;
}
@@ -59,21 +59,21 @@ template <>
struct plus<void, void>
{
template <typename Left, typename Right>
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
-> decltype(lhs + rhs)
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
const Right& rhs) const -> decltype(lhs + rhs)
{
return lhs + rhs;
}
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
__host__ __device__ plus()->plus<void, void>;
__host__ __device__ plus() -> plus<void, void>;
template <typename Left = void, typename Right = Left>
struct minus
{
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
-> decltype(lhs - rhs)
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
const Right& rhs) const -> decltype(lhs - rhs)
{
return lhs - rhs;
}
@@ -83,21 +83,21 @@ template <>
struct minus<void, void>
{
template <typename Left, typename Right>
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
-> decltype(lhs - rhs)
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
const Right& rhs) const -> decltype(lhs - rhs)
{
return lhs - rhs;
}
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
__host__ __device__ minus()->minus<void, void>;
__host__ __device__ minus() -> minus<void, void>;
template <typename Left = void, typename Right = Left>
struct multiplies
{
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
-> decltype(lhs * rhs)
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
const Right& rhs) const -> decltype(lhs * rhs)
{
return lhs * rhs;
}
@@ -107,15 +107,15 @@ template <>
struct multiplies<void, void>
{
template <typename Left, typename Right>
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
-> decltype(lhs * rhs)
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
const Right& rhs) const -> decltype(lhs * rhs)
{
return lhs * rhs;
}
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
__host__ __device__ multiplies()->multiplies<void, void>;
__host__ __device__ multiplies() -> multiplies<void, void>;
template <typename T>
struct maximize
@@ -327,8 +327,8 @@ CK_TILE_HOST_DEVICE constexpr auto lcm(X x, Ys... ys)
template <typename Left = void, typename Right = Left>
struct equal
{
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
-> decltype(lhs == rhs)
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
const Right& rhs) const -> decltype(lhs == rhs)
{
return lhs == rhs;
}
@@ -338,15 +338,15 @@ template <>
struct equal<void, void>
{
template <typename Left, typename Right>
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
-> decltype(lhs == rhs)
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
const Right& rhs) const -> decltype(lhs == rhs)
{
return lhs == rhs;
}
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
__host__ __device__ equal()->equal<void, void>;
__host__ __device__ equal() -> equal<void, void>;
template <>
struct equal<float, float>
@@ -369,8 +369,8 @@ struct equal<double, double>
template <typename Left = void, typename Right = Left>
struct less
{
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
-> decltype(lhs < rhs)
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
const Right& rhs) const -> decltype(lhs < rhs)
{
return lhs < rhs;
}
@@ -380,21 +380,21 @@ template <>
struct less<void, void>
{
template <typename Left, typename Right>
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
-> decltype(lhs < rhs)
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
const Right& rhs) const -> decltype(lhs < rhs)
{
return lhs < rhs;
}
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
__host__ __device__ less()->less<void, void>;
__host__ __device__ less() -> less<void, void>;
template <typename Left = void, typename Right = Left>
struct less_equal
{
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
-> decltype(lhs <= rhs)
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
const Right& rhs) const -> decltype(lhs <= rhs)
{
return lhs <= rhs;
}
@@ -404,15 +404,15 @@ template <>
struct less_equal<void, void>
{
template <typename Left, typename Right>
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
-> decltype(lhs <= rhs)
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
const Right& rhs) const -> decltype(lhs <= rhs)
{
return lhs <= rhs;
}
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
__host__ __device__ less_equal()->less_equal<void, void>;
__host__ __device__ less_equal() -> less_equal<void, void>;
template <>
struct less_equal<float, float>
@@ -487,6 +487,9 @@ struct log2e<float>
template <typename T = double>
constexpr T log2e_v = log2e<T>::value;
template <typename T = double>
constexpr T log2e_rcp_v = 1. / log2e<T>::value;
CK_TILE_DEVICE
float exp2(float x) { return exp2f(x); };
@@ -1380,6 +1383,44 @@ CK_TILE_DEVICE double exp<double>(double x)
return exp(x);
};
template <typename T>
CK_TILE_DEVICE T tanh_fast(T x)
{
return type_convert<T>((exp<T>(2.0 * type_convert<float>(x)) - 1.0) /
(exp<T>(2.0 * type_convert<float>(x)) + 1.0));
};
template <>
CK_TILE_DEVICE float tanh_fast<float>(float x)
{
// float a = __builtin_amdgcn_sinh(x);
// float b = __builtin_amdgcn_cosh(x);
// float e = a * __builtin_amdgcn_rcpf(b);
// return e;
float a = 2.0f * log2e_v<float> * x;
a = __builtin_amdgcn_exp2f(a);
a = __builtin_amdgcn_rcpf(a + 1.0f);
a = 2 * a;
a = 1 - a;
return a;
// float e, r, s, t, d;
// float a = x;
// s = abs(a);
// t = -log2e_v<float> * 2.0f * s;
// e = __builtin_amdgcn_exp2f(t);
// d = e + 1.0f;
// r = __builtin_amdgcn_rcpf(d);
// r = e * (-r) + r;
// if (s < 4.997253418e-3f) r = a;
// union fipnr {float f; unsigned int i;};
// fipnr r_; r_.f = r;
// fipnr a_; a_.f = a;
// { r_.i = (r_.i|(a_.i&0x80000000)); r = r_.f; }
// return r;
};
template <typename T>
CK_TILE_DEVICE T log(T x)
{

View File

@@ -0,0 +1,218 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck_tile {
// modify from include/ck/utility/mxfp_utils.hpp
template <typename T>
struct numeric_utils : numeric_traits<T>
{
using traits = numeric_traits<T>;
using _numeric = numeric<T>;
using raw_type = typename traits::bitwise_type;
static constexpr int exp_mask = (1 << traits::exp) - 1;
static constexpr raw_type get_exponent(raw_type x)
{
// TODO: check if repeated calls are optimized.
return (x >> traits::mant) & exp_mask;
}
static constexpr raw_type get_exponent(const T& x)
{
return get_exponent(bit_cast<raw_type>(x));
}
static constexpr bool is_positive(raw_type x)
{
return (x >> (traits::exp + traits::mant)) == _numeric::binary_zero;
}
static constexpr bool is_subnormal(raw_type x)
{
return get_exponent(x) == _numeric::binary_zero;
}
// TODO: replace double with template arg?
static constexpr double get_mantissa(raw_type x)
{
double mantissa = is_subnormal(x) ? 0.0f : 1.0f;
for(raw_type i = 0; i < traits::mant; ++i)
{
mantissa += std::ldexp(static_cast<float>(x & 0b1), -(traits::mant - i));
x >>= 1;
}
return mantissa;
}
};
template <typename T>
CK_TILE_HOST_DEVICE float convert_to_float(typename T::raw_type data, float scale = 1.f)
{
using utils = numeric_utils<T>;
float sign = utils::is_positive(data) ? 1.0 : -1.0;
int exp = (utils::is_subnormal(data) ? 1 : utils::get_exponent(data)) - utils::bias;
float mant = utils::get_mantissa(data);
return std::ldexp(sign * mant * scale, exp);
}
template <typename T>
CK_TILE_HOST_DEVICE typename T::raw_type convert_to_type(float value, float scale = 1.f)
{
using bitwise_type = typename numeric_traits<T>::bitwise_type;
value /= scale;
if(std::abs(value) > float(numeric<T>::max()))
{
float max_value = numeric<T>::max();
// cppcheck-suppress redundantAssignment
uint32_t max_bitwise = bit_cast<uint32_t>(max_value);
// cppcheck-suppress redundantAssignment
bitwise_type sign =
bit_cast<uint32_t>(value) >> (numeric_traits<float>::exp + numeric_traits<float>::mant);
bitwise_type exp =
((max_bitwise >> numeric_traits<float>::mant) & numeric_traits<float>::exp_mask) -
(numeric_traits<float>::bias - numeric_traits<T>::bias);
bitwise_type mantissa =
max_bitwise >> (numeric_traits<float>::mant - numeric_traits<T>::mant);
uint32_t mant_prev = max_bitwise >> (numeric_traits<float>::mant - numeric_traits<T>::mant);
mant_prev &= ((1 << numeric_traits<T>::mant) - 1);
mant_prev--;
mant_prev <<= (numeric_traits<float>::mant - numeric_traits<T>::mant);
uint32_t prev_bit =
((max_bitwise >> numeric_traits<float>::mant) << numeric_traits<float>::mant) |
mant_prev;
float prev_val = bit_cast<float>(prev_bit);
float diff = max_value - prev_val;
float actual_max = max_value + (diff / 2);
if(std::abs(value) < actual_max)
{
return sign << ((numeric_traits<T>::exp + numeric_traits<T>::mant)) |
(exp << numeric_traits<T>::mant) | mantissa;
}
else
{
if constexpr(!numeric<T>::has_inf())
{
return (1 << (numeric_traits<T>::mant + numeric_traits<T>::exp)) - 1;
}
else
{
exp++;
return sign << ((numeric_traits<T>::exp + numeric_traits<T>::mant)) |
(exp << numeric_traits<T>::mant);
}
}
}
const int mfmt = numeric_traits<float>::mant;
uint32_t x;
x = bit_cast<uint32_t>(value);
uint32_t head, mantissa;
int32_t exponent, bias;
uint32_t sign;
head = x & numeric_traits<float>::head_mask;
mantissa = x & numeric_traits<float>::mant_mask;
exponent = (head >> numeric_traits<float>::mant) & numeric_traits<float>::exp_mask;
sign = head >> (numeric_traits<float>::mant + numeric_traits<float>::exp);
bias = numeric_traits<float>::bias;
if(x == 0)
{
return 0b0;
}
const int mini_bias = numeric_traits<T>::bias;
const int mini_denormal_act_exponent = 1 - mini_bias;
int act_exponent, out_exponent, exponent_diff;
bool is_subnorm = false;
if(exponent == 0)
{
act_exponent = exponent - bias + 1;
exponent_diff = mini_denormal_act_exponent - act_exponent;
is_subnorm = true;
}
else
{
act_exponent = exponent - bias;
if(act_exponent <= mini_denormal_act_exponent)
{
exponent_diff = mini_denormal_act_exponent - act_exponent;
is_subnorm = true;
}
else
{
exponent_diff = 0;
}
mantissa += (1UL << mfmt);
}
auto shift_amount = (mfmt - numeric_traits<T>::mant + exponent_diff);
shift_amount = (shift_amount >= 64) ? 63 : shift_amount;
bool midpoint = (mantissa & ((1UL << shift_amount) - 1)) == (1UL << (shift_amount - 1));
float min_subnorm = float(numeric<T>::epsilon()) * (sign ? -1 : 1);
if(is_subnorm && std::abs(value) < std::abs(min_subnorm))
{
// closer to 0
if(std::abs(value) <= std::abs(min_subnorm - value))
return sign << (numeric_traits<T>::exp + numeric_traits<T>::mant);
else
return 1 | (sign << (numeric_traits<T>::exp + numeric_traits<T>::mant));
}
if(exponent_diff > 0)
mantissa >>= exponent_diff;
else if(exponent_diff == -1)
mantissa <<= -exponent_diff;
bool implicit_one = mantissa & (1 << mfmt);
out_exponent = (act_exponent + exponent_diff) + mini_bias - (implicit_one ? 0 : 1);
uint32_t drop_mask = (1UL << (mfmt - numeric_traits<T>::mant)) - 1;
bool odd = mantissa & (1UL << (mfmt - numeric_traits<T>::mant));
mantissa += (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa) & drop_mask;
if(out_exponent == 0)
{
if((1UL << mfmt) & mantissa)
{
out_exponent = 1;
}
}
else
{
if((1UL << (mfmt + 1)) & mantissa)
{
mantissa >>= 1;
out_exponent++;
}
}
mantissa >>= (mfmt - numeric_traits<T>::mant);
if(out_exponent == 0 && mantissa == 0)
{
return sign << (numeric_traits<T>::exp + numeric_traits<T>::mant);
}
mantissa &= (1UL << numeric_traits<T>::mant) - 1;
return (sign << (numeric_traits<T>::exp + numeric_traits<T>::mant)) |
(out_exponent << numeric_traits<T>::mant) | mantissa;
}
} // namespace ck_tile

View File

@@ -103,94 +103,92 @@ struct numeric_traits<float>
} // namespace ck_tile
#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_) \
attr_ bool operator==(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) == static_cast<float>(y); \
} \
attr_ bool operator!=(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) != static_cast<float>(y); \
} \
attr_ bool operator<(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) < static_cast<float>(y); \
} \
attr_ bool operator<=(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) <= static_cast<float>(y); \
} \
attr_ bool operator>(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) > static_cast<float>(y); \
} \
attr_ bool operator>=(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) >= static_cast<float>(y); \
} \
attr_ type_ operator+(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) + static_cast<float>(y)); \
} \
attr_ type_ operator-(const type_& x) \
{ \
constexpr uint32_t bits = sizeof(type_) * 8; \
constexpr uint32_t mask = 1 << (bits - 1); \
type_ y = x; \
y.data ^= static_cast<typename type_::raw_type>(mask); \
return y; \
} \
attr_ type_ operator-(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) - static_cast<float>(y)); \
} \
attr_ type_ operator*(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) * static_cast<float>(y)); \
} \
attr_ type_ operator/(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) / static_cast<float>(y)); \
} \
attr_ type_& operator+=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) + static_cast<float>(y)); \
return x; \
} \
attr_ type_& operator-=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) - static_cast<float>(y)); \
return x; \
} \
attr_ type_& operator*=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) * static_cast<float>(y)); \
return x; \
} \
attr_ type_& operator/=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) / static_cast<float>(y)); \
return x; \
} \
attr_ type_& operator++(type_& x) \
{ \
x = type_(static_cast<float>(x) + 1.f); \
return x; \
} \
attr_ type_& operator--(type_& x) \
{ \
x = type_(static_cast<float>(x) - 1.f); \
return x; \
} \
attr_ type_ operator++(type_& x, int) \
{ \
type_ y(x); \
x = type_(static_cast<float>(x) + 1.f); \
return y; \
} \
attr_ type_ operator--(type_& x, int) \
{ \
type_ y(x); \
x = type_(static_cast<float>(x) - 1.f); \
return y; \
#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_) \
attr_ bool operator==(const type_& x, const type_& y) \
{ \
return std::abs(static_cast<float>(x) - static_cast<float>(y)) < \
static_cast<float>(numeric<type_>::epsilon()); \
} \
attr_ bool operator!=(const type_& x, const type_& y) { return not operator==(x, y); } \
attr_ bool operator<(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) < static_cast<float>(y); \
} \
attr_ bool operator<=(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) <= static_cast<float>(y); \
} \
attr_ bool operator>(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) > static_cast<float>(y); \
} \
attr_ bool operator>=(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) >= static_cast<float>(y); \
} \
attr_ type_ operator+(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) + static_cast<float>(y)); \
} \
attr_ type_ operator-(const type_& x) \
{ \
constexpr uint32_t bits = sizeof(type_) * 8; \
constexpr uint32_t mask = 1 << (bits - 1); \
type_ y = x; \
y.data ^= static_cast<typename type_::raw_type>(mask); \
return y; \
} \
attr_ type_ operator-(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) - static_cast<float>(y)); \
} \
attr_ type_ operator*(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) * static_cast<float>(y)); \
} \
attr_ type_ operator/(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) / static_cast<float>(y)); \
} \
attr_ type_& operator+=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) + static_cast<float>(y)); \
return x; \
} \
attr_ type_& operator-=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) - static_cast<float>(y)); \
return x; \
} \
attr_ type_& operator*=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) * static_cast<float>(y)); \
return x; \
} \
attr_ type_& operator/=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) / static_cast<float>(y)); \
return x; \
} \
attr_ type_& operator++(type_& x) \
{ \
x = type_(static_cast<float>(x) + 1.f); \
return x; \
} \
attr_ type_& operator--(type_& x) \
{ \
x = type_(static_cast<float>(x) - 1.f); \
return x; \
} \
attr_ type_ operator++(type_& x, int) \
{ \
type_ y(x); \
x = type_(static_cast<float>(x) + 1.f); \
return y; \
} \
attr_ type_ operator--(type_& x, int) \
{ \
type_ y(x); \
x = type_(static_cast<float>(x) - 1.f); \
return y; \
}

View File

@@ -0,0 +1,357 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cmath>
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/mxfp_convert.hpp"
#if defined(__gfx950__)
#define CK_TILE_FP4_CVT_DEVICE 1
#else
#define CK_TILE_FP4_CVT_DEVICE 0
#endif
#define TEST_convert_with_table 0
namespace ck_tile {
using fp32_t = float;
using fp32x2_t = float __attribute__((ext_vector_type(2)));
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2)));
CK_TILE_HOST_DEVICE constexpr uint8_t float_to_e2m1(float x, float scale = 1.f);
// TODO: Add stochastic method
struct pk_float4_e2m1_t
{
// TODO: Can we merge raw_type and type?
using raw_type = uint8_t;
using type = raw_type;
raw_type data;
CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t() : data{type{}} {}
template <typename T, typename = std::enable_if_t<std::is_integral_v<T>>>
CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t(T init) : data{static_cast<type>(init)}
{
}
CK_TILE_HOST_DEVICE explicit constexpr pk_float4_e2m1_t(float init, float scale = 1.f)
: data{float_to_e2m1(init, scale)}
{
}
CK_TILE_HOST_DEVICE constexpr operator type() const { return data; }
CK_TILE_HOST_DEVICE constexpr raw_type& get() { return data; }
CK_TILE_HOST_DEVICE constexpr raw_type get() const { return data; }
CK_TILE_HOST_DEVICE constexpr float to_float(float scale = 1.f) const;
CK_TILE_HOST_DEVICE constexpr fp32x2_t to_fp32x2(float scale = 1.f) const;
CK_TILE_HOST_DEVICE constexpr fp16_t to_fp16(float scale = 1.f) const;
CK_TILE_HOST_DEVICE constexpr fp16x2_t to_fp16x2(float scale = 1.f) const;
CK_TILE_HOST_DEVICE constexpr bf16_t to_bf16(float scale = 1.f) const;
CK_TILE_HOST_DEVICE constexpr bf16x2_t to_bf16x2(float scale = 1.f) const;
CK_TILE_HOST_DEVICE constexpr operator float() const { return to_float(); }
CK_TILE_HOST_DEVICE constexpr operator fp32x2_t() const { return to_fp32x2(); }
CK_TILE_HOST_DEVICE constexpr operator fp16_t() const { return to_fp16(); }
CK_TILE_HOST_DEVICE constexpr operator fp16x2_t() const { return to_fp16x2(); }
CK_TILE_HOST_DEVICE constexpr operator bf16_t() const { return to_bf16(); }
CK_TILE_HOST_DEVICE constexpr operator bf16x2_t() const { return to_bf16x2(); }
template <index_t I>
CK_TILE_HOST_DEVICE constexpr raw_type unpack(number<I>) const;
CK_TILE_HOST_DEVICE constexpr static pk_float4_e2m1_t pack(const type x0, const type x1)
{
return (x1 << 4) | (x0 & 0b00001111);
}
#if TEST_convert_with_table
static constexpr float e2m1_to_fp32_table[16] = {
0, 0.5, 1, 1.5, 2, 3, 4, 6, -0, -0.5, -1, -1.5, -2, -3, -4, -6};
static constexpr fp16_t e2m1_to_fp16_table[16] = {
bit_cast<fp16_t>(static_cast<uint16_t>(0x0000)), // 0
bit_cast<fp16_t>(static_cast<uint16_t>(0x3800)), // 0.5
bit_cast<fp16_t>(static_cast<uint16_t>(0x3C00)), // 1
bit_cast<fp16_t>(static_cast<uint16_t>(0x3E00)), // 1.5
bit_cast<fp16_t>(static_cast<uint16_t>(0x4000)), // 2
bit_cast<fp16_t>(static_cast<uint16_t>(0x4200)), // 3
bit_cast<fp16_t>(static_cast<uint16_t>(0x4400)), // 4
bit_cast<fp16_t>(static_cast<uint16_t>(0x4600)), // 6
bit_cast<fp16_t>(static_cast<uint16_t>(0x8000)), // -0
bit_cast<fp16_t>(static_cast<uint16_t>(0xB800)), // -0.5
bit_cast<fp16_t>(static_cast<uint16_t>(0xBC00)), // -1
bit_cast<fp16_t>(static_cast<uint16_t>(0xBE00)), // -1.5
bit_cast<fp16_t>(static_cast<uint16_t>(0xC000)), // -2
bit_cast<fp16_t>(static_cast<uint16_t>(0xC200)), // -3
bit_cast<fp16_t>(static_cast<uint16_t>(0xC400)), // -4
bit_cast<fp16_t>(static_cast<uint16_t>(0xC600)) // -6
};
#endif
};
using pk_fp4_t = pk_float4_e2m1_t;
using pk_fp4_raw_t = typename pk_fp4_t::raw_type;
template <>
struct numeric_traits<pk_fp4_t>
{
using bitwise_type = pk_fp4_raw_t;
static constexpr int exp = 2;
static constexpr int mant = 1;
static constexpr int bias = 1;
static constexpr int PackedSize = 2;
};
// limits
template <class T>
struct numeric;
template <>
struct numeric<pk_fp4_t>
{
static constexpr pk_fp4_raw_t binary_min_normal = 0b00100010; // 1
static constexpr pk_fp4_raw_t binary_max_normal = 0b01110111; // 6
static constexpr pk_fp4_raw_t binary_lowest_normal = 0b11111111; // -6
static constexpr pk_fp4_raw_t binary_min_subnorm = 0b00010001; // 0.5
static constexpr pk_fp4_raw_t binary_max_subnorm = 0b00010001; // 0.5
static constexpr pk_fp4_raw_t binary_zero = 0b00000000; // 0
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t min() { return binary_min_normal; }
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t max() { return binary_max_normal; }
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t lowest() { return binary_lowest_normal; }
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t epsilon() { return binary_min_subnorm; }
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t round_error() { return binary_min_subnorm; }
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t zero() { return binary_zero; }
CK_TILE_HOST_DEVICE static constexpr fp8_t denorm_min() { return binary_min_subnorm; }
CK_TILE_HOST_DEVICE static constexpr bool has_inf() { return false; }
// N/A
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t infinity() { return max(); }
// N/A
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t quiet_NaN() { return max(); }
// N/A
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t signaling_NaN() { return max(); }
};
template <index_t I>
CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t pk_fp4_t::unpack(number<I>) const
{
static_assert(I < 2, "Index is out of range.");
if constexpr(I == 1)
return (data >> 4);
else
return data & 0b00001111;
}
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, pk_fp4_t)
// TODO: consider replace this macro to improve performance
#if CK_TILE_FP4_CVT_DEVICE
namespace impl {
template <typename T>
CK_TILE_DEVICE T _from_f4(pk_fp4_raw_t src, float scale = 1.0f)
{
if constexpr(std::is_same_v<T, fp32_t>)
return fp32x2_t(__builtin_amdgcn_cvt_scalef32_pk_f32_fp4(src, scale, 0))[0];
else if constexpr(std::is_same_v<T, fp32x2_t>)
return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(src, scale, 0);
else if constexpr(std::is_same_v<T, fp16_t>)
return fp16x2_t(__builtin_amdgcn_cvt_scalef32_pk_f16_fp4(src, scale, 0))[0];
else if constexpr(std::is_same_v<T, fp16x2_t>)
return __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(src, scale, 0);
else if constexpr(std::is_same_v<T, bf16_t>)
return bf16x2_t(__builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(src, scale, 0))[0];
else if constexpr(std::is_same_v<T, bf16x2_t>)
return __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(src, scale, 0);
else
static_assert(std::false_type::value, "Unsupported type.");
return T{};
}
template <typename T>
CK_TILE_DEVICE pk_fp4_raw_t _to_f4(T src, float scale = 1.0f)
{
union
{
uint32_t u32;
pk_fp4_raw_t pf4[4];
} cvt{0};
if constexpr(std::is_same_v<T, fp32_t>)
cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(cvt.u32, src, src, scale, 0);
else if constexpr(std::is_same_v<T, fp32x2_t>)
cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(cvt.u32, src[0], src[1], scale, 0);
else if constexpr(std::is_same_v<T, fp16_t>)
cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(cvt.u32, fp16x2_t{src, src}, scale, 0);
else if constexpr(std::is_same_v<T, fp16x2_t>)
cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(cvt.u32, src, scale, 0);
else if constexpr(std::is_same_v<T, bf16_t>)
cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_bf16(cvt.u32, bf16x2_t{src, src}, scale, 0);
else if constexpr(std::is_same_v<T, bf16x2_t>)
cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_bf16(cvt.u32, src, scale, 0);
else
static_assert(std::false_type::value, "Unsupported type.");
return cvt.pf4[0];
}
} // namespace impl
#endif
CK_TILE_HOST_DEVICE constexpr bf16_t pk_fp4_t::to_bf16(float scale) const
{
#if CK_TILE_FP4_CVT_DEVICE
return impl::_from_f4<bf16_t>(data, scale);
#else
return bf16_t{type_convert<bf16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{}), scale))};
#endif
}
CK_TILE_HOST_DEVICE constexpr bf16x2_t pk_fp4_t::to_bf16x2(float scale) const
{
#if CK_TILE_FP4_CVT_DEVICE
return impl::_from_f4<bf16x2_t>(data, scale);
#else
return bf16x2_t{type_convert<bf16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{}), scale)),
type_convert<bf16_t>(convert_to_float<pk_fp4_t>(unpack(number<1>{}), scale))};
#endif
}
// TODO: make float_to_e2m1 generic so that we can convert from directrly.
CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t float_to_e2m1(float x, float scale)
{
#if CK_TILE_FP4_CVT_DEVICE
return impl::_to_f4(x, scale);
#else
return convert_to_type<pk_fp4_t>(x, scale);
#endif
}
CK_TILE_HOST_DEVICE constexpr pk_fp4_t float_to_pk_fp4(const float& x, float scale)
{
return float_to_e2m1(x, scale);
}
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16_to_pk_fp4(const fp16_t& x, float scale)
{
#if CK_TILE_FP4_CVT_DEVICE
return impl::_to_f4(x, scale);
#else
return float_to_e2m1(type_convert<float>(x), scale);
#endif
}
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16_to_pk_fp4(const bf16_t& x, float scale)
{
#if CK_TILE_FP4_CVT_DEVICE
return impl::_to_f4(x, scale);
#else
return float_to_e2m1(type_convert<float>(x), scale);
#endif
}
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t& x, float scale)
{
#if CK_TILE_FP4_CVT_DEVICE
return impl::_to_f4(x, scale);
#else
return pk_fp4_t::pack(float_to_e2m1(type_convert<float>(x[0]), scale),
float_to_e2m1(type_convert<float>(x[1]), scale));
#endif
}
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x, float scale)
{
#if CK_TILE_FP4_CVT_DEVICE
return impl::_to_f4(x, scale);
#else
return pk_fp4_t::pack(float_to_e2m1(type_convert<float>(x[0]), scale),
float_to_e2m1(type_convert<float>(x[1]), scale));
#endif
}
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x, float scale)
{
#if CK_TILE_FP4_CVT_DEVICE
return impl::_to_f4(x, scale);
#else
return pk_fp4_t::pack(float_to_e2m1(x[0], scale), float_to_e2m1(x[1], scale));
#endif
}
CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_to_fp32x2(const pk_fp4_t& x, float scale)
{
return x.to_fp32x2(scale);
}
CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_to_fp16x2(const pk_fp4_t& x, float scale)
{
return x.to_fp16x2(scale);
}
CK_TILE_HOST_DEVICE constexpr bf16x2_t pk_fp4_to_bf16x2(const pk_fp4_t& x, float scale)
{
return x.to_bf16x2(scale);
}
CK_TILE_HOST_DEVICE constexpr float pk_fp4_to_float(const pk_fp4_t& x, float scale)
{
return x.to_float(scale);
}
CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_to_fp16(const pk_fp4_t& x, float scale)
{
return x.to_fp16(scale);
}
CK_TILE_HOST_DEVICE constexpr bf16_t pk_fp4_to_bf16(const pk_fp4_t& x, float scale)
{
return x.to_bf16(scale);
}
#if TEST_convert_with_table == 0
CK_TILE_HOST_DEVICE constexpr float pk_fp4_t::to_float(float scale) const
{
#if CK_TILE_FP4_CVT_DEVICE
return impl::_from_f4<fp32_t>(data, scale);
#else
return convert_to_float<pk_fp4_t>(unpack(number<0>{}), scale);
#endif
}
CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_t::to_fp32x2(float scale) const
{
#if CK_TILE_FP4_CVT_DEVICE
return impl::_from_f4<fp32x2_t>(data, scale);
#else
return fp32x2_t{convert_to_float<pk_fp4_t>(unpack(number<0>{}), scale),
convert_to_float<pk_fp4_t>(unpack(number<1>{}), scale)};
#endif
}
CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_t::to_fp16(float scale) const
{
#if CK_TILE_FP4_CVT_DEVICE
return impl::_from_f4<fp16_t>(data, scale);
#else
return fp16_t{type_convert<fp16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{}), scale))};
#endif
}
CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const
{
#if CK_TILE_FP4_CVT_DEVICE
return impl::_from_f4<fp16x2_t>(data, scale);
#else
return fp16x2_t{type_convert<fp16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{}), scale)),
type_convert<fp16_t>(convert_to_float<pk_fp4_t>(unpack(number<1>{}), scale))};
#endif
}
#else
CK_TILE_HOST_DEVICE constexpr float pk_fp4_t::to_float(float scale) const
{
return e2m1_to_fp32_table[unpack(number<0>{})] * scale;
}
CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_t::to_fp32x2(float scale) const
{
return fp32x2_t{e2m1_to_fp32_table[unpack(number<0>{})] * scale, e2m1_to_fp32_table[unpack(number<1>{}] * scale};
}
CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_t::to_fp16(float scale) const
{
return type_convert<float>(e2m1_to_fp16_table[unpack(number<0>{})]) * scale;
}
CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const
{
return fp16x2_t{
type_convert<fp16_t>(type_convert<float>(e2m1_to_fp16_table[unpack(number<0>{})]) * scale),
type_convert<fp16_t>(type_convert<float>(e2m1_to_fp16_table[unpack(number<1>{})]) * scale)};
}
#endif
} // namespace ck_tile

View File

@@ -99,7 +99,7 @@ struct numeric_traits<pk_int4_t>
using fp32x2_t = float __attribute__((ext_vector_type(2)));
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2)));
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t& x)
{
@@ -116,6 +116,24 @@ CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t& x)
return res;
}
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t_signed_conversion(const pk_int4_t& x)
{
uint8_t x_u8 = ck_tile::bit_cast<uint8_t>(x);
float x_l = ((x_u8 & 0x0f) >> 0);
float x_h = ((x_u8 & 0xf0) >> 4);
x_l = x_l > 7 ? x_l - 16 : x_l;
x_h = x_l > 7 ? x_l - 16 : x_l;
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
fp32x2_t res = {x_h, x_l};
#elif
fp32x2_t res = {x_l, x_h};
#endif
return res;
}
CK_TILE_HOST_DEVICE fp16x2_t pk_int4_t_to_halfx2_t(const pk_int4_t& x)
{
uint8_t x_u8 = ck_tile::bit_cast<uint8_t>(x);

View File

@@ -11,6 +11,7 @@
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/int8.hpp"
#include "ck_tile/core/numeric/mxfp_convert.hpp"
namespace ck_tile {
@@ -63,8 +64,44 @@ CK_TILE_TYPE_CONVERT(bf8_t, bf8, float, float)
CK_TILE_TYPE_CONVERT(float, float, int8_t, int8)
CK_TILE_TYPE_CONVERT(int8_t, int8, float, float)
#undef CK_TILE_TYPE_CONVERT
} // namespace ck_tile
#include "ck_tile/core/numeric/pk_fp4.hpp"
namespace ck_tile {
template <typename Y, typename X>
CK_TILE_HOST_DEVICE constexpr Y scaled_type_convert(X x, float scale);
#define CK_TILE_SCALED_TYPE_CONVERT(dtype_, dname_, stype_, sname_) \
template <> \
CK_TILE_HOST_DEVICE constexpr dtype_ scaled_type_convert<dtype_, stype_>(stype_ x, \
float scale) \
{ \
return sname_##_to_##dname_(x, scale); \
} \
template <> \
CK_TILE_HOST_DEVICE constexpr dtype_ type_convert<dtype_, stype_>(stype_ x) \
{ \
return sname_##_to_##dname_(x, 1.f); \
}
CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp32x2_t, fp32x2)
CK_TILE_SCALED_TYPE_CONVERT(fp32x2_t, fp32x2, pk_fp4_t, pk_fp4)
CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp16x2_t, fp16x2)
CK_TILE_SCALED_TYPE_CONVERT(fp16x2_t, fp16x2, pk_fp4_t, pk_fp4)
CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, bf16x2_t, bf16x2)
CK_TILE_SCALED_TYPE_CONVERT(bf16x2_t, bf16x2, pk_fp4_t, pk_fp4)
CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, float, float)
CK_TILE_SCALED_TYPE_CONVERT(float, float, pk_fp4_t, pk_fp4)
CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, bf16_t, bf16)
CK_TILE_SCALED_TYPE_CONVERT(bf16_t, bf16, pk_fp4_t, pk_fp4)
CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp16_t, fp16)
CK_TILE_SCALED_TYPE_CONVERT(fp16_t, fp16, pk_fp4_t, pk_fp4)
#undef CK_TILE_SCALED_TYPE_CONVERT
#endif
} // namespace ck_tile

View File

@@ -84,7 +84,7 @@ using ext_vector_t = typename impl::ext_vector<T, N>::type;
// by default, any type will result in a vector_size=1 with scalar_type=T traits.
// ... unless we have other vector_traits specialization
template <typename T, typename>
template <typename T, typename = void>
struct vector_traits
{
using scalar_type =
@@ -94,7 +94,7 @@ struct vector_traits
// specialization for ext_vector_type()
template <typename T, index_t N>
struct vector_traits<T __attribute__((ext_vector_type(N)))>
struct vector_traits<T __attribute__((ext_vector_type(N))), void>
{
using scalar_type = std::conditional_t<std::is_same_v<T, pk_int4_t>, int8_t, T>;
static constexpr index_t vector_size = N;
@@ -131,12 +131,12 @@ using fp16x64_t = _Float16 __attribute__((ext_vector_type(64)));
// bf16
// using bf16_t = ...
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
using bf16x4_t = bf16_raw_t __attribute__((ext_vector_type(4)));
using bf16x8_t = bf16_raw_t __attribute__((ext_vector_type(8)));
using bf16x16_t = bf16_raw_t __attribute__((ext_vector_type(16)));
using bf16x32_t = bf16_raw_t __attribute__((ext_vector_type(32)));
using bf16x64_t = bf16_raw_t __attribute__((ext_vector_type(64)));
using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2)));
using bf16x4_t = bfloat16_t __attribute__((ext_vector_type(4)));
using bf16x8_t = bfloat16_t __attribute__((ext_vector_type(8)));
using bf16x16_t = bfloat16_t __attribute__((ext_vector_type(16)));
using bf16x32_t = bfloat16_t __attribute__((ext_vector_type(32)));
using bf16x64_t = bfloat16_t __attribute__((ext_vector_type(64)));
// i32
// using int32_t = ...

View File

@@ -5,7 +5,7 @@
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/arch/arch.hpp"
#if __clang_major__ == 20
#if __clang_major__ >= 20
#include "ck_tile/core/arch/amd_buffer_addressing_builtins.hpp"
#else
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
@@ -18,6 +18,7 @@
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/utility/ignore.hpp"
namespace ck_tile {
@@ -133,6 +134,28 @@ struct buffer_view<address_space_enum::generic,
}
}
/*
In the generic address space, we do not support the transpose instruction in the buffer view.
Will report compilation error when developer wants to use it.
*/
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE constexpr auto transpose_get(index_t i,
index_t linear_offset,
bool is_valid_element,
bool_constant<oob_conditional_check> = {}) const
{
static_assert(false, "Error: transpose load not supported in global memory space.");
ignore = i;
ignore = linear_offset;
ignore = is_valid_element;
return;
}
// i is offset of T, not X. i should be aligned to X
template <memory_operation_enum Op,
typename X,
@@ -187,28 +210,6 @@ struct buffer_view<address_space_enum::generic,
// FIXME: remove
CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
CK_TILE_HOST_DEVICE void print() const
{
printf("buffer_view{");
// AddressSpace
printf("AddressSpace: generic, ");
// p_data_
printf("p_data_: %p, ", static_cast<void*>(const_cast<remove_cvref_t<T>*>(p_data_)));
// buffer_size_
printf("buffer_size_: ");
print(buffer_size_);
printf(", ");
// invalid_element_value_
printf("invalid_element_value_: ");
print(invalid_element_value_);
printf("}");
}
};
// Address Space: Global
@@ -359,6 +360,28 @@ struct buffer_view<address_space_enum::global,
}
}
/*
In the global memory address space, we do not support the transpose instruction in the buffer
view. Will report compilation error when developer wants to use it.
*/
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE constexpr auto transpose_get(index_t i,
index_t linear_offset,
bool is_valid_element,
bool_constant<oob_conditional_check> = {}) const
{
static_assert(false, "Error: transpose load not supported in global memory space.");
ignore = i;
ignore = linear_offset;
ignore = is_valid_element;
return;
}
// i is offset of T, not X. i should be aligned to X
template <typename X,
bool oob_conditional_check = true,
@@ -407,10 +430,12 @@ struct buffer_view<address_space_enum::global,
"wrong! X should contain multiple T");
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_data_, (buffer_size_) * sizeof(type));
amd_async_buffer_load_with_oob<remove_cvref_t<T>, t_per_x, Coherence>(
smem,
cached_buf_res_,
src_wave_buffer_resource,
i,
linear_offset,
is_valid_element,
@@ -710,28 +735,6 @@ struct buffer_view<address_space_enum::global,
// FIXME: remove
CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
CK_TILE_HOST_DEVICE void print() const
{
printf("buffer_view{");
// AddressSpace
printf("AddressSpace: Global, ");
// p_data_
printf("p_data_: %p, ", static_cast<void*>(const_cast<remove_cvref_t<T>*>(p_data_)));
// buffer_size_
printf("buffer_size_: ");
print(buffer_size_);
printf(", ");
// invalid_element_value_
printf("invalid_element_value_: ");
print(invalid_element_value_);
printf("}");
}
};
// Address Space: LDS
@@ -852,6 +855,47 @@ struct buffer_view<address_space_enum::lds,
smem_load<sizeof(X)>{}(dst, v_offset * sizeof(T), i_offset * sizeof(T));
}
template <typename X,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE constexpr auto transpose_get([[maybe_unused]] index_t i,
[[maybe_unused]] index_t linear_offset,
bool is_valid_element) const
{
// X contains multiple T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T");
if(is_valid_element)
{
#if defined(__gfx950__)
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
constexpr address_space_enum addr_space = get_address_space();
return amd_transpose_load_to_vgpr<remove_cvref_t<T>, t_per_x, addr_space>(
p_data_ + i + linear_offset);
#else
return X{numeric<remove_cvref_t<T>>::zero()};
#endif
}
else
{
if constexpr(InvalidElementUseNumericalZeroValue)
{
return X{numeric<remove_cvref_t<T>>::zero()};
}
else
{
return X{invalid_element_value_};
}
}
}
// i is offset of T, not X. i should be aligned to X
template <memory_operation_enum Op,
typename X,
@@ -906,45 +950,39 @@ struct buffer_view<address_space_enum::lds,
// ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to
// ds_write_b128
// TODO: remove this after compiler fix
// clang-format off
static_assert(
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, int8_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, int8x2_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, int8x4_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8x4_t> &&
std::is_same_v<remove_cvref_t<X>, int8x4_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8x8_t> &&
std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8x16_t> &&
std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, int8_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, int8x2_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, int8x4_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8x4_t> && std::is_same_v<remove_cvref_t<X>, int8x4_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8x8_t> && std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8x16_t> && std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
// int8 on thread buffer
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 16>>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 8>>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 4>>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 2>>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 1>>) ||
// ext_vector_type for pk_int4 must use int8_t as type
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 1>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 2>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 4>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 8>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 16>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4x4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 4>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4x8_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 8>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4x16_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 16>>),
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 1>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 2>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 4>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 8>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 16>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4x4_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 4>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4x8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 8>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4x16_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 16>>),
"wrong! not implemented for this combination, please add "
"implementation");
// clang-format on
if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, int8_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 1>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 1>>))
{
@@ -955,6 +993,8 @@ struct buffer_view<address_space_enum::lds,
}
else if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, int8x2_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 2>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 2>>))
{
@@ -965,6 +1005,8 @@ struct buffer_view<address_space_enum::lds,
}
else if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, int8x4_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 4>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 4>>))
{
@@ -975,6 +1017,8 @@ struct buffer_view<address_space_enum::lds,
}
else if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 8>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 8>>))
{
@@ -985,6 +1029,8 @@ struct buffer_view<address_space_enum::lds,
}
else if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 16>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 16>>))
{
@@ -1048,28 +1094,6 @@ struct buffer_view<address_space_enum::lds,
// FIXME: remove
CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
CK_TILE_HOST_DEVICE void print() const
{
printf("buffer_view{");
// AddressSpace
printf("AddressSpace: Lds, ");
// p_data_
printf("p_data_: %p, ", static_cast<void*>(const_cast<remove_cvref_t<T>*>(p_data_)));
// buffer_size_
printf("buffer_size_: ");
print(buffer_size_);
printf(", ");
// invalid_element_value_
printf("invalid_element_value_: ");
print(invalid_element_value_);
printf("}");
}
};
// Address Space: Vgpr
@@ -1223,28 +1247,6 @@ struct buffer_view<address_space_enum::vgpr,
// FIXME: remove
CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
CK_TILE_HOST_DEVICE void print() const
{
printf("buffer_view{");
// AddressSpace
printf("AddressSpace: Vgpr, ");
// p_data_
printf("p_data_: %p, ", static_cast<void*>(const_cast<remove_cvref_t<T>*>(p_data_)));
// buffer_size_
printf("buffer_size_: ");
print(buffer_size_);
printf(", ");
// invalid_element_value_
printf("invalid_element_value_: ");
print(invalid_element_value_);
printf("}");
}
};
template <address_space_enum BufferAddressSpace,
@@ -1270,4 +1272,25 @@ make_buffer_view(T* p, BufferSizeType buffer_size, X invalid_element_value)
p, buffer_size, invalid_element_value};
}
// Generalized print function for all buffer_view variants
template <address_space_enum BufferAddressSpace,
typename T,
typename BufferSizeType,
bool InvalidElementUseNumericalZeroValue,
amd_buffer_coherence_enum Coherence>
CK_TILE_HOST_DEVICE void print(const buffer_view<BufferAddressSpace,
T,
BufferSizeType,
InvalidElementUseNumericalZeroValue,
Coherence>& bv)
{
printf("buffer_view{AddressSpace: %s, p_data_: %p, buffer_size_: ",
address_space_to_string(BufferAddressSpace),
static_cast<void*>(const_cast<remove_cvref_t<T>*>(bv.p_data_)));
print(bv.buffer_size_);
printf(", invalid_element_value_: ");
print(bv.invalid_element_value_);
printf("}");
}
} // namespace ck_tile

View File

@@ -18,32 +18,8 @@
namespace ck_tile {
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return tile_window.load(number<i_access>{}, bool_constant<oob_conditional_check>{});
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(const tile_window_linear<BottomTensorView_,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
template <typename TileWindow_, index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
@@ -51,35 +27,11 @@ CK_TILE_DEVICE auto load_tile(const tile_window_linear<BottomTensorView_,
}
template <typename DistributedTensor_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
typename TileWindow_,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile,
const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return tile_window.load(dst_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
}
template <typename DistributedTensor_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile,
const tile_window_linear<BottomTensorView_,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
const TileWindow_& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
@@ -138,42 +90,25 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
}
template <typename LdsTileWindow_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
typename TileWindow_,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto
async_load_tile_raw(LdsTileWindow_&& lds_tile,
const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
bool oob_conditional_check = true>
CK_TILE_DEVICE auto async_load_tile(LdsTileWindow_&& lds_tile,
const TileWindow_& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return tile_window.async_load_raw(lds_tile,
number<i_access>{},
bool_constant<oob_conditional_check>{},
bool_constant<pre_nop>{});
return tile_window.async_load(
lds_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
}
template <typename LdsTileWindow_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
typename TileWindow_,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile,
const tile_window_linear<BottomTensorView_,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
const TileWindow_& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})

View File

@@ -0,0 +1,446 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/container/thread_buffer.hpp"
#include "ck_tile/core/container/statically_indexed_array.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
constexpr int DS_READ_TR_SIZE()
{
return 8; // Literal constant, evaluated at compile time
}
namespace util {
template <typename Suffix, typename Sequence>
struct is_sequence_suffix
{
static constexpr bool size_check = (Suffix::size() <= Sequence::size());
static constexpr index_t start_pos = Sequence::size() - Suffix::size();
using extract_indices = typename arithmetic_sequence_gen<start_pos, Sequence::size(), 1>::type;
static constexpr bool value =
size_check && (Suffix{} == decltype(Sequence::extract(extract_indices{})){});
};
template <index_t... Xs>
struct is_sequence_suffix<sequence<>, sequence<Xs...>>
{
static constexpr bool value = true;
};
template <typename Suffix, typename Sequence>
constexpr bool is_sequence_suffix_v = is_sequence_suffix<Suffix, Sequence>::value;
} // namespace util
// Default policy: Retains original 2D transpose behavior
template <typename DataType>
struct DefaultTranspose
{
template <index_t LaneGroupSize>
struct Quad16
{
static_assert(LaneGroupSize == 64 || LaneGroupSize == 32 || LaneGroupSize == 16,
"LaneGroupSize must be 64, 32, or 16");
using InputEncoding =
tile_distribution_encoding<sequence<>,
tuple<sequence<4>, sequence<LaneGroupSize / 16, 4, 4>>,
tuple<sequence<2, 1, 2>>,
tuple<sequence<0, 0, 1>>,
sequence<2>,
sequence<2>>;
using OutputEncoding =
tile_distribution_encoding<sequence<>,
tuple<sequence<LaneGroupSize>, sequence<4>>,
tuple<sequence<1>>,
tuple<sequence<0>>,
sequence<2>,
sequence<0>>;
};
template <index_t LaneGroupSize>
struct Quad8
{
static_assert(LaneGroupSize == 64 || LaneGroupSize == 32 || LaneGroupSize == 16,
"LaneGroupSize must be 64, 32, or 16");
using InputEncoding =
tile_distribution_encoding<sequence<>,
tuple<sequence<8>, sequence<LaneGroupSize / 16, 2, 8>>,
tuple<sequence<2, 1, 2>>,
tuple<sequence<0, 0, 1>>,
sequence<2>,
sequence<2>>;
using OutputEncoding =
tile_distribution_encoding<sequence<>,
tuple<sequence<LaneGroupSize>, sequence<8>>,
tuple<sequence<1>>,
tuple<sequence<0>>,
sequence<2>,
sequence<0>>;
};
// Select based on data size
template <index_t LaneGroupSize>
using QuadInputEncoding = std::conditional_t<sizeof(DataType) == 2,
typename Quad16<LaneGroupSize>::InputEncoding,
typename Quad8<LaneGroupSize>::InputEncoding>;
template <index_t LaneGroupSize>
using QuadOutputEncoding = std::conditional_t<sizeof(DataType) == 2,
typename Quad16<LaneGroupSize>::OutputEncoding,
typename Quad8<LaneGroupSize>::OutputEncoding>;
// Always swap last two dimensions
static constexpr auto transpose_dims = sequence<1, 0>{};
// Programmable: Element grouping function
static constexpr auto group_func = [](auto idx) {
return idx; // Identity mapping
};
template <typename InDstrEncode, bool ReverseDirection, index_t LaneGroupSize>
struct ValidationTraitsImpl
{
using QuadEncoding = std::conditional_t<ReverseDirection,
QuadOutputEncoding<LaneGroupSize>,
QuadInputEncoding<LaneGroupSize>>;
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto input_hs = InDstrEncode::hs_lengthss_;
static constexpr auto quad_hs = QuadEncoding::hs_lengthss_;
// 1. Must be 2D tensor
static constexpr bool dims_valid = (InDstrEncode::NDimX == 2);
// 2. Quad pattern must be suffix of input pattern
static constexpr bool suffix_valid_dim0 =
util::is_sequence_suffix_v<decltype(quad_hs[I0]), decltype(input_hs[I0])>;
static constexpr bool suffix_valid_dim1 =
util::is_sequence_suffix_v<decltype(quad_hs[I1]), decltype(input_hs[I1])>;
// 3. PS→RHS mapping constraints
static constexpr auto input_ps_major = InDstrEncode::ps_to_rhss_major_;
static constexpr auto input_ps_minor = InDstrEncode::ps_to_rhss_minor_;
static constexpr auto quad_ps_major0 = QuadEncoding::ps_to_rhss_major_[I0];
static constexpr auto quad_ps_minor0 = QuadEncoding::ps_to_rhss_minor_[I0];
static constexpr auto input_ps_major_last =
input_ps_major[number<input_ps_major.size() - 1>{}];
static constexpr auto input_ps_minor_last =
input_ps_minor[number<input_ps_minor.size() - 1>{}];
using psys_offset = ck_tile::sequence<input_hs[I0].size() - quad_hs[I0].size(),
input_hs[I1].size() - quad_hs[I1].size()>;
static constexpr auto shifted_quad_ps_minor0 = generate_sequence_v2(
[](auto i) {
return number<quad_ps_minor0[i] + psys_offset{}[quad_ps_major0[i] - 1]>{};
},
number<quad_ps_minor0.size()>{});
static constexpr bool ps_mapping_valid =
util::is_sequence_suffix_v<decltype(quad_ps_major0), decltype(input_ps_major_last)> &&
util::is_sequence_suffix_v<decltype(shifted_quad_ps_minor0),
decltype(input_ps_minor_last)>;
// 4. YS→RHS mapping constraints
static constexpr auto input_ys_major = InDstrEncode::ys_to_rhs_major_;
static constexpr auto input_ys_minor = InDstrEncode::ys_to_rhs_minor_;
static constexpr auto quad_ys_major = QuadEncoding::ys_to_rhs_major_;
static constexpr auto quad_ys_minor = QuadEncoding::ys_to_rhs_minor_;
static_assert(quad_ys_major.size() == 1 && quad_ys_minor.size() == 1,
"YS->RHS mapping must be single dimension");
static_assert(quad_ys_major.back() == 2 && quad_ys_minor.back() == quad_hs[I1].size() - 1,
"YS->RHS mapping must be the last dimension");
static constexpr bool ys_mapping_valid =
(input_ys_major.back() == 2) && (input_ys_minor.back() == input_hs[I1].size() - 1);
static constexpr bool value = dims_valid && suffix_valid_dim0 && suffix_valid_dim1 &&
ps_mapping_valid && ys_mapping_valid;
};
template <typename InDstrEncode, bool ReverseDirection = false>
struct ValidationTraits
{
static constexpr bool value =
ValidationTraitsImpl<InDstrEncode, ReverseDirection, 64>::value ||
ValidationTraitsImpl<InDstrEncode, ReverseDirection, 32>::value ||
ValidationTraitsImpl<InDstrEncode, ReverseDirection, 16>::value;
static constexpr index_t LaneGroupSize =
ValidationTraitsImpl<InDstrEncode, ReverseDirection, 64>::value ? 64
: ValidationTraitsImpl<InDstrEncode, ReverseDirection, 32>::value ? 32
: ValidationTraitsImpl<InDstrEncode, ReverseDirection, 16>::value ? 16
: 0;
};
};
template <typename TileDistribution_, typename DataType_, typename Policy>
struct TransposeTileDistrChecker
{
using InDstrEncode = typename remove_cvref_t<TileDistribution_>::DstrEncode;
using Validator = typename Policy::template ValidationTraits<InDstrEncode>;
static constexpr bool distr_encoding_valid = Validator::value;
};
// this is used to generate the transposed output tile distribution encoding
// based on the input tile distribution encoding
template <typename TileDistributionEncoding_,
typename DataType_,
typename Policy = DefaultTranspose<DataType_>,
bool ReverseDirection = false>
struct TransposeTileDistributionTraits
{
using InDstrEncode = remove_cvref_t<TileDistributionEncoding_>;
static constexpr auto input_hs_lengthss = InDstrEncode::hs_lengthss_;
static constexpr index_t LaneGroupSize =
Policy::template ValidationTraits<InDstrEncode, ReverseDirection>::LaneGroupSize;
static_assert(Policy::template ValidationTraits<InDstrEncode, ReverseDirection>::value,
"The input tile distribution encoding is not valid for transpose!");
using QuadInputEncoding = std::conditional_t< //
ReverseDirection,
typename Policy::template QuadOutputEncoding<LaneGroupSize>,
typename Policy::template QuadInputEncoding<LaneGroupSize>>;
using QuadOutputEncoding = std::conditional_t< //
ReverseDirection,
typename Policy::template QuadInputEncoding<LaneGroupSize>,
typename Policy::template QuadOutputEncoding<LaneGroupSize>>;
static constexpr auto quad_input_hs_lengthss = QuadInputEncoding::hs_lengthss_;
static constexpr auto quad_output_hs_lengthss = QuadOutputEncoding::hs_lengthss_;
static constexpr auto input_ps_to_rhss_major = InDstrEncode::ps_to_rhss_major_;
static constexpr auto input_ps_to_rhss_minor = InDstrEncode::ps_to_rhss_minor_;
static constexpr auto input_ys_to_rhs_major = InDstrEncode::ys_to_rhs_major_;
static constexpr auto input_ys_to_rhs_minor = InDstrEncode::ys_to_rhs_minor_;
static constexpr auto I0 = number<0>{};
static constexpr auto quad_input_ps_to_rhss_major0 = QuadInputEncoding::ps_to_rhss_major_[I0];
static constexpr auto quad_input_ps_to_rhss_minor0 = QuadInputEncoding::ps_to_rhss_minor_[I0];
static constexpr auto quad_output_ps_to_rhss_major0 = QuadOutputEncoding::ps_to_rhss_major_[I0];
static constexpr auto quad_output_ps_to_rhss_minor0 = QuadOutputEncoding::ps_to_rhss_minor_[I0];
static constexpr auto quad_output_ys_to_rhs_major = QuadOutputEncoding::ys_to_rhs_major_;
static constexpr auto quad_output_ys_to_rhs_minor = QuadOutputEncoding::ys_to_rhs_minor_;
static constexpr index_t dim0 = Policy::transpose_dims[0];
static constexpr index_t dim1 = Policy::transpose_dims[1];
static constexpr auto swap_one_and_two = [](const index_t idx) {
return (idx == 1) ? 2 : (idx == 2) ? 1 : idx;
};
// for transpose load
// remove the quad_input_hs_lengthss from the input_hs_lengthss for each dimension and reverse
// dims and append the quad_output_hs_lengthss to the end of each dimension
static constexpr auto outer_hs_lengthss = generate_tuple(
[](auto i) {
constexpr auto input_i = input_hs_lengthss[i];
constexpr auto outer_len = input_i.size() - quad_input_hs_lengthss[i].size();
return typename sequence_split<decltype(input_i), outer_len>::left_type{};
},
number<InDstrEncode::NDimX>{});
static constexpr auto reversed_outer_hs_lengthss = tuple_reverse(outer_hs_lengthss);
static constexpr auto dst_out_hs_lengthss = generate_tuple(
[](auto i) {
auto outer_i = reversed_outer_hs_lengthss[i];
// append the reversed quad output hs lengths to the outer hs lengths
return outer_i.push_back(quad_output_hs_lengthss[i]);
},
number<InDstrEncode::NDimX>{});
// for PS→RHS mapping(both major and minor), we need to modify the last element (which is for
// thread distr) of the major sequence
static constexpr auto dst_ps_to_rhss_major = generate_tuple(
// for major because of dst_out_hs_lengthss is reversed, this index also need to be reversed
[](auto i) {
if constexpr(i == input_ps_to_rhss_major.size() - 1)
{
constexpr auto current_size = input_ps_to_rhss_major[i].size();
constexpr auto reduce_size = quad_input_ps_to_rhss_major0.size();
constexpr auto quad_out = quad_output_ps_to_rhss_major0;
constexpr auto reduced_ps_to_rhss_major = input_ps_to_rhss_major[i].extract(
typename arithmetic_sequence_gen<0, current_size - reduce_size, 1>::type{});
return reduced_ps_to_rhss_major.transform(swap_one_and_two).push_back(quad_out);
}
else
{
// For all other sequences (i.e. warp), keep them unchanged
return input_ps_to_rhss_major[i].transform(swap_one_and_two);
}
},
number<input_ps_to_rhss_major.size()>{});
static constexpr auto quad_idx_offset =
transform_tuples([](auto x) { return number<x.size()>{}; }, reversed_outer_hs_lengthss);
// minus 1 because RsLength is not counted
static constexpr auto quad_output_ps_minor_offset = to_sequence(generate_tuple_for(
[](auto x) { return quad_idx_offset[number<x - 1>{}]; }, quad_output_ps_to_rhss_major0));
static constexpr auto quad_output_ys_minor_offset = to_sequence(generate_tuple_for(
[](auto x) { return quad_idx_offset[number<x - 1>{}]; }, quad_output_ys_to_rhs_major));
static constexpr auto dst_ps_to_rhss_minor = generate_tuple(
[](auto i) {
constexpr auto input_i = input_ps_to_rhss_minor[i];
if constexpr(i == input_ps_to_rhss_minor.size() - 1)
{
constexpr auto outer_len = input_i.size() - quad_input_ps_to_rhss_minor0.size();
constexpr auto outer_ps =
typename sequence_split<decltype(input_i), outer_len>::left_type{};
return outer_ps.push_back(quad_output_ps_minor_offset +
quad_output_ps_to_rhss_minor0);
}
else
{
// For all other sequences, keep them unchanged
return input_i;
}
},
number<input_ps_to_rhss_minor.size()>{});
static constexpr auto outer_input_ys_to_rhs_major = input_ys_to_rhs_major.pop_back();
// for major because of dst_out_hs_lengthss is reversed, this index also need to be reversed
static constexpr auto dst_ys_to_rhs_major =
outer_input_ys_to_rhs_major.transform(swap_one_and_two).push_back(number<2>{});
static constexpr auto dst_ys_to_rhs_minor = input_ys_to_rhs_minor.pop_back().push_back(
number<(quad_output_ys_minor_offset + quad_output_ys_to_rhs_minor)[I0]>{});
using TransposedDstrEncode =
tile_distribution_encoding<typename InDstrEncode::RsLengths,
remove_cvref_t<decltype(dst_out_hs_lengthss)>,
remove_cvref_t<decltype(dst_ps_to_rhss_major)>,
remove_cvref_t<decltype(dst_ps_to_rhss_minor)>,
remove_cvref_t<decltype(dst_ys_to_rhs_major)>,
remove_cvref_t<decltype(dst_ys_to_rhs_minor)>>;
};
template <typename TileDistributionEncoding_,
typename DataType_,
typename Policy = DefaultTranspose<DataType_>>
using OutputTileDistributionTraits =
TransposeTileDistributionTraits<TileDistributionEncoding_, DataType_, Policy, false>;
template <typename TileDistributionEncoding_,
typename DataType_,
typename Policy = DefaultTranspose<DataType_>>
using InputTileDistributionTraits =
TransposeTileDistributionTraits<TileDistributionEncoding_, DataType_, Policy, true>;
template <typename InnerEncode,
index_t kLeadIterPerWarp,
index_t kSecondIterPerWarp,
index_t kLeadNumWarps,
index_t kSecondNumWarps>
CK_TILE_HOST_DEVICE constexpr auto InputTileDistributionEncoding()
{
constexpr auto block_outer_dst_encoding =
tile_distribution_encoding<sequence<>,
tuple<sequence<kSecondIterPerWarp, kSecondNumWarps>,
sequence<kLeadIterPerWarp, kLeadNumWarps>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 1>>,
sequence<2, 1>,
sequence<0, 0>>{};
constexpr auto blk_distr_encode =
detail::make_embed_tile_distribution_encoding(block_outer_dst_encoding, InnerEncode{});
return blk_distr_encode;
}
/**
* @brief transpose loads tile from a tensor and returns the resulting tensor with a new
* (transposed) tile distribution. use SFINAE to ensure the tile distribution encoding is valid.
*
* This function is intended for use with statically distributed tensor tiles, where the input
* and output tile distributions differ due to the transpose operation. It ensures that the
* element space size and vector length remain consistent between the input and output
* distributions.
*
* @tparam BottomTensorView_ The type of the bottom tensor view.
* @tparam WindowLengths_ The type representing the window lengths.
* @tparam TileDistribution_ The type representing the tile distribution.
* @tparam NumCoord The number of coordinates (dimensions).
* @tparam Policy The transpose policy to use (defaults to DefaultTranspose).
* the last is SFINAE to ensure the tile distribution encoding is valid.
*
* @param tile_window The tile window with static distribution to load and transpose.
*
* @return A statically distributed tensor containing the transposed tile data.
*
* @note
* - The function uses compile-time checks to ensure the input and output tile distributions
* are compatible in terms of element space size and vector length.
* - The transpose operation is performed according to the specified Policy.
*/
template <
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
typename Policy = DefaultTranspose<typename BottomTensorView_::DataType>,
typename = std::enable_if_t<TransposeTileDistrChecker<TileDistribution_,
typename BottomTensorView_::DataType,
Policy>::distr_encoding_valid,
Policy>>
CK_TILE_DEVICE auto
load_tile_transpose(const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window)
{
using OutTileDstrEncode = typename OutputTileDistributionTraits<
typename TileDistribution_::DstrEncode,
typename BottomTensorView_::DataType>::TransposedDstrEncode;
auto out_tensor = make_static_distributed_tensor<typename BottomTensorView_::DataType>(
make_static_tile_distribution(OutTileDstrEncode{}));
auto trans_tensor = tile_window.template load_transpose<Policy>();
constexpr auto input_distr = TileDistribution_{};
constexpr auto output_distr = make_static_tile_distribution(OutTileDstrEncode{});
constexpr auto y_in_desc = input_distr.get_ys_to_d_descriptor();
constexpr auto y_out_desc = output_distr.get_ys_to_d_descriptor();
constexpr index_t NDimYIn = input_distr.get_num_of_dimension_y();
constexpr index_t NDimYOut = output_distr.get_num_of_dimension_y();
constexpr auto y_in_lengths = to_sequence(y_in_desc.get_lengths());
constexpr auto y_out_lengths = to_sequence(y_out_desc.get_lengths());
constexpr auto y_in_element_space_size = y_in_desc.get_element_space_size();
constexpr auto y_out_element_space_size = y_out_desc.get_element_space_size();
static_assert(y_in_element_space_size == y_out_element_space_size,
"the element space size is not the same!");
static_assert(y_in_lengths[NDimYIn - 1] == y_out_lengths[NDimYOut - 1],
"the vector length is not the same!");
constexpr index_t vecLoadSize = y_in_lengths[NDimYIn - 1];
constexpr index_t num_of_access =
reduce_on_sequence(y_in_lengths, multiplies{}, number<1>{}) / vecLoadSize;
using DataVec = array<typename BottomTensorView_::DataType, vecLoadSize>;
static_for<0, num_of_access, 1>{}([&](auto iAccess) {
out_tensor.get_thread_buffer().template set_as<DataVec>(
number<iAccess>{},
trans_tensor.get_thread_buffer().template get_as<DataVec>(number<iAccess>{}));
});
return out_tensor;
}
} // namespace ck_tile

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -53,10 +53,13 @@ struct is_null_tile_window<null_tile_window<T>> : public std::true_type
};
} // namespace impl
template <typename T>
constexpr bool is_null_tile_window_v = impl::is_null_tile_window<remove_cvref_t<T>>::value;
template <typename T>
CK_TILE_DEVICE constexpr auto is_null_tile_window(const T&)
{
return impl::is_null_tile_window<remove_cvref_t<T>>::value;
return is_null_tile_window_v<remove_cvref_t<T>>;
}
template <typename WindowLengths>

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -129,7 +129,10 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT
// set output vectors
static_for<0, num_vec_out, 1>{}([&](auto i) {
constexpr auto idx_y_out_tmp = generate_array(
[&](auto ii) { return ii == y_dim_vec_in ? idx_y_start[ii] + i : idx_y_start[ii]; },
[&](auto ii) {
return ii == y_dim_vec_in ? static_cast<index_t>(idx_y_start[ii]) + i
: static_cast<index_t>(idx_y_start[ii]);
},
number<NDimY>{});
constexpr auto idx_y_out =

View File

@@ -303,6 +303,6 @@ struct tile_sweeper
template <typename T,
typename F,
typename U = typename uniform_sequence_gen<T::get_num_of_dimension(), 1>::type>
CK_TILE_HOST_DEVICE_EXTERN tile_sweeper(const T&, const F&, U = {})->tile_sweeper<T, F, U>;
CK_TILE_HOST_DEVICE_EXTERN tile_sweeper(const T&, const F&, U = {}) -> tile_sweeper<T, F, U>;
} // namespace ck_tile

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -81,7 +81,7 @@ struct tensor_adaptor
template <index_t IDimHidden>
CK_TILE_HOST_DEVICE static constexpr auto
get_transform_and_its_upper_dimension(number<IDimHidden>)
get_transform_and_its_upper_dimension(number<IDimHidden>)
{
// FIXME: length of bottom dimension is not known, since info about lower dim length are not
// saved in transformation
@@ -119,13 +119,13 @@ struct tensor_adaptor
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_hidden_dimension()
{
constexpr auto all_low_dim_ids = unpack(
[](auto&&... xs) constexpr { return merge_sequences(xs...); },
LowerDimensionHiddenIdss{});
constexpr auto all_low_dim_ids =
unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); },
LowerDimensionHiddenIdss{});
constexpr auto all_up_dim_ids = unpack(
[](auto&&... xs) constexpr { return merge_sequences(xs...); },
UpperDimensionHiddenIdss{});
constexpr auto all_up_dim_ids =
unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); },
UpperDimensionHiddenIdss{});
constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids);
@@ -259,6 +259,7 @@ struct tensor_adaptor
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() { return is_static(); }
template <index_t Internal = 0>
CK_TILE_HOST_DEVICE static constexpr auto get_top_dimension_safe_vector_length_strides(
const array<index_t, ndim_hidden_>& guaranteed_vector_lengths,
const array<index_t, ndim_hidden_>& guaranteed_vector_strides)
@@ -266,7 +267,9 @@ struct tensor_adaptor
auto vector_lengths = guaranteed_vector_lengths;
auto vector_strides = guaranteed_vector_strides;
static_for<0, get_num_of_transform(), 1>{}([&](auto itran) {
static_for<0,
Internal ? std::min(Internal, get_num_of_transform()) : get_num_of_transform(),
1>{}([&](auto itran) {
constexpr auto low_dims = get_lower_dimension_hidden_idss().at(itran);
constexpr auto up_dims = get_upper_dimension_hidden_idss().at(itran);
@@ -298,42 +301,16 @@ struct tensor_adaptor
set_container_subset(vector_lengths, up_dims, up_vector_lengths);
set_container_subset(vector_strides, up_dims, up_vector_strides);
});
constexpr auto top_dims = TopDimensionHiddenIds{};
return make_tuple(get_container_subset(vector_lengths, top_dims),
get_container_subset(vector_strides, top_dims));
}
CK_TILE_HOST_DEVICE void print() const
{
printf("tensor_adaptor{");
//
printf("transforms: ");
print(transforms_);
printf(", ");
//
printf("LowerDimensionHiddenIds: ");
print(LowerDimensionHiddenIdss{});
printf(", ");
//
printf("UpperDimensionHiddenIds: ");
print(UpperDimensionHiddenIdss{});
printf(", ");
//
printf("BottomDimensionHiddenIds: ");
print(BottomDimensionHiddenIds{});
printf(", ");
//
printf("TopDimensionHiddenIds: ");
print(TopDimensionHiddenIds{});
printf("}");
if constexpr(Internal > 0)
{
return make_tuple(vector_lengths, vector_strides);
}
else
{
constexpr auto top_dims = TopDimensionHiddenIds{};
return make_tuple(get_container_subset(vector_lengths, top_dims),
get_container_subset(vector_strides, top_dims));
}
}
private:
@@ -341,6 +318,40 @@ struct tensor_adaptor
ElementSize element_size_;
};
template <typename Transforms,
typename LowerDimensionHiddenIdss,
typename UpperDimensionHiddenIdss,
typename BottomDimensionHiddenIds,
typename TopDimensionHiddenIds>
CK_TILE_HOST_DEVICE static void print(const tensor_adaptor<Transforms,
LowerDimensionHiddenIdss,
UpperDimensionHiddenIdss,
BottomDimensionHiddenIds,
TopDimensionHiddenIds>& adaptor)
{
printf("tensor_adaptor{\n");
printf(" transforms: [");
print(adaptor.get_transforms());
printf("],\n");
printf(" LowerDimensionHiddenIds: [");
print(LowerDimensionHiddenIdss{});
printf("],\n");
printf(" UpperDimensionHiddenIds: [");
print(UpperDimensionHiddenIdss{});
printf("],\n");
printf(" BottomDimensionHiddenIds: [");
print(BottomDimensionHiddenIds{});
printf("],\n");
//
printf(" TopDimensionHiddenIds: [");
print(TopDimensionHiddenIds{});
printf("]\n}\n");
}
// Transforms: Tuple<transforms...>
// LowerDimensionOldTopIdss: Tuple<Sequence<...>, ...>
// UpperDimensionNewTopIdss: Tuple<Sequence<...>, ...>
@@ -461,7 +472,7 @@ transform_tensor_adaptor(const OldTensorAdaptor& old_tensor_adaptor,
sequence<0>{}, inclusive_scan_sequence(up_dim_numbers, plus<index_t>{}, number<0>{}));
constexpr auto up_dim_hidden_idss = generate_tuple(
[ old_hidden_dim_number, up_dim_numbers_scan ](auto i) constexpr {
[old_hidden_dim_number, up_dim_numbers_scan](auto i) constexpr {
return
typename arithmetic_sequence_gen<old_hidden_dim_number + up_dim_numbers_scan[i],
old_hidden_dim_number + up_dim_numbers_scan[i + 1],
@@ -470,8 +481,8 @@ transform_tensor_adaptor(const OldTensorAdaptor& old_tensor_adaptor,
number<num_new_transform>{});
// new top dimension's hidden ids
constexpr auto unordered_new_top_dim_hidden_ids = unpack(
[](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss);
constexpr auto unordered_new_top_dim_hidden_ids =
unpack([](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss);
constexpr auto new_top_dim_unordered2ordered = unpack(
[](auto... xs) constexpr { return merge_sequences(xs...); }, NewUpperDimensionNewTopIdss{});
@@ -595,8 +606,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
TensorAdaptor1::get_lower_dimension_hidden_idss()[itran];
// sequence in, sequence out
constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr
{
constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr {
auto low_dim_hidden_ids_1_mod_ = to_multi_index(low_dim_hidden_ids_1);
// shift hidden id so every dim id is unique
@@ -619,8 +629,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
});
return low_dim_hidden_ids_1_mod_;
}
();
}();
return generate_sequence_v2(
[&](auto i) constexpr { return number<low_dim_hidden_ids_1_mod[i]>{}; },
@@ -643,8 +652,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
TensorAdaptor1::get_upper_dimension_hidden_idss()[itran];
// sequence in, constexpr tuple out
constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr
{
constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr {
auto up_dim_hidden_ids_1_mod_ = to_multi_index(up_dim_hidden_ids_1);
// shift hidden id
@@ -653,8 +661,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
});
return up_dim_hidden_ids_1_mod_;
}
();
}();
// constexpr tuple to sequence
return generate_sequence_v2(

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -133,32 +133,45 @@ struct tensor_descriptor : public tensor_adaptor<Transforms,
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() { return is_static(); }
template <index_t Internal = 0>
CK_TILE_HOST_DEVICE static constexpr auto get_top_dimension_safe_vector_length_strides()
{
return Base::get_top_dimension_safe_vector_length_strides(
return Base::template get_top_dimension_safe_vector_length_strides<Internal>(
to_array<index_t, ndim_hidden_>(GuaranteedVectorLengths{}),
to_array<index_t, ndim_hidden_>(GuaranteedVectorStrides{}));
}
CK_TILE_HOST_DEVICE void print() const
{
printf("tensor_descriptor{");
// tensor_adaptor
Base::print();
printf(", ");
// element_space_size_
printf("element_space_size_: ");
print(element_space_size_);
printf("}");
}
// TODO make these private
ElementSpaceSize element_space_size_;
};
template <typename Transforms,
typename LowerDimensionHiddenIdss,
typename UpperDimensionHiddenIdss,
typename TopDimensionHiddenIds,
typename ElementSpaceSize,
typename GuaranteedVectorLengths,
typename GuaranteedVectorStrides>
CK_TILE_HOST_DEVICE static void print(const tensor_descriptor<Transforms,
LowerDimensionHiddenIdss,
UpperDimensionHiddenIdss,
TopDimensionHiddenIds,
ElementSpaceSize,
GuaranteedVectorLengths,
GuaranteedVectorStrides>& descriptor)
{
printf("tensor_descriptor{\n");
// first print the tensor adaptor part of the descriptor using the base class print
print(static_cast<const typename decltype(descriptor)::Base&>(descriptor));
printf("element_space_size_: %ld,\n",
static_cast<long>(descriptor.get_element_space_size().value));
printf("guaranteed_vector_lengths: ");
print(GuaranteedVectorLengths{});
printf(",\nguaranteed_vector_strides: ");
print(GuaranteedVectorStrides{});
printf("}\n}\n");
}
template <typename Adaptor, typename ElementSpaceSize>
CK_TILE_HOST_DEVICE constexpr auto
make_tensor_descriptor_from_adaptor(const Adaptor& adaptor,
@@ -365,12 +378,29 @@ make_naive_tensor_descriptor_packed(const tuple<Lengths...>& lengths,
const auto element_space_size = container_reduce(lengths, multiplies{}, long_number<1>{});
constexpr index_t first_dim_length = []() {
if constexpr(is_constant_v<remove_cvref_t<decltype(element_space_size)>>)
return decltype(element_space_size)::value;
else
return -1;
}();
using last_t = remove_cvref_t<decltype(lengths.template get<N - 1>())>;
constexpr index_t last_dim_length = []() {
if constexpr(is_constant_v<last_t>)
return std::max(last_t::value, GuaranteedLastDimensionVectorLength);
else
return -1;
}();
using GuaranteedVectorLengths =
typename sequence_merge<typename uniform_sequence_gen<N, -1>::type,
sequence<GuaranteedLastDimensionVectorLength>>::type;
typename sequence_merge<sequence<first_dim_length>,
typename uniform_sequence_gen<N - 1, -1>::type,
sequence<last_dim_length>>::type;
using GuaranteedVectorStrides =
typename sequence_merge<typename uniform_sequence_gen<N, -1>::type, sequence<1>>::type;
typename sequence_merge<sequence<1>,
typename uniform_sequence_gen<N - 1, -1>::type,
sequence<1>>::type;
return tensor_descriptor<remove_cv_t<decltype(transforms)>,
remove_cv_t<decltype(low_dim_hidden_idss)>,

View File

@@ -161,7 +161,8 @@ struct tensor_view
CK_TILE_HOST_DEVICE constexpr void
async_get_vectorized_elements(CK_TILE_LDS_ADDR remove_cvref_t<DataType>* smem,
const TensorCoord& coord,
index_t linear_offset) const
index_t linear_offset,
bool_constant<oob_conditional_check> = {}) const
{
return buf_.template async_get<X>(
smem,
@@ -181,7 +182,8 @@ struct tensor_view
async_get_vectorized_elements(CK_TILE_LDS_ADDR remove_cvref_t<DataType>* smem,
const TensorCoord& coord,
index_t linear_offset,
bool is_valid_element) const
bool is_valid_element,
bool_constant<oob_conditional_check> = {}) const
{
return buf_.template async_get<X>(smem,
coord.get_offset() / PackedSize,
@@ -210,6 +212,27 @@ struct tensor_view
bool_constant<pre_nop>{});
}
template <typename X,
bool pre_nop = false,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void
async_get_vectorized_elements_raw(remove_cvref_t<DataType>* smem,
const TensorCoord& coord,
index_t coord_extra_offset,
index_t linear_offset,
bool_constant<pre_nop> = {}) const
{
return buf_.template async_get_raw<X>(
smem,
(coord.get_offset() + coord_extra_offset) / PackedSize,
linear_offset / PackedSize,
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
bool_constant<pre_nop>{});
}
template <typename X,
bool pre_nop = false,
typename std::enable_if<
@@ -230,6 +253,33 @@ struct tensor_view
bool_constant<pre_nop>{});
}
template <typename X,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<X>
get_transpose_vectorized_elements(const TensorCoord& coord, index_t linear_offset) const
{
return buf_.template transpose_get<X>(
coord.get_offset(),
linear_offset,
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord));
}
template <typename X,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<X>
get_transpose_vectorized_elements(const TensorCoord& coord,
index_t linear_offset,
bool is_valid_element // flag
) const
{
return buf_.template transpose_get<X>(coord.get_offset(), linear_offset, is_valid_element);
}
// X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template <typename X,
@@ -384,22 +434,6 @@ struct tensor_view
coord.get_offset() / PackedSize, linear_offset / PackedSize, is_valid_element, x);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("tensor_view{");
// buf_
printf("buf_: ");
print(buf_);
printf(", ");
// desc_
printf("desc_: ");
print(desc_);
printf("}");
}
// member
buffer_view buf_;
TensorDesc desc_;
@@ -411,18 +445,21 @@ struct null_tensor_view
};
template <address_space_enum BufferAddressSpace = address_space_enum::generic,
amd_buffer_coherence_enum Coherence = amd_buffer_coherence_enum::coherence_default,
typename DataType,
typename... Ts>
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType* p,
const tensor_descriptor<Ts...>& desc)
{
auto buffer_view = make_buffer_view<BufferAddressSpace>(p, desc.get_element_space_size());
auto buffer_view =
make_buffer_view<BufferAddressSpace, Coherence>(p, desc.get_element_space_size());
return tensor_view<decltype(buffer_view), decltype(desc)>{buffer_view, desc};
}
template <address_space_enum BufferAddressSpace = address_space_enum::generic,
memory_operation_enum DstInMemOp = memory_operation_enum::set,
amd_buffer_coherence_enum Coherence = amd_buffer_coherence_enum::coherence_default,
typename DataType,
typename... Lengths,
typename... Strides,
@@ -441,12 +478,14 @@ make_naive_tensor_view(DataType* p,
number<GuaranteedLastDimensionVectorLength>{},
number<GuaranteedLastDimensionVectorStride>{});
auto buffer_view = make_buffer_view<BufferAddressSpace>(p, desc.get_element_space_size());
auto buffer_view =
make_buffer_view<BufferAddressSpace, Coherence>(p, desc.get_element_space_size());
return tensor_view<decltype(buffer_view), decltype(desc), DstInMemOp>{buffer_view, desc};
}
template <address_space_enum BufferAddressSpace = address_space_enum::generic,
amd_buffer_coherence_enum Coherence = amd_buffer_coherence_enum::coherence_default,
typename DataType,
typename... Lengths,
index_t GuaranteedLastDimensionVectorLength = -1>
@@ -458,7 +497,8 @@ make_naive_tensor_view_packed(DataType* p,
auto desc =
make_naive_tensor_descriptor_packed(lengths, number<GuaranteedLastDimensionVectorLength>{});
auto buffer_view = make_buffer_view<BufferAddressSpace>(p, desc.get_element_space_size());
auto buffer_view =
make_buffer_view<BufferAddressSpace, Coherence>(p, desc.get_element_space_size());
return tensor_view<decltype(buffer_view), decltype(desc)>{buffer_view, desc};
}
@@ -488,6 +528,7 @@ template <typename TensorView,
CK_TILE_HOST_DEVICE constexpr auto
pad_tensor_view(const TensorView& tensor_view, const TileLengths& tile_lengths, DoPads)
{
constexpr index_t num_dim = DoPads::size();
static_assert(num_dim == TileLengths::size() && num_dim == TensorView::get_num_of_dimension(),

View File

@@ -204,7 +204,7 @@ struct tile_distribution
// FIXME: it's hacky to get Y index from Distributed-Index
template <typename DistributedIndices>
CK_TILE_HOST_DEVICE static constexpr auto
get_y_indices_from_distributed_indices(DistributedIndices)
get_y_indices_from_distributed_indices(DistributedIndices)
{
constexpr auto ys_idx_arr = [] {
array<index_t, NDimY> ys_idx;
@@ -230,24 +230,6 @@ struct tile_distribution
{
return PsYs2XsAdaptor::is_static() && Ys2DDescriptor::is_static();
}
CK_TILE_HOST_DEVICE void print() const
{
printf("tile_distribution{");
//
printf("tile_distribution_encoding: ");
print(DstrEncode{});
printf(", ");
//
printf("ps_ys_to_xs_: ");
print(ps_ys_to_xs_);
printf(", ");
//
printf("ys_to_d_: ");
print(ys_to_d_);
//
printf("}");
}
};
namespace detail {
@@ -268,7 +250,7 @@ CK_TILE_HOST_DEVICE constexpr auto make_sequential_index(index_t ibegin, index_t
// this returns a constexpr encoding of tile_distribution
template <typename StaticTileDistributionEncoding_>
CK_TILE_HOST_DEVICE constexpr auto
make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_)
make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_)
{
using RsLengths = typename StaticTileDistributionEncoding_::RsLengths;
using HsLengthss = typename StaticTileDistributionEncoding_::HsLengthss;
@@ -544,26 +526,26 @@ namespace detail {
//
// e.g
// X0 X1
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 32>, (0 means all length)
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice start:<0, 0>, end:<-1, 32>, (-1 means the last one)
// Y P P Y P Y P Y
// => <1, 4, 32> - <1, 1, 4, 2, 4> -> OK
// |--> slice along this Y dim, is the first dim of X1, totally 4 slices
//
// X0 X1
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 8>, (0 means all length)
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice start:<0, 0>, end:<-1, 8>, (-1 means the last one)
// Y P P Y P Y P Y
// => <1, 4, 32> - <1, 1, 1, 2, 4> -> OK
// |--> slice along this Y dim, the P dim is 1 in the left, so is OK
// totally 16 slices
//
// X0 X1
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 4>, (0 means all length)
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice start:<0, 0>, end:<-1, 4>, (-1 means the last one)
// Y P P Y P Y P Y
// => <1, 4, 32> - <1, 1, 1, 1, 4> -> Fail
// |--> slice along this P dim, will split threads, not supported
//
// X0 X1
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 16>, (0 means all length)
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice start:<0, 0>, end:<-1, 16>, (-1 means the last one)
// Y P P Y P Y P Y
// => <1, 4, 32> - <1, 1, 2, 2, 4> -> OK
// |--> slice along this Y dim, but this Y sim need to split into 2
@@ -579,27 +561,55 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
using Encoding = decltype(Distribution::get_static_tile_distribution_encoding());
static_assert(sizeof...(XSliceBegins) == sizeof...(XSliceEnds));
static_assert(sizeof...(XSliceBegins) == Encoding::NDimX, "only support slice over h, not r");
constexpr auto x_slice_lengths = x_slice_ends - x_slice_begins;
constexpr auto p_len_over_h = Encoding::detail::get_uniformed_p_dim_lengths_over_h();
constexpr auto x_slice_ends_ = generate_sequence_v2(
[&](auto i) {
if constexpr(x_slice_ends[i] == -1)
{
// -1 means till the end
constexpr auto x_length_ =
container_reduce(typename Encoding::HsLengthss{}[i], multiplies{}, number<1>{});
return x_length_;
}
else
{
return x_slice_ends[i];
}
},
number<x_slice_ends.size()>{});
constexpr auto x_slice_lengths = x_slice_ends_ - x_slice_begins;
constexpr auto x_slice_lengths_without_p = generate_sequence_v2(
[&](auto i) constexpr {
constexpr auto len_ = x_slice_lengths[i];
static_assert(len_ % p_len_over_h[i] == 0,
"slice length must be dividable by p_len_over_h");
return number<len_ / p_len_over_h[i]>{};
},
number<x_slice_lengths.size()>{});
constexpr auto src_h_prefix_sum = Encoding::detail::get_h_dim_lengths_prefix_sum();
constexpr auto src_y_info = Encoding::detail::get_sorted_y_info();
constexpr auto src_y_info = Encoding::detail::get_sorted_y_to_h_info();
constexpr auto src_y_dims = src_y_info[number<0>{}];
constexpr auto src_y_maps = src_y_info[number<1>{}];
constexpr auto src_y_prefix_sum = src_y_info[number<2>{}];
constexpr auto sliced_hlen_yidx_ylen = [&]() constexpr
{
constexpr auto sliced_hlen_yidx_ylen = [&]() constexpr {
auto y_slice_sorted_origins = make_zero_multi_index<Encoding::NDimY>();
auto y_slice_lengths = Encoding::detail::ys_lengths_;
constexpr auto y_to_h_masks = Encoding::detail::get_y_to_h_masks();
// This lambda will modify some value outside, so c++ will not treat return value as
// constexpr
// TODO: ugly
auto new_h_lengths = transform_tuples(
[&](auto h_len, auto id) {
constexpr auto sliced_h =
reverse_slice_sequence(h_len, number<x_slice_lengths[id]>{});
constexpr auto sliced_h = reverse_slice_sequence(
h_len, number<x_slice_lengths_without_p[id]>{}, y_to_h_masks[id]);
constexpr auto sliced_h_lens = sliced_h[number<0>{}];
constexpr auto sliced_h_index = sliced_h[number<2>{}];
@@ -607,26 +617,39 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
// update y_slice_lengths
constexpr auto uniformed_h_index = sliced_h_index + number<src_h_prefix_sum[id]>{};
constexpr auto found_y_index = container_find(src_y_dims, uniformed_h_index);
constexpr auto y_to_h_dim_end = src_y_prefix_sum[id + 1];
static_assert(found_y_index >= 0 && found_y_index < src_y_dims.size(),
"not sliced at y dim, please check");
static_for<0, sliced_h_index + 1, 1>{}([&](auto i) {
y_slice_lengths(src_y_maps[found_y_index - i]) =
sliced_h_lens[sliced_h_index - i];
});
{
constexpr auto sliced_y_to_h_lens =
pick_sequence_elements_by_mask(sliced_h_lens, y_to_h_masks[id]);
constexpr auto sliced_y_to_h_dims = sliced_y_to_h_lens.size();
static_for<0, sliced_y_to_h_dims, 1>{}([&](auto i) {
y_slice_lengths(src_y_maps[y_to_h_dim_end - 1 - i]) =
sliced_y_to_h_lens[sliced_y_to_h_dims - 1 - i];
});
}
// TODO: add validations not across p dim
// NOTE: this y_origin is for all dims, not only current dim
// will later use pick to select target dim
constexpr auto y_origin = [&]() {
constexpr auto h_trans = make_merge_transform_v3_division_mod(h_len);
auto h_origin_ = make_zero_multi_index<h_trans.NDimLow>();
h_trans.calculate_lower_index(h_origin_, sequence<x_slice_begins[id].value>{});
// can't use Encoding::Ys2RHsMajor/Ys2RHsMinor, these are unordered
constexpr auto y_to_h_len =
pick_sequence_elements_by_mask(h_len, y_to_h_masks[id]);
constexpr auto y_to_h_dims = y_to_h_len.size();
constexpr auto h_trans = make_merge_transform_v3_division_mod(y_to_h_len);
auto h_origin_ = make_zero_multi_index<h_trans.NDimLow>();
constexpr auto y_begin_ = x_slice_begins[id] / p_len_over_h[id];
h_trans.calculate_lower_index(h_origin_, sequence<y_begin_.value>{});
auto y_origin_ = make_zero_multi_index<Encoding::NDimY>();
static_for<0, sliced_h_index + 1, 1>{}([&](auto i) {
y_origin_(found_y_index - i) = h_origin_[sliced_h_index - i];
static_for<0, y_to_h_dims, 1>{}([&](auto i) {
y_origin_(y_to_h_dim_end - 1 - i) = h_origin_[y_to_h_dims - 1 - i];
});
return y_origin_;
}();
@@ -645,8 +668,7 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
auto y_slice_origins = container_reorder_given_old2new(y_slice_sorted_origins, src_y_maps);
return make_tuple(new_h_lengths, y_slice_origins, y_slice_lengths);
}
();
}();
constexpr auto sliced_h_lengths = sliced_hlen_yidx_ylen[number<0>{}];
constexpr auto sliced_y_origins_array = sliced_hlen_yidx_ylen[number<1>{}];
@@ -672,4 +694,27 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
}
} // namespace detail
// Free print function for tile_distribution
template <typename PsYs2XsAdaptor_,
typename Ys2DDescriptor_,
typename StaticTileDistributionEncoding_,
typename TileDistributionDetail_>
CK_TILE_HOST_DEVICE void print(const tile_distribution<PsYs2XsAdaptor_,
Ys2DDescriptor_,
StaticTileDistributionEncoding_,
TileDistributionDetail_>& distribution)
{
printf("tile_distribution{");
printf("tile_distribution_encoding: ");
print(StaticTileDistributionEncoding_{});
printf(", ");
printf("ps_ys_to_xs_: ");
print(distribution.ps_ys_to_xs_);
printf(", ");
printf("ys_to_d_: ");
print(distribution.ys_to_d_);
printf("}\n");
}
} // namespace ck_tile

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -47,6 +47,11 @@ struct tile_distribution_encoding
static constexpr auto ys_to_rhs_major_ = Ys2RHsMajor{};
static constexpr auto ys_to_rhs_minor_ = Ys2RHsMinor{};
#if !CK_TILE_ENC_SUPPORT_Y_TO_R
static_assert(container_find(ys_to_rhs_major_, 0) == NDimY,
"do not support Y dim pointed to R dim");
#endif
// redundant but useful info
// TODO: really bad code, should be over-hauled
struct detail
@@ -255,33 +260,107 @@ struct tile_distribution_encoding
}
}();
// e.g. tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>> --> seq<3, 5> --> seq<0, 3, 8>
CK_TILE_HOST_DEVICE static constexpr auto get_h_dim_lengths_prefix_sum()
CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_h_dim_lengths()
{
// <len_d0, len_d1, ...>
// e.g. tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>> --> seq<3, 5>
constexpr auto uniformed_h_dim_lengths = generate_sequence_v2(
[&](auto i) {
constexpr index_t size = HsLengthss{}[i].size();
return number<size>{};
constexpr index_t size_ = HsLengthss{}[i].size();
return number<size_>{};
},
number<NDimX>{});
return uniformed_h_dim_lengths;
}
// note: this function only count the p dim length along h, not r
CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_p_dim_lengths_over_h()
{
// e.g. tuple<seq<1, 4, 32>, seq<1, 2, 8, 4, 4>>
// Y P Y Y P Y P Y
// | | |
// v v v
// return : seq<4, 2 * 4> => seq<4, 8>
constexpr auto uniformed_ps_to_rhss_major_ =
unpack([](auto... xs_) { return merge_sequences(xs_...); }, ps_to_rhss_major_);
constexpr auto uniformed_ps_to_rhss_minor_ =
unpack([](auto... xs_) { return merge_sequences(xs_...); }, ps_to_rhss_minor_);
constexpr auto p_len_ = [&]() {
array<index_t, NDimX> len_{1};
static_for<0, NDimX, 1>{}([&](auto idim_x_) {
constexpr auto major_ = number<idim_x_ + 1>{}; // RDim
static_for<0, uniformed_ps_to_rhss_major_.size(), 1>{}([&](auto idim_u_) {
if constexpr(major_.value == uniformed_ps_to_rhss_major_[idim_u_])
{
constexpr auto minor_ = uniformed_ps_to_rhss_minor_[idim_u_];
constexpr auto h_length_ = hs_lengthss_[idim_x_][minor_];
len_[idim_x_] *= h_length_;
}
});
});
return len_;
}();
constexpr auto p_len_over_h_seq_ = TO_SEQUENCE(p_len_, NDimX);
return p_len_over_h_seq_;
}
//
// R: seq<3>, H: tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>>
// => return seq<1, 3, 5>
// R: seq<>, H: tuple<seq<2, 4>, seq<16, 8, 8>>
// => return seq<0, 2, 3>
CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_rh_dim_lengths()
{
constexpr auto uniformed_rh_dim_lengths =
merge_sequences(sequence<NDimR>{} /*for R dims*/, get_uniformed_h_dim_lengths());
return uniformed_rh_dim_lengths;
}
// e.g. tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>> --> seq<3, 5> --> seq<0, 3, 8>
CK_TILE_HOST_DEVICE static constexpr auto get_h_dim_lengths_prefix_sum()
{
// <0, len_d0, len_d0+len_d1, ...>
// e.g. seq<3, 5> --> seq<0, 3, 8>
constexpr auto h_dim_prefix_sum = prefix_sum_sequence(uniformed_h_dim_lengths);
constexpr auto h_dim_prefix_sum = prefix_sum_sequence(get_uniformed_h_dim_lengths());
return h_dim_prefix_sum;
}
CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_idx_y_to_h()
CK_TILE_HOST_DEVICE static constexpr auto get_rh_dim_lengths_prefix_sum()
{
// <0, len_d0, len_d0+len_d1, ...>
// e.g. seq<3, 5> --> seq<0, 3, 8>
constexpr auto rh_dim_prefix_sum = prefix_sum_sequence(get_uniformed_rh_dim_lengths());
return rh_dim_prefix_sum;
}
CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_idx_p_to_h()
{
// tuple<seq<xx..>, seq<yy..>> -> seq<xx..yy..>
constexpr auto uniformed_ps_to_rhss_major_ =
unpack([](auto... xs_) { return merge_sequences(xs_...); }, ps_to_rhss_major_);
constexpr auto uniformed_ps_to_rhss_minor_ =
unpack([](auto... xs_) { return merge_sequences(xs_...); }, ps_to_rhss_minor_);
constexpr auto all_ps_2_rhss = transform_sequences(
[](auto major, auto minor) constexpr {
constexpr auto rh_dim_prefix_sum = get_rh_dim_lengths_prefix_sum();
return rh_dim_prefix_sum.at(major) + minor;
},
uniformed_ps_to_rhss_major_,
uniformed_ps_to_rhss_minor_);
return all_ps_2_rhss;
}
CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_idx_y_to_rh()
{
constexpr auto all_ys_2_rhss = transform_sequences(
[](auto major, auto minor) constexpr {
// <0, 0, len_d0, len_d0+len_d1, ...>
constexpr auto x_dim_prefix_sum = merge_sequences(
sequence<0>{} /*for R dims*/, get_h_dim_lengths_prefix_sum());
return x_dim_prefix_sum.at(major) + minor;
constexpr auto rh_dim_prefix_sum = get_rh_dim_lengths_prefix_sum();
return rh_dim_prefix_sum.at(major) + minor;
},
Ys2RHsMajor{},
Ys2RHsMinor{});
@@ -289,6 +368,45 @@ struct tile_distribution_encoding
return all_ys_2_rhss;
}
CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_idx_y_to_h()
{
// TODO: Y can't point to R
constexpr auto all_ys_2_rhss = transform_sequences(
[](auto major, auto minor) constexpr {
constexpr auto rh_dim_prefix_sum = get_rh_dim_lengths_prefix_sum();
return rh_dim_prefix_sum.at(major) + minor - NDimR;
},
Ys2RHsMajor{},
Ys2RHsMinor{});
return all_ys_2_rhss;
}
// return tuple of seq
CK_TILE_HOST_DEVICE static constexpr auto get_y_to_h_masks()
{
constexpr auto masks_ = generate_tuple(
[&](auto i) {
constexpr auto size_ = HsLengthss{}[i].size();
constexpr auto current_y_to_h_mask_ = [&]() {
array<index_t, size_> m_{0};
// TODO: we loop over all y for each h dim
for(auto j = 0; j < NDimY; j++)
{
if(Ys2RHsMajor{}[j] == (i + 1) /*RDim need plus 1*/)
{
m_[Ys2RHsMinor{}[j]] = 1;
}
}
return m_;
}();
return TO_SEQUENCE(current_y_to_h_mask_, size_);
},
number<NDimX>{});
return masks_;
}
// return tuple<sorted_dims, sorted_maps, sorted_prefix_sum>
template <typename IdxSeq, typename PrefixSumSeq>
CK_TILE_HOST_DEVICE static constexpr auto get_sorted_info(IdxSeq, PrefixSumSeq)
@@ -305,115 +423,34 @@ struct tile_distribution_encoding
return make_tuple(sorted_dims, sorted_maps, sorted_prefix_sum);
}
CK_TILE_HOST_DEVICE static constexpr auto get_sorted_y_info()
// Note here y_to_h does not count R dim!
CK_TILE_HOST_DEVICE static constexpr auto get_sorted_y_to_h_info()
{
return get_sorted_info(get_uniformed_idx_y_to_h(), get_h_dim_lengths_prefix_sum());
}
CK_TILE_HOST_DEVICE void print() const
{
printf("tile_distribution_encoding::detail{");
//
printf("ndim_rh_major_: ");
print(ndim_rh_major_);
printf(", ");
//
printf("ndim_span_major_: ");
print(ndim_span_major_);
printf(", ");
//
printf("ndims_rhs_minor_: ");
print(ndims_rhs_minor_);
printf(", ");
//
printf("ndim_rh_major_: ");
print(ndim_rh_major_);
printf(", ");
//
printf("max_ndim_rh_minor_: ");
print(max_ndim_rh_minor_);
printf(", ");
//
printf("rhs_lengthss_: ");
print(rhs_lengthss_);
printf(", ");
//
printf("ys_lengths_: ");
print(ys_lengths_);
printf(", ");
//
printf("rhs_major_minor_to_ys_: ");
print(rhs_major_minor_to_ys_);
printf(", ");
//
printf("ndims_span_minor_: ");
print(ndims_span_minor_);
printf(", ");
//
printf("max_ndim_span_minor_: ");
print(max_ndim_span_minor_);
printf(", ");
//
printf("ys_to_span_major_: ");
print(ys_to_span_major_);
printf(", ");
//
printf("ys_to_span_minor_: ");
print(ys_to_span_minor_);
printf(", ");
//
printf("distributed_spans_lengthss_: ");
print(distributed_spans_lengthss_);
printf(", ");
//
printf("ndims_distributed_spans_minor_: ");
print(ndims_distributed_spans_minor_);
printf(", ");
//
printf("ps_over_rs_derivative_: ");
print(ps_over_rs_derivative_);
//
printf("}");
}
};
CK_TILE_HOST_DEVICE void print() const
{
printf("tile_distribution_encoding{");
//
printf("NDimX: %d, NDimP: %d, NDimY: %d, ", NDimX, NDimP, NDimY);
//
printf("rs_lengths_: ");
print(rs_lengths_);
printf(", ");
//
printf("hs_lengthss_: ");
print(hs_lengthss_);
printf(", ");
//
printf("ps_to_rhss_major_: ");
print(ps_to_rhss_major_);
printf(", ");
//
printf("ps_to_rhss_minor_: ");
print(ps_to_rhss_minor_);
printf(", ");
//
printf("ys_to_rhs_major_: ");
print(ys_to_rhs_major_);
printf(", ");
//
printf("ys_to_rhs_minor_: ");
print(ys_to_rhs_minor_);
printf(", ");
//
printf("detail: ");
print(detail{});
//
printf("}");
}
};
template <typename encoding, typename shuffle>
class tile_distribution_encoding_shuffle;
template <typename encoding, index_t... shuffle>
class tile_distribution_encoding_shuffle<encoding, sequence<shuffle...>>
{
template <typename Ys2RHs>
using shuffled = sequence<(Ys2RHs::template get<shuffle>())...>;
public:
using type = tile_distribution_encoding<typename encoding::RsLengths,
typename encoding::HsLengthss,
typename encoding::Ps2RHssMajor,
typename encoding::Ps2RHssMinor,
shuffled<typename encoding::Ys2RHsMajor>,
shuffled<typename encoding::Ys2RHsMinor>>;
};
template <typename encoding, typename shuffle>
using tile_distribution_encoding_shuffle_t =
typename tile_distribution_encoding_shuffle<encoding, shuffle>::type;
namespace detail {
template <typename OuterDstr, typename InnerDstr>
@@ -757,4 +794,106 @@ make_reduce_tile_distribution_encoding(InDstr, sequence<InReduceDimXs...> reduce
}
} // namespace detail
// Free print function for tile_distribution_encoding::detail
template <typename RsLengths_,
typename HsLengthss_,
typename Ps2RHssMajor_,
typename Ps2RHssMinor_,
typename Ys2RHsMajor_,
typename Ys2RHsMinor_>
CK_TILE_HOST_DEVICE void
print(const typename tile_distribution_encoding<RsLengths_,
HsLengthss_,
Ps2RHssMajor_,
Ps2RHssMinor_,
Ys2RHsMajor_,
Ys2RHsMinor_>::detail& detail_obj)
{
printf("tile_distribution_encoding::detail{");
printf("ndim_rh_major_: ");
print(detail_obj.ndim_rh_major_);
printf(", ");
printf("ndim_span_major_: ");
print(detail_obj.ndim_span_major_);
printf(", ");
printf("ndims_rhs_minor_: ");
print(detail_obj.ndims_rhs_minor_);
printf(", ");
printf("ndim_rh_major_: ");
print(detail_obj.ndim_rh_major_);
printf(", ");
printf("max_ndim_rh_minor_: ");
print(detail_obj.max_ndim_rh_minor_);
printf(", ");
printf("rhs_lengthss_: ");
print(detail_obj.rhs_lengthss_);
printf(", ");
printf("ys_lengths_: ");
print(detail_obj.ys_lengths_);
printf(", ");
printf("rhs_major_minor_to_ys_: ");
print(detail_obj.rhs_major_minor_to_ys_);
printf(", ");
printf("ndims_span_minor_: ");
print(detail_obj.ndims_span_minor_);
printf(", ");
printf("max_ndim_span_minor_: ");
print(detail_obj.max_ndim_span_minor_);
printf(", ");
printf("ys_to_span_major_: ");
print(detail_obj.ys_to_span_major_);
printf(", ");
printf("ys_to_span_minor_: ");
print(detail_obj.ys_to_span_minor_);
printf(", ");
printf("distributed_spans_lengthss_: ");
print(detail_obj.distributed_spans_lengthss_);
printf(", ");
printf("ndims_distributed_spans_minor_: ");
print(detail_obj.ndims_distributed_spans_minor_);
printf(", ");
printf("ps_over_rs_derivative_: ");
print(detail_obj.ps_over_rs_derivative_);
printf("}");
}
// Free print function for tile_distribution_encoding
template <typename RsLengths_,
typename HsLengthss_,
typename Ps2RHssMajor_,
typename Ps2RHssMinor_,
typename Ys2RHsMajor_,
typename Ys2RHsMinor_>
CK_TILE_HOST_DEVICE void print(const tile_distribution_encoding<RsLengths_,
HsLengthss_,
Ps2RHssMajor_,
Ps2RHssMinor_,
Ys2RHsMajor_,
Ys2RHsMinor_>& encoding)
{
printf("tile_distribution_encoding{");
printf("NDimX: %d, NDimP: %d, NDimY: %d, ", encoding.NDimX, encoding.NDimP, encoding.NDimY);
printf("rs_lengths_: ");
print(encoding.rs_lengths_);
printf(", ");
printf("hs_lengthss_: ");
print(encoding.hs_lengthss_);
printf(", ");
printf("ps_to_rhss_major_: ");
print(encoding.ps_to_rhss_major_);
printf(", ");
printf("ps_to_rhss_minor_: ");
print(encoding.ps_to_rhss_minor_);
printf(", ");
printf("ys_to_rhs_major_: ");
print(encoding.ys_to_rhs_major_);
printf(", ");
printf("ys_to_rhs_minor_: ");
print(encoding.ys_to_rhs_minor_);
printf(", ");
printf("}");
}
} // namespace ck_tile

View File

@@ -59,6 +59,38 @@ CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc& in_element_func,
return out_dstr_tensor;
}
/**
* @brief Template function that "unpacks" a tuple and applies an element-wise operation.
*
* @param in_element_func Function to apply element-wise.
* @param t Any container containing elements to process, with known size and
* tuple-like semantic.
* @return Calls tile_elementwise_inout with unpacked tuple elements.
*/
template <typename InElementFunc, typename Tuple, size_t... I>
CK_TILE_DEVICE auto tile_elementwise_inout_unpack(const InElementFunc& in_element_func,
const Tuple& t,
std::index_sequence<I...>)
{
return tile_elementwise_inout(in_element_func, t[number<I>{}]...);
}
/**
* @brief Template function that "unpacks" a tuple and applies an element-wise operation.
*
* @param in_element_func Function to apply element-wise.
* @param t Any container containing elements to process, with known size and
* tuple-like semantic.
* @return Calls the overloaded function, passing an index sequence.
*/
template <typename InElementFunc, typename Tuple>
CK_TILE_DEVICE auto tile_elementwise_inout_unpack(const InElementFunc& in_element_func,
const Tuple& t)
{
static constexpr auto size = Tuple::size();
return tile_elementwise_inout_unpack(in_element_func, t, std::make_index_sequence<size>{});
}
template <typename DstrTensors, typename T>
CK_TILE_DEVICE void set_tile(DstrTensors& dstr_tensor, const T& value)
{
@@ -295,9 +327,8 @@ CK_TILE_DEVICE auto cast_tile_opt_subdword(const InTensor& in_dstr_tensors)
template <typename DstType, typename SrcTensor>
CK_TILE_DEVICE auto cast_tile(const SrcTensor& src_tensor)
{
if constexpr((std::is_same_v<DstType, fp8_t> ||
std::is_same_v<DstType, bf8_t>)&&std::is_same_v<typename SrcTensor::DataType,
float> &&
if constexpr((std::is_same_v<DstType, fp8_t> || std::is_same_v<DstType, bf8_t>) &&
std::is_same_v<typename SrcTensor::DataType, float> &&
(SrcTensor::get_thread_buffer_size() % 4 == 0))
{
return impl::cast_tile_pk_fp8_fp32<DstType, SrcTensor>(src_tensor);

View File

@@ -0,0 +1,858 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/arch/utility.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
/**
* @brief This class provides tile (windowed) view and access to the device memory.
*
* @note This tile window does not support single issue you need to use tile_window_linear
* structure for this purpose
*
* @tparam BottomTensorView_ Class describing & holding device tensor memory.
* @tparam WindowLengths_ Spatial sizes of windowed view on tensor.
* @tparam StaticTileDistribution_ Thread distribution (mapping) into Tile dimensions
* @tparam NumCoord TBD
*/
template <typename BottomTensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
typename StaticPageIndexArray_,
typename StaticValidArray_,
index_t HsGatherDim = 0,
index_t NumCoord = 1,
index_t YsGatherDim = 0>
struct tile_scatter_gather
{
using BottomTensorView = remove_reference_t<BottomTensorView_>;
using WindowLengths = remove_cvref_t<WindowLengths_>;
using TileDstr = remove_cvref_t<StaticTileDistribution_>;
using PageIdxArray = remove_cvref_t<StaticPageIndexArray_>;
using ValidArray = remove_cvref_t<StaticValidArray_>;
using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor;
using BottomTensorDesc = typename BottomTensorView::TensorDesc;
using DataType = remove_cvref_t<typename BottomTensorView::DataType>;
static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::get_num_of_top_dimension();
static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension();
static constexpr index_t NDimP = TileDstr::get_num_of_dimension_p();
static constexpr index_t NDimY = TileDstr::get_num_of_dimension_y();
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static_assert(NumCoord == 1);
// TODO: check WindowLengths and StaticTileDistribution are consistent
static_assert(ck_tile::is_known_at_compile_time<WindowLengths>::value,
"wrong! lengths should be static");
static_assert(TileDstr::is_static(), "wrong!");
static_assert(NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(),
"wrong! inconsistent # of diemsnions");
using AdaptorTopIndex = array<index_t, NDimWindowAdaptorTop>;
using BottomTensorIndex = array<index_t, NDimBottomTensor>;
using WindowAdaptorCoord =
decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{}));
using BottomTensorCoord =
decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{}));
struct load_store_traits
{
private:
static constexpr auto get_vector_dim_y_scalar_per_vector()
{
const auto [ys_vector_lengths, ys_vector_strides] =
tile_scatter_gather::get_window_adaptor_ys_safe_vector_length_strides();
index_t VectorDimY_ = 0;
index_t ScalarPerVector_ = 1;
for(index_t i = 0; i < NDimY; ++i)
{
if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_)
{
ScalarPerVector_ = ys_vector_lengths[i];
VectorDimY_ = i;
}
}
return make_tuple(VectorDimY_, ScalarPerVector_);
}
public:
static constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>();
static constexpr index_t ScalarPerVector =
get_vector_dim_y_scalar_per_vector().template at<1>();
// using vector_type_t = vector_type_maker_t<DataType, ScalarPerVector>;
// using vector_t = typename vector_type_t::type;
using vector_t = thread_buffer<DataType, ScalarPerVector / PackedSize>;
private:
static constexpr auto scalars_per_access_ = [] {
constexpr auto scalars_per_access_arr = generate_array(
[&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, number<NDimY>{});
/// TODO: add non-automatic storage argument support to macro TO_SEQUENCE()
constexpr auto NDimY_ = NDimY;
return TO_SEQUENCE(scalars_per_access_arr, NDimY_);
}();
static constexpr auto get_space_filling_curve()
{
constexpr auto tile_dstr = TileDstr{};
constexpr auto thread_tensor_lengths_ys =
to_sequence(tile_dstr.get_ys_to_d_descriptor().get_lengths());
// FIXME: need logic to judge dim access order
using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type;
return space_filling_curve<decltype(thread_tensor_lengths_ys),
DimAccessOrder,
decltype(scalars_per_access_)>{};
}
public:
using SFC_Ys = decltype(get_space_filling_curve());
static constexpr index_t NumAccess = SFC_Ys::get_num_of_access();
static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0");
static_assert(NumAccess % NumCoord == 0, "wrong! # of access is not divisible by NumCoord");
};
static constexpr index_t NumAccessPerCoord = load_store_traits::NumAccess / NumCoord;
CK_TILE_DEVICE constexpr tile_scatter_gather() = default;
CK_TILE_DEVICE constexpr tile_scatter_gather(const BottomTensorView& bottom_tensor_view,
const WindowLengths& window_lengths,
const BottomTensorIndex& window_origin,
const TileDstr& tile_distribution,
const PageIdxArray& page_idx,
const ValidArray& valids)
: bottom_tensor_view_{bottom_tensor_view},
window_lengths_{window_lengths},
window_origin_{window_origin},
tile_dstr_{tile_distribution},
page_idx_{page_idx},
valids_{valids},
pre_computed_coords_{}
{
#if 0 // debug
// TODO: this use more register for FA, but less register for GEMM
// need investigation
// only support warp-tile and block-tile
static_assert(NDimP == 1 or NDimP == 2, "wrong!");
WindowAdaptorCoord window_adaptor_thread_coord_tmp;
if constexpr(NDimP == 1)
{
window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
tile_distribution.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0});
}
else if constexpr(NDimP == 2)
{
window_adaptor_thread_coord_tmp =
make_tensor_adaptor_coordinate(tile_distribution.get_ps_ys_to_xs_adaptor(),
AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
}
#else
// TODO: this use less register for FA, but more register for GEMM
// need investigation
const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
tile_distribution.get_ps_ys_to_xs_adaptor(),
container_concat(detail::get_partition_index(tile_distribution),
array<index_t, NDimY>{0}));
#endif
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
window_origin + window_adaptor_thread_coord_tmp.get_bottom_index();
bottom_tensor_thread_origin_idx_tmp(HsGatherDim) = 0;
const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
// pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
// future load/store() calls (might allocate more registers)
using Traits = load_store_traits;
using SFC_Ys = typename Traits::SFC_Ys;
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
constexpr auto idx_diff_ys =
SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}), idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
pre_computed_coords_(iCoord) =
make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
});
}
CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; }
CK_TILE_DEVICE static constexpr bool has_static_tile_distribution()
{
return TileDstr::is_static();
}
CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; }
CK_TILE_DEVICE constexpr auto get_tile_distribution() const { return tile_dstr_; }
CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; }
CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; }
CK_TILE_DEVICE constexpr void
set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data)
{
bottom_tensor_view_.buf_.p_data_ = data;
}
// move thread's window adaptor coordinate and bottom tensor coordinate
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
template <typename ATopIndex>
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(
WindowAdaptorCoord& window_adaptor_thread_coord,
BottomTensorCoord& bottom_tensor_thread_coord,
const ATopIndex& idx_diff_adaptor_top) const
{
array<index_t, NDimBottomTensor> idx_diff_adaptor_bottom;
move_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
window_adaptor_thread_coord,
idx_diff_adaptor_top,
idx_diff_adaptor_bottom);
move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
bottom_tensor_thread_coord,
idx_diff_adaptor_bottom);
}
// return vector dimension among [y0, y1, ...]
CK_TILE_DEVICE static constexpr auto get_window_adaptor_ys_safe_vector_length_strides()
{
// bottom tensor top dimension vector lengths and strides
const auto [bottom_tensor_top_dim_vector_lengths, bottom_tensor_top_dim_vector_strides] =
BottomTensorDesc::get_top_dimension_safe_vector_length_strides();
// window vector lengths/strides
const auto window_adaptor_bottom_dim_vector_lengths = bottom_tensor_top_dim_vector_lengths;
const auto window_adaptor_bottom_dim_vector_strides = bottom_tensor_top_dim_vector_strides;
// window adaptor [p0, p1, ..., y0, y1, ...]
array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_lengths{
-1};
array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_strides{
-1};
constexpr auto window_adaptor_bottom_dims =
WindowAdaptor::get_bottom_dimension_hidden_ids();
set_container_subset(window_adaptor_vector_lengths,
window_adaptor_bottom_dims,
window_adaptor_bottom_dim_vector_lengths);
set_container_subset(window_adaptor_vector_strides,
window_adaptor_bottom_dims,
window_adaptor_bottom_dim_vector_strides);
const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] =
WindowAdaptor{}.get_top_dimension_safe_vector_length_strides(
window_adaptor_vector_lengths, window_adaptor_vector_strides);
// [y0, y1, ...]
constexpr auto y_dims = typename arithmetic_sequence_gen<TileDstr::get_num_of_dimension_p(),
NDimWindowAdaptorTop,
1>::type{};
return make_tuple(get_container_subset(window_adaptor_ps_ys_vector_lengths, y_dims),
get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims));
}
CK_TILE_DEVICE constexpr auto get_num_of_access() const { return load_store_traits::NumAccess; }
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
constexpr auto tile_dstr = TileDstr{};
auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
load(dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
return dst_tensor;
}
template <typename DistributedTensor,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
using Traits = load_store_traits;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
const auto page_offset = page_idx_[idx_gather];
// read from bottom tensor
const vector_t vec_value = [&]() {
if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
{
return get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
page_offset,
bool_constant<oob_conditional_check>{});
}
else
{
return get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
page_offset,
valids_[idx_gather],
bool_constant<oob_conditional_check>{});
}
}();
#if 1
// write into distributed tensor
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
},
number<NDimY>{});
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
Traits::PackedSize;
dst_tensor.get_thread_buffer().template at<d>() =
vec_value.template get_as<DataType>()[j / Traits::PackedSize];
});
#else
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
static_assert(d % Traits::ScalarPerVector == 0);
dst_tensor.get_thread_buffer().template get_as<vector_t>()(
number<d / Traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
#endif
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto forward_step_scatter = generate_tuple(
[&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
number<NDimY>{});
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
forward_step_scatter);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
});
}
// TODO: currently async load only implemented in inline asm
template <typename LdsTileWindow_,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
// using LdsTensorView = typename LdsTileWindow::BottomTensorView;
using LdsDataType = typename LdsTileWindow::DataType;
// using LdsDescriptor = typename LdsTileWindow::BottomTensorDesc;
// issues * warps * lanes
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
const index_t size_per_buf =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<0>{}, number<0>{})) *
sizeof(LdsDataType);
const index_t size_per_wave =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<1>{}, number<0>{})) *
sizeof(LdsDataType) -
size_per_buf;
const index_t size_per_issue =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<1>{}, number<0>{}, number<0>{})) *
sizeof(LdsDataType) -
size_per_buf;
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
m0_set_with_memory(m0_init_value); // This should be wave independent
using Traits = load_store_traits;
// using vector_type_t = typename Traits::vector_type_t;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
constexpr auto pre_nop_ = [&]() {
if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
return bool_constant<true>{};
else
return bool_constant<false>{};
}();
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
const auto page_offset = page_idx_[idx_gather];
// read from bottom tensor
if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
{
get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
smem, bottom_tensor_thread_coord, page_offset, 0, pre_nop_);
}
else
{
get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
smem,
bottom_tensor_thread_coord,
page_offset,
valids_[idx_gather],
0,
pre_nop_);
}
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto forward_step_scatter = generate_tuple(
[&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
number<NDimY>{});
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
forward_step_scatter);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
m0_inc_with_memory(size_per_issue);
}
});
});
}
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE void store(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
using Traits = load_store_traits;
// using vector_type_t = typename Traits::vector_type_t;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
// printf("off %d\n", page_idx_[I0]);
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
constexpr auto idx_gather = idx_ys_start[number<0>{}];
const auto page_offset = page_idx_[idx_gather];
// printf("idx_ys_start[0], idx_ys_start[1](%d, %d) \n",
// idx_ys_start[number<0>{}]+0, idx_ys_start[number<1>{}]+0);
// read from distributed tensor
// vector_type_t vec;
vector_t vec_value;
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
},
number<NDimY>{});
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
Traits::PackedSize;
// printf("thread_idx_m: %d j: %d\n", idx_ys[number<0>{}] + 0, 0+j);
vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
dstr_tensor.get_thread_buffer().template at<d>();
});
// const vector_t vec_value = vec.template get_as<vector_t>().template at<0>();
// write into bottom tensor
if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
{
get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
page_offset,
vec_value,
bool_constant<oob_conditional_check>{});
}
else
{
get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
page_offset,
valids_[idx_gather],
vec_value,
bool_constant<oob_conditional_check>{});
}
// printf("coord_offset:%d, scatter_offset:%d \n",
// bottom_tensor_thread_coord.get_offset(), offset); move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto forward_step_scatter = generate_tuple(
[&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
number<NDimY>{});
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
forward_step_scatter);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
});
}
// move thread's botom tensor coordiante
// [x0', x1', ... ] ==> [offset]
// also move window-origin
CK_TILE_DEVICE void move(const BottomTensorIndex& step)
{
window_origin_ += step;
BottomTensorIndex step_new = step;
step_new(HsGatherDim) = 0;
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
pre_computed_coords_(iCoord)(I1),
step_new);
});
}
CK_TILE_DEVICE void update_page_idx(const PageIdxArray& new_idx) { page_idx_ = new_idx; }
CK_TILE_DEVICE void update_valids(const ValidArray& new_valids)
{
if constexpr(std::is_same_v<ValidArray, std::nullptr_t> == false)
{
valids_ = new_valids;
}
}
CK_TILE_DEVICE void update_page_idx_and_valids(const PageIdxArray& new_idx,
const ValidArray& new_valids)
{
update_page_idx(new_idx);
update_valids(new_valids);
}
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin)
{
window_origin_ = new_window_origin;
#if 0 // debug
// TODO: this use more register for FA, but less register for GEMM
// need investigation
// only support warp-tile and block-tile
static_assert(NDimP == 1 or NDimP == 2, "wrong!");
WindowAdaptorCoord window_adaptor_thread_coord_tmp;
if constexpr(NDimP == 1)
{
window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
tile_dstr_.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0});
}
else if constexpr(NDimP == 2)
{
window_adaptor_thread_coord_tmp =
make_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
}
#else
// TODO: this use less register for FA, but more register for GEMM
// need investigation
const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
tile_dstr_.get_ps_ys_to_xs_adaptor(),
container_concat(detail::get_partition_index(tile_dstr_), array<index_t, NDimY>{0}));
#endif
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
bottom_tensor_thread_origin_idx_tmp(HsGatherDim) = 0;
const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
// pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
// future load/store() calls (might allocate more registers)
using Traits = load_store_traits;
using SFC_Ys = typename Traits::SFC_Ys;
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
constexpr auto idx_diff_ys =
SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}), idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
pre_computed_coords_(iCoord) =
make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
});
}
CK_TILE_HOST_DEVICE void init_raw() { bottom_tensor_view_.init_raw(); }
// this is the bottom tensor view
// [x0', x1', ...] ==> [offset]
BottomTensorView bottom_tensor_view_;
//
WindowLengths window_lengths_;
// origin ([x0', x1', ...]) of window on bottom tensor
BottomTensorIndex window_origin_;
// Tile tensor distribution, which contains:
// 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...]
// 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d]
TileDstr tile_dstr_;
PageIdxArray page_idx_;
ValidArray valids_;
// this contains:
// per-thread coordinate for window adaptor
// per-thread coordinate for bottom tensor
array<tuple<WindowAdaptorCoord, BottomTensorCoord>, NumCoord> pre_computed_coords_;
};
// TODO: use strategy
template <typename TensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
typename StaticPageIndexArray_,
index_t HsGatherDim = 0,
index_t NumCoord = 1>
CK_TILE_DEVICE constexpr auto
make_tile_scatter_gather(const TensorView_& tensor_view,
const WindowLengths_& window_lengths,
const multi_index<TensorView_::get_num_of_dimension()>& origin,
const StaticTileDistribution_& tile_distribution,
const StaticPageIndexArray_& page_idx,
number<HsGatherDim> = {},
number<NumCoord> = {})
{
return tile_scatter_gather<remove_cvref_t<TensorView_>,
remove_cvref_t<WindowLengths_>,
remove_cvref_t<StaticTileDistribution_>,
remove_cvref_t<StaticPageIndexArray_>,
std::nullptr_t,
HsGatherDim,
NumCoord>{
tensor_view, window_lengths, origin, tile_distribution, page_idx, nullptr};
}
template <typename TensorView,
typename WindowLengths,
typename StaticTileDistribution,
typename StaticPageIndexArray,
index_t HsGatherDim>
CK_TILE_DEVICE constexpr auto make_tile_scatter_gather(
const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
const multi_index<TensorView::get_num_of_dimension()>& origin,
const StaticTileDistribution& tile_distribution,
const StaticPageIndexArray& page_idx,
number<HsGatherDim> = {})
{
return make_tile_scatter_gather(tile_window.get_bottom_tensor_view(),
tile_window.get_window_lengths(),
origin,
tile_distribution,
page_idx,
number<HsGatherDim>{});
}
template <typename TensorView,
typename WindowLengths,
typename StaticTileDistribution,
typename StaticPageIndexArray,
index_t HsGatherDim>
CK_TILE_DEVICE constexpr auto make_tile_scatter_gather(
const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
const StaticTileDistribution& tile_distribution,
const StaticPageIndexArray& page_idx,
number<HsGatherDim> = {})
{
return make_tile_scatter_gather(tile_window.get_bottom_tensor_view(),
tile_window.get_window_lengths(),
tile_window.get_window_origin(),
tile_distribution,
page_idx,
number<HsGatherDim>{});
}
template <typename TensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
typename StaticPageIndexArray_,
typename StaticValidArray_,
index_t HsGatherDim = 0,
index_t NumCoord = 1>
CK_TILE_DEVICE constexpr auto
make_tile_scatter_gather(const TensorView_& tensor_view,
const WindowLengths_& window_lengths,
const multi_index<TensorView_::get_num_of_dimension()>& origin,
const StaticTileDistribution_& tile_distribution,
const StaticPageIndexArray_& page_idx,
const StaticValidArray_& valids,
number<HsGatherDim> = {},
number<NumCoord> = {})
{
return tile_scatter_gather<remove_cvref_t<TensorView_>,
remove_cvref_t<WindowLengths_>,
remove_cvref_t<StaticTileDistribution_>,
remove_cvref_t<StaticPageIndexArray_>,
remove_cvref_t<StaticValidArray_>,
HsGatherDim,
NumCoord>{
tensor_view, window_lengths, origin, tile_distribution, page_idx, valids};
}
template <typename TensorView,
typename WindowLengths,
typename StaticTileDistribution,
typename StaticPageIndexArray,
typename StaticValidArray,
index_t HsGatherDim>
CK_TILE_DEVICE constexpr auto make_tile_scatter_gather(
const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
const multi_index<TensorView::get_num_of_dimension()>& origin,
const StaticTileDistribution& tile_distribution,
const StaticPageIndexArray& page_idx,
const StaticValidArray& valids,
number<HsGatherDim> = {})
{
return make_tile_scatter_gather(tile_window.get_bottom_tensor_view(),
tile_window.get_window_lengths(),
origin,
tile_distribution,
page_idx,
valids,
number<HsGatherDim>{});
}
template <typename TensorView,
typename WindowLengths,
typename StaticTileDistribution,
typename StaticPageIndexArray,
typename StaticValidArray,
index_t HsGatherDim>
CK_TILE_DEVICE constexpr auto make_tile_scatter_gather(
const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
const StaticTileDistribution& tile_distribution,
const StaticPageIndexArray& page_idx,
const StaticValidArray& valids,
number<HsGatherDim> = {})
{
return make_tile_scatter_gather(tile_window.get_bottom_tensor_view(),
tile_window.get_window_lengths(),
tile_window.get_window_origin(),
tile_distribution,
page_idx,
valids,
number<HsGatherDim>{});
}
} // namespace ck_tile

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,256 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/arch/utility.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
/**
* @brief This class provides description of tile windowed view on the device memory.
*
* @note This class does not provide any functions to read or modify device memory.
*
* @tparam BottomTensorView_ Class describing & holding device tensor memory.
* @tparam WindowLengths_ Spatial sizes of windowed view on tensor.
*/
template <typename TileWindowType_, typename BottomTensorView_, typename WindowLengths_>
struct tile_window_base
{
using BottomTensorView = remove_reference_t<BottomTensorView_>;
using WindowLengths = remove_cvref_t<WindowLengths_>;
using BottomTensorDesc = typename BottomTensorView::TensorDesc;
using DataType = remove_cvref_t<typename BottomTensorView::DataType>;
static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension();
static_assert(ck_tile::is_known_at_compile_time<WindowLengths>::value,
"wrong! lengths should be static");
using BottomTensorIndex = array<index_t, NDimBottomTensor>;
CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; }
CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; }
CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; }
CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; }
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin)
{
window_origin_ = new_window_origin;
// Delegate to child if it implements extra logic
static_cast<TileWindowType_*>(this)->set_window_origin_extended(new_window_origin);
}
// Default no-op; can be overridden in child
CK_TILE_DEVICE void set_window_origin_extended(const BottomTensorIndex&) {}
CK_TILE_DEVICE constexpr void
set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data)
{
bottom_tensor_view_.buf_.p_data_ = data;
}
// move window-origin
CK_TILE_DEVICE void move(const BottomTensorIndex& step)
{
window_origin_ += step;
// Delegate to child if it implements extra movement logic
static_cast<TileWindowType_*>(this)->move_extended(step);
}
// Default no-op; can be overridden in child
CK_TILE_DEVICE void move_extended(const BottomTensorIndex&) {}
// origin ([x0', x1', ...]) of window on bottom tensor
BottomTensorIndex window_origin_;
WindowLengths window_lengths_;
// this is the bottom tensor view
// [x0', x1', ...] ==> [offset]
BottomTensorView bottom_tensor_view_;
};
template <typename TileWindowType_,
typename BottomTensorView_,
typename WindowLengths_,
typename StaticTileDistribution_>
struct tile_window_with_tile_dstr_base
: public tile_window_base<TileWindowType_, BottomTensorView_, WindowLengths_>
{
using TileDstr = remove_cvref_t<StaticTileDistribution_>;
using TileWindowBase = tile_window_base<TileWindowType_, BottomTensorView_, WindowLengths_>;
using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor;
static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::get_num_of_top_dimension();
static constexpr index_t NDimP = TileDstr::get_num_of_dimension_p();
static constexpr index_t NDimY = TileDstr::get_num_of_dimension_y();
using AdaptorTopIndex = array<index_t, NDimWindowAdaptorTop>;
// using BottomTensorIndex = array<index_t, TileWindowBase::NDimBottomTensor>;
using WindowAdaptorCoord =
decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{}));
using BottomTensorCoord = decltype(make_tensor_coordinate(
typename TileWindowBase::BottomTensorDesc{}, typename TileWindowBase::BottomTensorIndex{}));
static_assert(TileDstr::is_static(), "wrong!");
static_assert(TileWindowBase::NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(),
"wrong! inconsistent # of diemsnions");
CK_TILE_DEVICE constexpr auto get_tile_distribution() const { return tile_dstr_; }
CK_TILE_HOST_DEVICE void init_raw() { this->bottom_tensor_view_.init_raw(); }
CK_TILE_DEVICE static constexpr bool has_static_tile_distribution()
{
return TileDstr::is_static();
}
// move thread's window adaptor coordinate and bottom tensor coordinate
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
template <typename ATopIndex>
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(
WindowAdaptorCoord& window_adaptor_thread_coord,
BottomTensorCoord& bottom_tensor_thread_coord,
const ATopIndex& idx_diff_adaptor_top) const
{
array<index_t, TileWindowBase::NDimBottomTensor> idx_diff_adaptor_bottom;
move_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
window_adaptor_thread_coord,
idx_diff_adaptor_top,
idx_diff_adaptor_bottom);
move_tensor_coordinate(this->bottom_tensor_view_.get_tensor_descriptor(),
bottom_tensor_thread_coord,
idx_diff_adaptor_bottom);
}
struct Traits
{
public:
static constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<typename TileWindowBase::DataType>>::PackedSize;
static constexpr auto get_vector_dim_y_scalar_per_vector()
{
const auto [ys_vector_lengths, ys_vector_strides] =
tile_window_with_tile_dstr_base::get_window_adaptor_ys_safe_vector_length_strides();
index_t VectorDimY_ = 0;
index_t ScalarPerVector_ = 1;
for(index_t i = 0; i < NDimY; ++i)
{
if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_)
{
ScalarPerVector_ = ys_vector_lengths[i];
VectorDimY_ = i;
}
}
return make_tuple(VectorDimY_, ScalarPerVector_);
}
static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>();
static constexpr index_t ScalarPerVector =
get_vector_dim_y_scalar_per_vector().template at<1>();
using vector_t =
thread_buffer<typename TileWindowBase::DataType, ScalarPerVector / PackedSize>;
static constexpr auto scalars_per_access_ = [] {
constexpr auto scalars_per_access_arr = generate_array(
[&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, number<NDimY>{});
/// TODO: add non-automatic storage argument support to macro TO_SEQUENCE()
constexpr auto NDimY_ = NDimY;
return TO_SEQUENCE(scalars_per_access_arr, NDimY_);
}();
static constexpr auto get_space_filling_curve()
{
constexpr auto thread_tensor_lengths_ys =
to_sequence(TileDstr{}.get_ys_to_d_descriptor().get_lengths());
// FIXME: need logic to judge dim access order
using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type;
return space_filling_curve<decltype(thread_tensor_lengths_ys),
DimAccessOrder,
decltype(scalars_per_access_),
false /*!!! no snaked curve! */>{};
}
using SFC_Ys = decltype(get_space_filling_curve());
static constexpr index_t NumAccess = SFC_Ys::get_num_of_access();
static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0");
};
// return vector dimension among [y0, y1, ...]
CK_TILE_DEVICE static constexpr auto get_window_adaptor_ys_safe_vector_length_strides()
{
// bottom tensor top dimension vector lengths and strides
const auto [bottom_tensor_top_dim_vector_lengths, bottom_tensor_top_dim_vector_strides] =
TileWindowBase::BottomTensorDesc::get_top_dimension_safe_vector_length_strides();
// window vector lengths/strides
const auto window_adaptor_bottom_dim_vector_lengths = bottom_tensor_top_dim_vector_lengths;
const auto window_adaptor_bottom_dim_vector_strides = bottom_tensor_top_dim_vector_strides;
// window adaptor [p0, p1, ..., y0, y1, ...]
array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_lengths{
-1};
array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_strides{
-1};
constexpr auto window_adaptor_bottom_dims =
WindowAdaptor::get_bottom_dimension_hidden_ids();
set_container_subset(window_adaptor_vector_lengths,
window_adaptor_bottom_dims,
window_adaptor_bottom_dim_vector_lengths);
set_container_subset(window_adaptor_vector_strides,
window_adaptor_bottom_dims,
window_adaptor_bottom_dim_vector_strides);
const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] =
WindowAdaptor{}.get_top_dimension_safe_vector_length_strides(
window_adaptor_vector_lengths, window_adaptor_vector_strides);
// [y0, y1, ...]
constexpr auto y_dims = typename arithmetic_sequence_gen<TileDstr::get_num_of_dimension_p(),
NDimWindowAdaptorTop,
1>::type{};
return make_tuple(get_container_subset(window_adaptor_ps_ys_vector_lengths, y_dims),
get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims));
}
CK_TILE_DEVICE constexpr auto get_num_of_access() const { return Traits::NumAccess; }
// Tile tensor distribution, which contains:
// 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...]
// 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d]
TileDstr tile_dstr_;
};
} // namespace ck_tile

File diff suppressed because it is too large Load Diff

View File

@@ -18,6 +18,13 @@
#pragma once
namespace ck_tile {
template <typename TileWindow_>
CK_TILE_DEVICE void move_tile_window(TileWindow_& window,
const typename TileWindow_::BottomTensorIndex& step)
{
window.move(step);
}
// input a lds store tile, extract some information from it
// used to set m0 value for gfx9 serious
template <typename LdsTileWindow_>

View File

@@ -83,12 +83,14 @@ CK_TILE_DEVICE void transpose_tile2d_impl_in_thread(OutTensor& out_tensor,
constexpr index_t num_vec_in = vec_length_out;
constexpr index_t num_vec_out = vec_length_in;
using InVec = array<DataType, vec_length_in>;
using OutVec = array<DataType, vec_length_out>;
// SFC
constexpr auto scalars_per_access_arr = generate_array(
[&](auto i) { return (i == y_dim_vec_in or i == y_dim_vec_out) ? y_lengths[i] : 1; },
[&](auto i) {
if constexpr(vec_length_in == 1)
return 1;
else
return (i == y_dim_vec_in || i == y_dim_vec_out) ? y_lengths[i] : 1;
},
number<NDimY>{});
constexpr auto scalars_per_access = TO_SEQUENCE(scalars_per_access_arr, NDimY);
@@ -101,51 +103,90 @@ CK_TILE_DEVICE void transpose_tile2d_impl_in_thread(OutTensor& out_tensor,
static_assert(num_access > 0, "wrong! num_access should be larger than 0");
// in/out vectors to be transposed
thread_buffer<InVec, num_vec_in> in_vectors;
thread_buffer<OutVec, num_vec_out> out_vectors;
// loop over SFC and do transpose
static_for<0, num_access, 1>{}([&](auto iAccess) {
// data index [y0, y1, ...] in the order of input tensor
constexpr auto idx_y_start = SFC_Y::get_index(iAccess);
// get input vectors
static_for<0, num_vec_in, 1>{}([&](auto i) {
constexpr auto idx_y_in = generate_tuple(
[&](auto ii) {
return ii == y_dim_vec_out ? idx_y_start[ii] + i : idx_y_start[ii];
},
number<NDimY>{});
if constexpr(num_vec_in == 1 || num_vec_out == 1)
{
// loop over SFC
static_for<0, num_access, 1>{}([&](auto iAccess) {
// data index [y0, y1, ...] in the order of input tensor
constexpr auto idx_y_start = SFC_Y::get_index(iAccess);
constexpr auto idx_y_in =
generate_tuple([&](auto ii) { return idx_y_start[ii].value; }, number<NDimY>{});
constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y_in);
static_assert(in_offset % vec_length_in == 0);
in_vectors(i).template get_as<InVec>()(I0) =
in_tensor.get_thread_buffer()
.template get_as<InVec>()[number<in_offset / vec_length_in>{}];
});
// transpose
transpose_vectors<DataType, num_vec_in, num_vec_out>{}(in_vectors, out_vectors);
// set output vectors
static_for<0, num_vec_out, 1>{}([&](auto i) {
constexpr auto idx_y_out_tmp = generate_array(
[&](auto ii) { return ii == y_dim_vec_in ? idx_y_start[ii] + i : idx_y_start[ii]; },
number<NDimY>{});
constexpr auto idx_y_out_tmp =
generate_array([&](auto ii) { return idx_y_start[ii].value; }, number<NDimY>{});
constexpr auto idx_y_out =
container_reorder_given_new2old(idx_y_out_tmp, y_dim_out_to_in);
constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y_out);
static_assert(out_offset % vec_length_out == 0);
if constexpr(vec_length_in == 1)
{
out_tensor.get_thread_buffer().template set_as<OutVec>(
number<out_offset / vec_length_out>{},
out_vectors[i].template get_as<OutVec>()[I0]);
out_tensor.get_thread_buffer()[number<out_offset>{}] =
in_tensor.get_thread_buffer()[number<in_offset>{}];
}
else
{
using Vec = array<DataType, vec_length_in>;
out_tensor.get_thread_buffer().template get_as<Vec>(
number<out_offset / vec_length_in>{}) =
in_tensor.get_thread_buffer().template get_as<Vec>(
number<in_offset / vec_length_in>{});
}
});
});
}
else
{
using InVec = array<DataType, vec_length_in>;
using OutVec = array<DataType, vec_length_out>;
// in/out vectors to be transposed
thread_buffer<InVec, num_vec_in> in_vectors;
thread_buffer<OutVec, num_vec_out> out_vectors;
// loop over SFC and do transpose
static_for<0, num_access, 1>{}([&](auto iAccess) {
// data index [y0, y1, ...] in the order of input tensor
constexpr auto idx_y_start = SFC_Y::get_index(iAccess);
// get input vectors
static_for<0, num_vec_in, 1>{}([&](auto i) {
constexpr auto idx_y_in = generate_tuple(
[&](auto ii) {
return ii == y_dim_vec_out ? idx_y_start[ii] + i : idx_y_start[ii];
},
number<NDimY>{});
constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y_in);
static_assert(in_offset % vec_length_in == 0);
in_vectors(i).template get_as<InVec>()(I0) =
in_tensor.get_thread_buffer()
.template get_as<InVec>()[number<in_offset / vec_length_in>{}];
});
// transpose
transpose_vectors<DataType, num_vec_in, num_vec_out>{}(in_vectors, out_vectors);
// set output vectors
static_for<0, num_vec_out, 1>{}([&](auto i) {
constexpr auto idx_y_out_tmp = generate_array(
[&](auto ii) {
return ii == y_dim_vec_in ? idx_y_start[ii] + i : idx_y_start[ii];
},
number<NDimY>{});
constexpr auto idx_y_out =
container_reorder_given_new2old(idx_y_out_tmp, y_dim_out_to_in);
constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y_out);
static_assert(out_offset % vec_length_out == 0);
out_tensor.get_thread_buffer().template set_as<OutVec>(
number<out_offset / vec_length_out>{},
out_vectors[i].template get_as<OutVec>()[I0]);
});
});
}
}
} // namespace detail

View File

@@ -0,0 +1,156 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <stdio.h>
#include <tuple>
#include <utility>
#include "ck_tile/core/numeric/integer.hpp"
namespace ck_tile {
template <auto... val>
[[deprecated("Help function to print value")]] inline constexpr void CK_PRINT()
{
}
template <typename... type>
[[deprecated("Help function to print value")]] inline constexpr void CK_PRINT()
{
}
template <char... Xs>
struct str_literal
{
static constexpr const char data[] = {Xs..., '\0'};
static constexpr const size_t size = sizeof...(Xs);
template <char... Ys>
CK_TILE_HOST_DEVICE constexpr auto operator+(str_literal<Ys...> /*rhs*/) const
{
return str_literal<Xs..., Ys...>{};
}
template <index_t N, char... Ys>
CK_TILE_HOST_DEVICE static constexpr auto duplicate_n(const str_literal<Ys...> sep)
{
if constexpr(N == 0)
return str_literal<>{};
else if constexpr(N == 1)
return str_literal<Xs...>{};
else
return duplicate_n<N - 1>(sep) + str_literal<Ys..., Xs...>{};
}
};
#define make_str_literal(lit_) \
std::apply([](auto... indices) { return str_literal<(lit_)[decltype(indices)::value]...>{}; }, \
makeTuple(std::make_index_sequence<constexpr_strlen(lit_)>()))
template <size_t... Idx>
constexpr std::tuple<std::integral_constant<size_t, Idx>...>
makeTuple(std::index_sequence<Idx...>) noexcept
{
return {};
}
constexpr size_t constexpr_strlen(const char* c)
{
size_t t = 0;
while(*c++)
++t;
return t;
}
template <typename DataType_, typename StaticTileDistribution_>
struct static_distributed_tensor;
template <typename T_, index_t N_>
struct thread_buffer;
// Usage example: CK_PRINTF<float>{}(tensor);
template <typename ConvertTo = void,
typename FMT = str_literal<>,
typename PREFIX = str_literal<>,
typename SUFFIX = str_literal<>>
struct CK_PRINTF;
template <typename ConvertTo, char... FMTChars, char... PREFIXChars, char... SUFFIXChars>
struct CK_PRINTF<ConvertTo,
str_literal<FMTChars...>,
str_literal<PREFIXChars...>,
str_literal<SUFFIXChars...>>
{
template <typename T>
CK_TILE_HOST_DEVICE static constexpr auto default_format()
{
if constexpr(std::is_same_v<T, float>)
return make_str_literal("%8.3f");
else if constexpr(std::is_same_v<T, int>)
return make_str_literal("%5d");
else if constexpr(std::is_same_v<T, unsigned int>)
return make_str_literal("%5u");
else
return make_str_literal("0x%08x");
}
CK_TILE_HOST_DEVICE static constexpr auto get_prefix()
{
constexpr auto fmt_tid = make_str_literal("tid %03d: [%02d] ");
if constexpr(sizeof...(PREFIXChars) == 0)
return fmt_tid;
else
return fmt_tid + make_str_literal(" ") + str_literal<PREFIXChars...>{};
}
CK_TILE_HOST_DEVICE static constexpr auto get_suffix()
{
constexpr auto lf = make_str_literal("\n");
if constexpr(sizeof...(SUFFIXChars) == 0)
return lf;
else
return str_literal<SUFFIXChars...>{} + lf;
}
template <typename T, index_t N, typename Y, index_t... Is>
CK_TILE_HOST_DEVICE void impl(const thread_buffer<T, N>& buf,
std::integer_sequence<index_t, Is...>) const
{
using FMT1 = std::conditional_t<sizeof...(FMTChars) == 0,
decltype(default_format<Y>()),
str_literal<FMTChars...>>;
constexpr auto fmt_v = FMT1::template duplicate_n<N>(make_str_literal(" "));
constexpr auto fmt_wrap_v = get_prefix() + fmt_v + get_suffix();
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wformat-nonliteral"
printf(fmt_wrap_v.data, get_thread_id(), N, type_convert<Y>(buf[Is])...);
#pragma clang diagnostic pop
}
template <typename T, index_t N>
CK_TILE_HOST_DEVICE void operator()(const thread_buffer<T, N>& buf) const
{
using ConvertTo_ = std::conditional_t<std::is_same_v<ConvertTo, void>, T, ConvertTo>;
impl<T, N, ConvertTo_>(buf, std::make_integer_sequence<index_t, N>{});
}
template <typename... TS>
CK_TILE_HOST_DEVICE void operator()(const static_distributed_tensor<TS...>& tensor) const
{
return operator()(tensor.get_thread_buffer());
}
};
template <typename ConvertTo = void,
typename FMT = str_literal<>,
typename PREFIX = str_literal<>,
typename SUFFIX = str_literal<>>
struct CK_PRINTF_WARP0 : public CK_PRINTF<ConvertTo, FMT, PREFIX, SUFFIX>
{
using base_t = CK_PRINTF<ConvertTo, FMT, PREFIX, SUFFIX>;
template <typename T>
CK_TILE_HOST_DEVICE void operator()(const T& buf) const
{
if(get_thread_id() < get_warp_size())
base_t::operator()(buf);
}
};
} // namespace ck_tile

View File

@@ -202,3 +202,7 @@ void UpdateEnvVar(EnvVar, const std::string_view& val)
}
} // namespace ck_tile
// environment variable to enable logging:
// export CK_TILE_LOGGING=ON or CK_TILE_LOGGING=1 or CK_TILE_LOGGING=ENABLED
CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_LOGGING)

View File

@@ -58,6 +58,30 @@ struct static_for
}
};
namespace detail {
template <typename T, T... Is>
struct applier
{
template <typename F>
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
{
// tweak -fbracket-depth if compilation fails. Clang default limit is 256
(f(number<Is>{}), ...);
}
};
template <int32_t Size> // == sizeof...(Is)
using make_applier = __make_integer_seq<applier, index_t, Size>;
} // namespace detail
template <index_t N>
struct static_for<0, N, 1> : detail::make_applier<N>
{
using detail::make_applier<N>::operator();
};
struct identity
{
template <typename T>

View File

@@ -38,7 +38,7 @@ struct magic_division32_bit_range
shift_u32++;
};
uint64_t tmp_u64 = ((1UL << shift_u32) - divisor) << 32;
uint64_t tmp_u64 = static_cast<uint64_t>((1UL << shift_u32) - divisor) << 32;
uint32_t multiplier_u32 = tmp_u64 / divisor + 1;
return make_tuple(multiplier_u32, shift_u32);

View File

@@ -0,0 +1,76 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
namespace ck_tile {
/// Declare a ck_tile::print() interface that gets specialized in each header file for types that
/// can be printed.
template <typename T>
CK_TILE_HOST_DEVICE void print(const T&)
{
static_assert(sizeof(T) == 0,
"No print implementation available for this type. Please specialize "
"ck_tile::print for your type.");
}
/// Specialization for int
template <>
CK_TILE_HOST_DEVICE void print(const int& value)
{
printf("%d", value);
}
/// Specialization for float
template <>
CK_TILE_HOST_DEVICE void print(const float& value)
{
printf("%f", value);
}
/// Specialization for double
template <>
CK_TILE_HOST_DEVICE void print(const double& value)
{
printf("%f", value);
}
/// Specialization for long
template <>
CK_TILE_HOST_DEVICE void print(const long& value)
{
printf("%ld", value);
}
/// Specialization for unsigned int
template <>
CK_TILE_HOST_DEVICE void print(const unsigned int& value)
{
printf("%u", value);
}
/// Specialization for char
template <>
CK_TILE_HOST_DEVICE void print(const char& value)
{
printf("%c", value);
}
/// Specialization for array
template <typename T, size_t N>
CK_TILE_HOST_DEVICE void print(const T (&value)[N])
{
printf("[");
for(size_t i = 0; i < N; ++i)
{
if(i > 0)
printf(", ");
print(value[i]); // Recursively call print for each element
}
printf("]");
}
} // namespace ck_tile

View File

@@ -26,7 +26,8 @@ struct Add
}
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t>>>
typename = std::enable_if_t<std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
CK_TILE_HOST_DEVICE constexpr T operator()(T& y, T x) const
{
float y_ = type_convert<float>(y);
@@ -51,13 +52,25 @@ struct SquareAdd
{
return y + (x * x);
}
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
CK_TILE_HOST_DEVICE constexpr T operator()(T& y, T x) const
{
float y_ = type_convert<float>(y);
float x_ = type_convert<float>(x);
return type_convert<T>(y_ + (x_ * x_));
}
};
struct Max
{
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue()
{
return numeric<T>::min();
@@ -65,7 +78,9 @@ struct Max
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
{
return max(y, x);
@@ -76,7 +91,9 @@ struct AbsMax
{
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue()
{
return numeric<T>::min();
@@ -84,7 +101,9 @@ struct AbsMax
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
{
return max(y, abs(x));

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core/config.hpp"
#include <tuple>
#include <type_traits>
#include <stdint.h>
@@ -57,8 +58,8 @@ struct detector<Default, std::void_t<Op<Args...>>, Op, Args...>
struct nonesuch
{
~nonesuch() = delete;
nonesuch(nonesuch const&) = delete;
~nonesuch() = delete;
nonesuch(nonesuch const&) = delete;
void operator=(nonesuch const&) = delete;
};
@@ -127,4 +128,44 @@ struct is_any_of<CompareTo, FirstType, Rest...>
{
};
// Helper to check if a type is a specialization of a given template
template <typename Test, template <typename...> class RefTemplate>
struct is_specialization_of : std::false_type
{
};
template <template <typename...> class RefTemplate, typename... Args>
struct is_specialization_of<RefTemplate<Args...>, RefTemplate> : std::true_type
{
};
// Helper to get a tuple element or default type
namespace detail {
template <bool IsWithinBounds, std::size_t Idx, typename Tuple, typename DefaultType>
struct tuple_element_or_default_dispatch
{
using type = DefaultType;
};
template <std::size_t Idx, typename Tuple, typename DefaultType>
struct tuple_element_or_default_dispatch<true, Idx, Tuple, DefaultType>
{
using type = std::tuple_element_t<Idx, Tuple>;
};
} // namespace detail
template <typename Tuple_, std::size_t Idx, typename DefaultType>
struct tuple_element_or_default
{
using Tuple = remove_cvref_t<Tuple_>;
static constexpr bool is_within_bounds = Idx < std::tuple_size_v<Tuple>;
using type = typename detail::
tuple_element_or_default_dispatch<is_within_bounds, Idx, Tuple, DefaultType>::type;
};
template <typename Tuple_, std::size_t Idx, typename DefaultType>
using tuple_element_or_default_t =
typename tuple_element_or_default<Tuple_, Idx, DefaultType>::type;
} // namespace ck_tile

View File

@@ -49,7 +49,7 @@ struct composes<F>
/// FIXME: create macro to replace '__host__ __device__' and nothing more
template <typename... Ts>
__host__ __device__ composes(Ts&&...)->composes<remove_cvref_t<Ts>...>;
__host__ __device__ composes(Ts&&...) -> composes<remove_cvref_t<Ts>...>;
template <typename SaturateType>
struct saturates
@@ -57,8 +57,8 @@ struct saturates
// NOTE: this function does not return SaturateType value
// it is user's responsiblity to do further cast or not
template <typename AccType>
CK_TILE_HOST_DEVICE constexpr auto operator()(const AccType& a_) const
-> std::enable_if_t<std::is_arithmetic_v<AccType>, AccType>
CK_TILE_HOST_DEVICE constexpr auto
operator()(const AccType& a_) const -> std::enable_if_t<std::is_arithmetic_v<AccType>, AccType>
{
return clamp(a_,
type_convert<AccType>(numeric<SaturateType>::lowest()),

View File

@@ -9,7 +9,9 @@
#include "ck_tile/host/convolution_host_tensor_descriptor_helper.hpp"
#include "ck_tile/host/convolution_parameter.hpp"
#include "ck_tile/host/device_memory.hpp"
#include "ck_tile/host/device_prop.hpp"
#include "ck_tile/host/fill.hpp"
#include "ck_tile/host/flush_icache.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/host/joinable_thread.hpp"
@@ -25,6 +27,8 @@
#include "ck_tile/host/reference/reference_elementwise.hpp"
#include "ck_tile/host/reference/reference_fused_moe.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_grouped_conv_bwd_weight.hpp"
#include "ck_tile/host/reference/reference_grouped_conv_fwd.hpp"
#include "ck_tile/host/reference/reference_im2col.hpp"
#include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp"
#include "ck_tile/host/reference/reference_moe_sorting.hpp"
@@ -34,5 +38,8 @@
#include "ck_tile/host/reference/reference_rowwise_quantization2d.hpp"
#include "ck_tile/host/reference/reference_softmax.hpp"
#include "ck_tile/host/reference/reference_topk.hpp"
#include "ck_tile/host/reference/reference_transpose.hpp"
#include "ck_tile/host/rotating_buffers.hpp"
#include "ck_tile/host/stream_config.hpp"
#include "ck_tile/host/stream_utils.hpp"
#include "ck_tile/host/timer.hpp"

View File

@@ -18,16 +18,36 @@
namespace ck_tile {
/** @brief 8-bit floating point type */
using F8 = ck_tile::fp8_t;
/** @brief 8-bit brain floating point type */
using BF8 = ck_tile::bf8_t;
/** @brief 16-bit floating point (half precision) type */
using F16 = ck_tile::half_t;
/** @brief 16-bit brain floating point type */
using BF16 = ck_tile::bf16_t;
/** @brief 32-bit floating point (single precision) type */
using F32 = float;
/** @brief 8-bit signed integer type */
using I8 = int8_t;
/** @brief 32-bit signed integer type */
using I32 = int32_t;
/**
* @brief Calculate relative error threshold for numerical comparisons
*
* Calculates the relative error threshold based on the mantissa bits and characteristics
* of the data types involved in the computation.
*
* @tparam ComputeDataType Type used for computation
* @tparam OutDataType Type used for output
* @tparam AccDataType Type used for accumulation (defaults to ComputeDataType)
* @param number_of_accumulations Number of accumulation operations performed
* @return Relative error threshold based on data type characteristics
*/
template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
double get_relative_threshold(const int number_of_accumulations = 1)
CK_TILE_HOST double get_relative_threshold(const int number_of_accumulations = 1)
{
using F8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using F16 = ck_tile::half_t;
using BF16 = ck_tile::bf16_t;
using F32 = float;
using I8 = int8_t;
using I32 = int32_t;
static_assert(
is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
@@ -72,16 +92,23 @@ double get_relative_threshold(const int number_of_accumulations = 1)
return std::max(acc_error, midway_error);
}
/**
* @brief Calculate absolute error threshold for numerical comparisons
*
* Calculates the absolute error threshold based on the maximum possible value and
* the characteristics of the data types involved in the computation.
*
* @tparam ComputeDataType Type used for computation
* @tparam OutDataType Type used for output
* @tparam AccDataType Type used for accumulation (defaults to ComputeDataType)
* @param max_possible_num Maximum possible value in the computation
* @param number_of_accumulations Number of accumulation operations performed
* @return Absolute error threshold based on data type characteristics and maximum value
*/
template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations = 1)
CK_TILE_HOST double get_absolute_threshold(const double max_possible_num,
const int number_of_accumulations = 1)
{
using F8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using F16 = ck_tile::half_t;
using BF16 = ck_tile::bf16_t;
using F32 = float;
using I8 = int8_t;
using I32 = int32_t;
static_assert(
is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
@@ -128,6 +155,16 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
return std::max(acc_error, midway_error);
}
/**
* @brief Stream operator overload for vector output
*
* Provides a formatted string representation of a vector, useful for debugging and logging.
*
* @tparam T Type of vector elements
* @param os Output stream
* @param v Vector to output
* @return Reference to the output stream
*/
template <typename T>
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
{
@@ -145,6 +182,66 @@ std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
return os << "]";
}
/**
* @brief Check for size mismatch between output and reference ranges
*
* Verifies that the output and reference ranges are the same size.
*
* @tparam Range Type of output range
* @tparam RefRange Type of reference range
* @param out Output range to check
* @param ref Reference range to check against
* @param msg Error message to display if sizes mismatch
* @return True if sizes mismatch, false otherwise
*/
template <typename Range, typename RefRange>
CK_TILE_HOST bool check_size_mismatch(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!")
{
if(out.size() != ref.size())
{
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl;
return true;
}
return false;
}
/**
* @brief Report error statistics for numerical comparisons
*
* Outputs statistics about numerical comparison errors including count and maximum error.
*
* @param err_count Number of errors found
* @param max_err Maximum error value encountered
* @param total_size Total number of elements compared
*/
CK_TILE_HOST void report_error_stats(int err_count, double max_err, std::size_t total_size)
{
const float error_percent =
static_cast<float>(err_count) / static_cast<float>(total_size) * 100.f;
std::cerr << "max err: " << max_err;
std::cerr << ", number of errors: " << err_count;
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
}
/**
* @brief Check errors between floating point ranges using the specified tolerances.
*
* Compares two ranges of floating point values within specified relative and absolute tolerances.
* This overload handles standard floating point types except half precision floating point.
*
* @tparam Range Type of output range
* @tparam RefRange Type of reference range
* @param out Output range to check
* @param ref Reference range to check against
* @param msg Error message to display if check fails
* @param rtol Relative tolerance
* @param atol Absolute tolerance
* @param allow_infinity_ref Whether to allow infinity in reference values
* @return True if check passes, false otherwise
*/
template <typename Range, typename RefRange>
typename std::enable_if<
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
@@ -158,12 +255,9 @@ check_err(const Range& out,
double atol = 3e-6,
bool allow_infinity_ref = false)
{
if(out.size() != ref.size())
{
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl;
if(check_size_mismatch(out, ref, msg))
return false;
}
const auto is_infinity_error = [=](auto o, auto r) {
const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
@@ -196,15 +290,27 @@ check_err(const Range& out,
}
if(!res)
{
const float error_percent =
static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
std::cerr << "max err: " << max_err;
std::cerr << ", number of errors: " << err_count;
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
report_error_stats(err_count, max_err, ref.size());
}
return res;
}
/**
* @brief Check errors between floating point ranges using the specified tolerances
*
* Compares two ranges of brain floating point values within specified relative and absolute
* tolerances.
*
* @tparam Range Type of output range
* @tparam RefRange Type of reference range
* @param out Output range to check
* @param ref Reference range to check against
* @param msg Error message to display if check fails
* @param rtol Relative tolerance
* @param atol Absolute tolerance
* @param allow_infinity_ref Whether to allow infinity in reference values
* @return True if check passes, false otherwise
*/
template <typename Range, typename RefRange>
typename std::enable_if<
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
@@ -217,12 +323,8 @@ check_err(const Range& out,
double atol = 1e-3,
bool allow_infinity_ref = false)
{
if(out.size() != ref.size())
{
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl;
if(check_size_mismatch(out, ref, msg))
return false;
}
const auto is_infinity_error = [=](auto o, auto r) {
const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
@@ -256,15 +358,28 @@ check_err(const Range& out,
}
if(!res)
{
const float error_percent =
static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
std::cerr << "max err: " << max_err;
std::cerr << ", number of errors: " << err_count;
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
report_error_stats(err_count, max_err, ref.size());
}
return res;
}
/**
* @brief Check errors between half precision floating point ranges
*
* Compares two ranges of half precision floating point values within specified tolerances.
* This specialization handles the specific requirements and characteristics of half precision
* floating point comparisons.
*
* @tparam Range Type of output range
* @tparam RefRange Type of reference range
* @param out Output range to check
* @param ref Reference range to check against
* @param msg Error message to display if check fails
* @param rtol Relative tolerance
* @param atol Absolute tolerance
* @param allow_infinity_ref Whether to allow infinity in reference values
* @return True if check passes, false otherwise
*/
template <typename Range, typename RefRange>
typename std::enable_if<
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
@@ -277,12 +392,8 @@ check_err(const Range& out,
double atol = 1e-3,
bool allow_infinity_ref = false)
{
if(out.size() != ref.size())
{
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl;
if(check_size_mismatch(out, ref, msg))
return false;
}
const auto is_infinity_error = [=](auto o, auto r) {
const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
@@ -315,15 +426,26 @@ check_err(const Range& out,
}
if(!res)
{
const float error_percent =
static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
std::cerr << "max err: " << max_err;
std::cerr << ", number of errors: " << err_count;
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
report_error_stats(err_count, max_err, ref.size());
}
return res;
}
/**
* @brief Check errors between integer ranges
*
* Compares two ranges of integer values with an absolute tolerance.
* This specialization handles integer types and optionally int4_t when the
* experimental bit int extension is enabled.
*
* @tparam Range Type of output range
* @tparam RefRange Type of reference range
* @param out Output range to check
* @param ref Reference range to check against
* @param msg Error message to display if check fails
* @param atol Absolute tolerance
* @return True if check passes, false otherwise
*/
template <typename Range, typename RefRange>
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_integral_v<ranges::range_value_t<Range>> &&
@@ -339,12 +461,8 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
double = 0,
double atol = 0)
{
if(out.size() != ref.size())
{
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl;
if(check_size_mismatch(out, ref, msg))
return false;
}
bool res{true};
int err_count = 0;
@@ -370,15 +488,28 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
}
if(!res)
{
const float error_percent =
static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
std::cerr << "max err: " << max_err;
std::cerr << ", number of errors: " << err_count;
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
report_error_stats(err_count, static_cast<double>(max_err), ref.size());
}
return res;
}
/**
* @brief Check errors between FP8 ranges
*
* Specialized comparison for 8-bit floating point values that takes into account
* the unique characteristics and limitations of FP8 arithmetic, including
* rounding point distances and special handling of infinity values.
*
* @tparam Range Type of output range
* @tparam RefRange Type of reference range
* @param out Output range to check
* @param ref Reference range to check against
* @param msg Error message to display if check fails
* @param max_rounding_point_distance Maximum allowed distance between rounding points
* @param atol Absolute tolerance
* @param allow_infinity_ref Whether to allow infinity in reference values
* @return True if check passes, false otherwise
*/
template <typename Range, typename RefRange>
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_same_v<ranges::range_value_t<Range>, fp8_t>),
@@ -390,12 +521,8 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
double atol = 1e-1,
bool allow_infinity_ref = false)
{
if(out.size() != ref.size())
{
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl;
if(check_size_mismatch(out, ref, msg))
return false;
}
const auto is_infinity_error = [=](auto o, auto r) {
const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
@@ -447,15 +574,27 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
}
if(!res)
{
const float error_percent =
static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
std::cerr << "max err: " << max_err;
std::cerr << ", number of errors: " << err_count;
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
report_error_stats(err_count, max_err, ref.size());
}
return res;
}
/**
* @brief Check errors between BF8 ranges
*
* Specialized comparison for 8-bit brain floating point values that considers
* the specific numerical properties and error characteristics of the BF8 format.
*
* @tparam Range Type of output range
* @tparam RefRange Type of reference range
* @param out Output range to check
* @param ref Reference range to check against
* @param msg Error message to display if check fails
* @param rtol Relative tolerance
* @param atol Absolute tolerance
* @param allow_infinity_ref Whether to allow infinity in reference values
* @return True if check passes, false otherwise
*/
template <typename Range, typename RefRange>
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_same_v<ranges::range_value_t<Range>, bf8_t>),
@@ -467,12 +606,8 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
double atol = 1e-3,
bool allow_infinity_ref = false)
{
if(out.size() != ref.size())
{
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl;
if(check_size_mismatch(out, ref, msg))
return false;
}
const auto is_infinity_error = [=](auto o, auto r) {
const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
@@ -505,11 +640,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
}
if(!res)
{
const float error_percent =
static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
std::cerr << "max err: " << max_err;
std::cerr << ", number of errors: " << err_count;
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
report_error_stats(err_count, max_err, ref.size());
}
return res;
}

View File

@@ -33,13 +33,14 @@ struct IsCharArray<const char (&)[N]> : std::true_type
};
template <typename... Ts>
inline constexpr bool AllConvertibleToStringView = ((std::is_convertible_v<Ts, std::string_view> ||
IsCharArray<Ts>::value ||
std::is_same_v<Ts, char>)&&...);
inline constexpr bool AllConvertibleToStringView =
((std::is_convertible_v<Ts, std::string_view> || IsCharArray<Ts>::value ||
std::is_same_v<Ts, char>) &&
...);
template <typename... Ts>
[[nodiscard]] auto concat(const Ts&... xs)
-> std::enable_if_t<!AllConvertibleToStringView<Ts...>, std::string>
[[nodiscard]] auto
concat(const Ts&... xs) -> std::enable_if_t<!AllConvertibleToStringView<Ts...>, std::string>
{
using ::operator<<;
thread_local std::ostringstream oss;
@@ -78,8 +79,8 @@ template <std::size_t N>
}
template <typename... Ts>
auto concatInto(std::string& result, const Ts&... xs)
-> std::enable_if_t<AllConvertibleToStringView<Ts...>, void>
auto concatInto(std::string& result,
const Ts&... xs) -> std::enable_if_t<AllConvertibleToStringView<Ts...>, void>
{
const std::size_t space = (1 + ... + getSize(xs));
result.reserve(result.size() + space);
@@ -87,8 +88,8 @@ auto concatInto(std::string& result, const Ts&... xs)
}
template <typename... Ts>
[[nodiscard]] auto concat(const Ts&... xs)
-> std::enable_if_t<AllConvertibleToStringView<Ts...>, std::string>
[[nodiscard]] auto
concat(const Ts&... xs) -> std::enable_if_t<AllConvertibleToStringView<Ts...>, std::string>
{
std::string result;
concatInto(result, xs...);

View File

@@ -20,10 +20,35 @@ __global__ void set_buffer_value(T* p, T x, uint64_t buffer_element_size)
}
/**
* @brief Container for storing data in GPU device memory
* @brief Manages device memory allocation and host-device data transfers
*
* DeviceMem encapsulates GPU memory management operations using HIP runtime API.
* It provides functionality for allocating device memory, transferring data between
* host and device, and performing basic memory operations.
*
* Key features:
* - Automatic memory allocation and deallocation
* - Host-to-device and device-to-host data transfers
* - Memory initialization operations
* - Integration with HostTensor for simplified data handling
*
* Usage example:
* ```
* // Allocate device memory
* BHostTensor<float> AHostData({256});
* DeviceMem d_mem(BHostData.get_element_space_size_in_bytes());
*
* // Transfer data to device
* HostTensor<float> AHostTensor({256});
* d_mem.ToDevice(AHostData.data());
*
* // Retrieve data from device
* HostTensor<float> ResultHostTensor({256});
* d_mem.FromDevice(ResultHostTensor.data());
* ```
*/
struct DeviceMem
{
DeviceMem() : mpDeviceBuf(nullptr), mMemSize(0) {}
DeviceMem(std::size_t mem_size) : mMemSize(mem_size)
@@ -163,8 +188,8 @@ struct DeviceMem
}
}
void* mpDeviceBuf;
std::size_t mMemSize;
void* mpDeviceBuf; ///< pointer to device buffer
std::size_t mMemSize; ///< size of device buffer in bytes
};
} // namespace ck_tile

View File

@@ -0,0 +1,75 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#ifndef __HIPCC_RTC__
#include <string>
#include <string_view>
#include <hip/hip_runtime.h>
namespace ck_tile {
constexpr unsigned int fnv1a_hash(std::string_view str, unsigned int h = 2166136261u)
{
return str.empty() ? h
: fnv1a_hash(str.substr(1),
(h ^ static_cast<unsigned char>(str.front())) * 16777619u);
}
inline std::string get_device_name()
{
hipDeviceProp_t props{};
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess)
{
return std::string();
}
status = hipGetDeviceProperties(&props, device);
if(status != hipSuccess)
{
return std::string();
}
const std::string raw_name(props.gcnArchName);
const auto name = raw_name.substr(0, raw_name.find(':')); // str.substr(0, npos) returns str.
switch(fnv1a_hash(name))
{
// https://github.com/ROCm/MIOpen/blob/8498875aef84878e04c1eabefdf6571514891086/src/target_properties.cpp#L40
case fnv1a_hash("Ellesmere"):
case fnv1a_hash("Baffin"):
case fnv1a_hash("RacerX"):
case fnv1a_hash("Polaris10"):
case fnv1a_hash("Polaris11"):
case fnv1a_hash("Tonga"):
case fnv1a_hash("Fiji"):
case fnv1a_hash("gfx800"):
case fnv1a_hash("gfx802"):
case fnv1a_hash("gfx804"): return "gfx803";
case fnv1a_hash("Vega10"):
case fnv1a_hash("gfx901"): return "gfx900";
case fnv1a_hash("10.3.0 Sienna_Cichlid 18"): return "gfx1030";
default: return name;
}
}
inline bool is_gfx11_supported()
{
return get_device_name() == "gfx1100" || get_device_name() == "gfx1101" ||
get_device_name() == "gfx1102" || get_device_name() == "gfx1103" ||
get_device_name() == "gfx1150" || get_device_name() == "gfx1151" ||
get_device_name() == "gfx1152";
}
inline bool is_gfx12_supported()
{
return get_device_name() == "gfx1200" || get_device_name() == "gfx1201";
}
inline bool is_load_tr_supported()
{
// Check if load transpose is supported.
return get_device_name() == "gfx950";
}
} // namespace ck_tile
#endif

View File

@@ -8,6 +8,7 @@
#include <iterator>
#include <optional>
#include <random>
#include <stdexcept>
#include <type_traits>
#include <utility>
#include <unordered_set>
@@ -17,13 +18,31 @@
namespace ck_tile {
/**
* @brief Functor for filling a range with randomly generated values from a uniform distribution.
*
* This struct provides functionality to fill iterators or ranges with random values
* generated from a uniform distribution. It supports both single-threaded and
* multi-threaded operation.
*
* @tparam T The target type for the generated values.
*
* @note The multi-threaded implementation is not guaranteed to provide perfectly
* distributed values across threads.
*
* @example
*
* // Direct usage without creating a separate variable:
* ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_host_tensor);
*/
template <typename T>
struct FillUniformDistribution
{
float a_{-5.f};
float b_{5.f};
std::optional<uint32_t> seed_{11939};
// ATTENTION: threaded does not guarantee the distribution between thread
// ATTENTION: Whether to use multi-threading (note: not guaranteed to be perfectly distributed
// across threads).
bool threaded = false;
template <typename ForwardIter>
@@ -45,7 +64,7 @@ struct FillUniformDistribution
return;
// need to make each thread unique, add an offset to current seed
std::mt19937 gen(seed_.has_value() ? (*seed_ + iw_begin)
: std::random_device{}());
: std::random_device{}());
std::uniform_real_distribution<float> dis(a_, b_);
std::generate(first + iw_begin, first + iw_end, [&dis, &gen]() {
return ck_tile::type_convert<T>(dis(gen));
@@ -74,6 +93,60 @@ struct FillUniformDistribution
}
};
template <>
struct FillUniformDistribution<ck_tile::pk_int4_t>
{
float a_{-8.f}; // same type as primary template so that
// `FillUniformDistribution<Type>{-5.0f, 5.0f}` works for all types
float b_{7.f};
std::optional<uint32_t> seed_{11939};
template <typename ForwardIter>
void operator()(ForwardIter first, ForwardIter last) const
{
if(a_ < -8.0f || b_ > 7.0f)
{
throw std::runtime_error(
"a_ or b_ of FillUniformDistribution<ck_tile::pk_int4_t> is out of range.");
}
int min_value = static_cast<int>(a_);
int max_value = static_cast<int>(b_);
constexpr auto int4_array = std::array<uint8_t, 16>{0x88,
0x99,
0xaa,
0xbb,
0xcc,
0xdd,
0xee,
0xff,
0x00,
0x11,
0x22,
0x33,
0x44,
0x55,
0x66,
0x77};
std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}());
std::uniform_int_distribution<std::int32_t> dis(0, max_value - min_value + 1);
while(first != last)
{
int randomInt = dis(gen);
*first = int4_array[randomInt + (min_value + 8)];
++first;
}
}
template <typename ForwardRange>
auto operator()(ForwardRange&& range) const
-> std::void_t<decltype(std::declval<const FillUniformDistribution&>()(
std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range))))>
{
(*this)(std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range)));
}
};
namespace impl {
// clang-format off
@@ -169,7 +242,7 @@ struct FillNormalDistribution
return;
// need to make each thread unique, add an offset to current seed
std::mt19937 gen(seed_.has_value() ? (*seed_ + iw_begin)
: std::random_device{}());
: std::random_device{}());
std::normal_distribution<float> dis(mean_, std::sqrt(variance_));
std::generate(first + iw_begin, first + iw_end, [&dis, &gen]() {
return ck_tile::type_convert<T>(dis(gen));
@@ -280,7 +353,7 @@ struct FillMonotonicSeq
template <typename ForwardIter>
void operator()(ForwardIter first, ForwardIter last) const
{
std::generate(first, last, [=, n = init_value_]() mutable {
std::generate(first, last, [=, *this, n = init_value_]() mutable {
auto tmp = n;
if constexpr(std::is_same_v<decltype(tmp), pk_int4_t>)
{
@@ -315,7 +388,7 @@ struct FillStepRange
template <typename ForwardIter>
void operator()(ForwardIter first, ForwardIter last) const
{
std::generate(first, last, [=, n = start_value_]() mutable {
std::generate(first, last, [=, *this, n = start_value_]() mutable {
auto tmp = n;
n += step_;
if constexpr(IsAscending)
@@ -334,9 +407,10 @@ struct FillStepRange
}
template <typename ForwardRange>
auto operator()(ForwardRange&& range) const -> std::void_t<
decltype(std::declval<const FillStepRange&>()(std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range))))>
auto operator()(ForwardRange&& range) const
-> std::void_t<decltype(std::declval<const FillStepRange&>()(
std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range))))>
{
(*this)(std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range)));
@@ -355,9 +429,53 @@ struct FillConstant
}
template <typename ForwardRange>
auto operator()(ForwardRange&& range) const -> std::void_t<
decltype(std::declval<const FillConstant&>()(std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range))))>
auto operator()(ForwardRange&& range) const
-> std::void_t<decltype(std::declval<const FillConstant&>()(
std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range))))>
{
(*this)(std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range)));
}
};
//----------------------------------------------------------------------------------------------
/// @brief Transforms given input to fit 2:4 structured sparsity pattern so
/// every subgroup of 4 elements contain at most 2 non-zero elements
template <typename T>
struct AdjustToStructuredSparsity
{
size_t start{0};
// masks represent all valid 2:4 structured sparsity permutations
// clang-format off
static constexpr int32_t masks[] = {0, 0, 1, 1,
0, 1, 0, 1,
0, 1, 1, 0,
1, 0, 0, 1,
1, 0, 1, 0,
1, 1, 0, 0,
0, 0, 0, 1,
0, 0, 1, 0,
0, 1, 0, 0,
1, 0, 0, 0};
// clang-format on
template <typename ForwardIter>
void operator()(ForwardIter first, ForwardIter last) const
{
std::transform(first, last, first, [=, *this, index = start](T val) mutable {
auto tmp = val * masks[index % (sizeof(masks) / sizeof(int32_t))];
index += 1;
return type_convert<T>(tmp);
});
}
template <typename ForwardRange>
auto operator()(ForwardRange&& range) const
-> std::void_t<decltype(std::declval<const AdjustToStructuredSparsity&>()(
std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range))))>
{
(*this)(std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range)));
@@ -396,9 +514,10 @@ struct FillTrigValue
}
template <typename ForwardRange>
auto operator()(ForwardRange&& range) const -> std::void_t<
decltype(std::declval<const FillTrigValue&>()(std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range))))>
auto operator()(ForwardRange&& range) const
-> std::void_t<decltype(std::declval<const FillTrigValue&>()(
std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range))))>
{
(*this)(std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range)));

View File

@@ -0,0 +1,30 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <hip/hip_runtime.h>
namespace ck_tile {
static __global__ void flush_cache()
{
asm __volatile__("s_icache_inv \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t" ::
:);
}
} // namespace ck_tile

View File

@@ -85,6 +85,19 @@ CK_TILE_HOST auto construct_f_unpack_args(F, T args)
return construct_f_unpack_args_impl<F>(args, std::make_index_sequence<N>{});
}
/**
* @brief Descriptor for tensors in host memory.
*
* HostTensorDescriptor manages the shape (dimensions) and memory layout (strides)
* of a tensor in host memory. It provides functionality to:
* - Store tensor dimensions and strides
* - Calculate default strides for contiguous memory layout
* - Convert multi-dimensional indices to linear memory offsets
* - Query tensor metadata (dimensions, element counts, etc.)
*
* The class supports both automatic stride calculation for contiguous memory layout
* and custom strides for more complex memory patterns.
*/
struct HostTensorDescriptor
{
HostTensorDescriptor() = default;
@@ -138,12 +151,35 @@ struct HostTensorDescriptor
}
std::size_t get_num_of_dimension() const { return mLens.size(); }
/**
* @brief Calculates the total number of elements in the tensor.
*
* Computes the product of all dimension lengths to determine the
* total element count in the tensor.
*
* @pre The lengths array (mLens) and strides array (mStrides) must have
* the same size.
*
* @return The total number of elements in the tensor.
*/
std::size_t get_element_size() const
{
assert(mLens.size() == mStrides.size());
return std::accumulate(
mLens.begin(), mLens.end(), std::size_t{1}, std::multiplies<std::size_t>());
}
/**
* @brief Calculates the total element space required for the tensor in memory.
*
* This method computes the minimum size of contiguous memory needed to store
* all elements of the tensor, taking into account the tensor's dimensions and
* strides. The calculation is based on the formula: 1 + max((length_i - 1) * stride_i)
* across all dimensions.
*
* Dimensions with length 0 are skipped in this calculation.
*
* @return The size of the tensor's element space (number of elements).
*/
std::size_t get_element_space_size() const
{
std::size_t space = 1;
@@ -165,6 +201,18 @@ struct HostTensorDescriptor
const std::vector<std::size_t>& get_strides() const { return mStrides; }
/**
* @brief Calculates the linear offset from multi-dimensional indices.
*
* Converts a set of N-dimensional indices into a single linear offset by computing
* the inner product of the indices with the tensor's strides.
*
* @tparam Is Parameter pack of index types (should be convertible to std::size_t)
* @param is Variable number of indices, one for each dimension of the tensor
* @return std::size_t Linear offset corresponding to the given multi-dimensional indices
*
* @pre The number of indices must match the number of dimensions in the tensor
*/
template <typename... Is>
std::size_t GetOffsetFromMultiIndex(Is... is) const
{
@@ -173,7 +221,16 @@ struct HostTensorDescriptor
return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0});
}
std::size_t GetOffsetFromMultiIndex(std::vector<std::size_t> iss) const
/**
* @brief Calculates the linear memory offset from a multi-dimensional index
*
* Computes the linear offset by performing an inner product between the provided
* multi-dimensional indices and the tensor's strides.
*
* @param iss Vector containing the multi-dimensional indices
* @return The calculated linear offset as a size_t
*/
std::size_t GetOffsetFromMultiIndex(const std::vector<std::size_t>& iss) const
{
return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0});
}
@@ -194,8 +251,8 @@ struct HostTensorDescriptor
}
private:
std::vector<std::size_t> mLens;
std::vector<std::size_t> mStrides;
std::vector<std::size_t> mLens; ///< Lengths of each dimension
std::vector<std::size_t> mStrides; ///< Strides for each dimension
};
template <typename New2Old>
@@ -321,7 +378,7 @@ struct HostTensor
~HostTensor() = default;
HostTensor& operator=(const HostTensor&) = default;
HostTensor& operator=(HostTensor&&) = default;
HostTensor& operator=(HostTensor&&) = default;
template <typename FromT>
explicit HostTensor(const HostTensor<FromT>& other) : HostTensor(other.template CopyAsType<T>())
@@ -352,7 +409,13 @@ struct HostTensor
}
// void SetZero() { ck_tile::ranges::fill<T>(mData, 0); }
void SetZero() { std::fill(mData.begin(), mData.end(), 0); }
void SetZero()
{
if constexpr(std::is_same_v<T, e8m0_t>)
std::fill(mData.begin(), mData.end(), e8m0_t{1.f});
else
std::fill(mData.begin(), mData.end(), 0);
}
template <typename F>
void ForEach_impl(F&& f, std::vector<size_t>& idx, size_t rank)
@@ -483,9 +546,12 @@ struct HostTensor
return mData[GetOffsetFromMultiIndex(is...)];
}
T& operator()(std::vector<std::size_t> idx) { return mData[GetOffsetFromMultiIndex(idx)]; }
T& operator()(const std::vector<std::size_t>& idx)
{
return mData[GetOffsetFromMultiIndex(idx)];
}
const T& operator()(std::vector<std::size_t> idx) const
const T& operator()(const std::vector<std::size_t>& idx) const
{
return mData[GetOffsetFromMultiIndex(idx)];
}
@@ -662,6 +728,8 @@ struct HostTensor
file << type_convert<float>(itm) << std::endl;
else if(dtype == "int")
file << type_convert<int>(itm) << std::endl;
else if(dtype == "int8_t")
file << static_cast<int>(type_convert<ck_tile::int8_t>(itm)) << std::endl;
else
// TODO: we didn't implement operator<< for all custom
// data types, here fall back to float in case compile error
@@ -681,6 +749,24 @@ struct HostTensor
Data mData;
};
/**
* @brief Creates a host tensor descriptor with specified dimensions and layout
*
* Constructs a HostTensorDescriptor with appropriate strides based on whether the tensor
* layout is row-major or column-major. This is determined via the compile-time template
* parameter `is_row_major`.
*
* @tparam is_row_major Compile-time flag indicating if the layout is row-major (true) or
* column-major (false)
*
* @param row Number of rows in the tensor
* @param col Number of columns in the tensor
* @param stride Stride between adjacent rows (for row-major) or columns (for column-major)
*
* @return HostTensorDescriptor with shape {row, col} and strides:
* - For row-major: {stride, 1}
* - For column-major: {1, stride}
*/
template <bool is_row_major>
auto host_tensor_descriptor(std::size_t row,
std::size_t col,
@@ -698,6 +784,7 @@ auto host_tensor_descriptor(std::size_t row,
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
}
template <bool is_row_major>
auto get_default_stride(std::size_t row,
std::size_t col,
@@ -718,5 +805,4 @@ auto get_default_stride(std::size_t row,
else
return stride;
}
} // namespace ck_tile

View File

@@ -15,7 +15,7 @@ struct joinable_thread : std::thread
{
}
joinable_thread(joinable_thread&&) = default;
joinable_thread(joinable_thread&&) = default;
joinable_thread& operator=(joinable_thread&&) = default;
~joinable_thread()

View File

@@ -1,23 +1,31 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <numeric>
#include <functional>
#include "ck_tile/core/config.hpp"
#include "ck_tile/host/stream_config.hpp"
#include "ck_tile/core/utility/ignore.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include "ck_tile/host/stream_config.hpp"
#include "ck_tile/host/timer.hpp"
#include <hip/hip_runtime.h>
#include <cstddef>
#include <hip/hip_runtime.h>
namespace ck_tile {
template <int MaxThreadPerBlock, int MinBlockPerCu, typename Kernel, typename... Args>
#if CK_TILE_USE_LAUNCH_BOUNDS
__launch_bounds__(MaxThreadPerBlock, MinBlockPerCu)
#endif
__global__ void kentry(Args... args)
{
#if defined(__HIP_DEVICE_COMPILE__)
Kernel{}(args...);
#else
(..., (ignore = args, 0));
#endif
}
//
@@ -51,6 +59,60 @@ CK_TILE_HOST void launch_and_check(const stream_config& sc, Callables&&... calla
}
}
// Measure the preprocess time during the cold iterations
template <typename TimerType, typename PreprocessFunc>
CK_TILE_HOST double
preprocess_profiling_impl(TimerType timer, const stream_config& s, PreprocessFunc preprocess)
{
timer.start(s.stream_id_);
for(int i = 0; i < s.nrepeat_; i++)
{
if constexpr(!std::is_same_v<PreprocessFunc, std::nullptr_t>)
{
preprocess();
}
}
timer.stop(s.stream_id_);
return timer.duration() / s.nrepeat_;
}
template <typename TimerType, typename CallablesFunc, typename PreprocessFunc = std::nullptr_t>
CK_TILE_HOST double timing_loop_impl(TimerType timer,
const stream_config& s,
CallablesFunc&& callables_func,
PreprocessFunc preprocess = nullptr)
{
for(int i = 0; i < s.cold_niters_; i++)
{
callables_func();
}
// Only profile preprocess if it's provided
auto preprocess_time = 0.0;
if constexpr(!std::is_same_v<PreprocessFunc, std::nullptr_t>)
{
preprocess_time = preprocess_profiling_impl(gpu_timer{}, s, preprocess);
}
int i = 0;
timer.start(s.stream_id_);
while(i < s.nrepeat_)
{
if constexpr(!std::is_same_v<PreprocessFunc, std::nullptr_t>)
{
preprocess();
}
callables_func();
i++;
}
timer.stop(s.stream_id_);
if(!i)
return 0.;
return (timer.duration() / s.nrepeat_) - preprocess_time;
}
// clang-format off
/*
* launch_kernel()
@@ -81,37 +143,48 @@ CK_TILE_HOST void launch_and_check(const stream_config& sc, Callables&&... calla
template <typename... Callables>
CK_TILE_HOST float launch_kernel(const stream_config& s, Callables&&... callables)
{
static_assert(sizeof...(callables) > 0, "At least one callable is required!");
if(!s.time_kernel_)
{
launch_and_check(s, std::forward<Callables>(callables)...);
return 0;
}
auto time_launches = [&](auto timer) {
// warmup
for(int i = 0; i < s.cold_niters_; i++)
{
launch_and_check(s, std::forward<Callables>(callables)...);
}
timer.start(s.stream_id_);
for(int i = 0; i < s.nrepeat_; i++)
{
launch_and_check(s, std::forward<Callables>(callables)...);
}
timer.stop(s.stream_id_);
return timer.duration() / s.nrepeat_;
};
auto callables_func = [&]() { launch_and_check(s, std::forward<Callables>(callables)...); };
if(s.is_gpu_timer_)
{
return time_launches(gpu_timer{});
return timing_loop_impl(gpu_timer{}, s, callables_func);
}
else
{
return time_launches(cpu_timer{});
return timing_loop_impl(cpu_timer{}, s, callables_func);
}
}
template <typename PreprocessFunc, typename... Callables>
CK_TILE_HOST float
launch_kernel_time_mask(const stream_config& s, PreprocessFunc preprocess, Callables&&... callables)
{
static_assert(sizeof...(callables) > 0, "At least one callable is required!");
if(!s.time_kernel_)
{
preprocess();
launch_and_check(s, std::forward<Callables>(callables)...);
return 0;
}
auto callables_func = [&]() { launch_and_check(s, std::forward<Callables>(callables)...); };
if(s.is_gpu_timer_)
{
return timing_loop_impl(gpu_timer{}, s, callables_func, preprocess);
}
else
{
return timing_loop_impl(cpu_timer{}, s, callables_func, preprocess);
}
}
} // namespace ck_tile

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -11,6 +11,110 @@
namespace ck_tile {
template <typename ADataType,
typename QDataType,
typename BDataType,
typename AccDataType,
typename CDataType,
uint32_t QuantGroupSize,
bool aquant,
typename AElementOp = ck_tile::identity,
typename BElementOp = ck_tile::identity,
typename ACCElementOp = ck_tile::identity>
CK_TILE_HOST void reference_gemm_quant(const HostTensor<ADataType>& a_m_k,
const HostTensor<QDataType>& q,
const HostTensor<BDataType>& b_k_n,
HostTensor<CDataType>& c_m_n,
const AElementOp& a_element_op = {},
const BElementOp& b_element_op = {},
const ACCElementOp& acc_element_op = {})
{
const std::size_t M = a_m_k.get_length(0);
const std::size_t N = b_k_n.get_length(1);
const std::size_t K = a_m_k.get_length(1);
auto f_mn = [&](auto m, auto n) {
AccDataType v_acc = 0, v_block_acc = 0;
static_assert(std::is_same_v<ADataType, pk_int4_t> || std::is_same_v<ADataType, fp8_t> ||
std::is_same_v<ADataType, bf8_t>);
static_assert(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t> ||
std::is_same_v<BDataType, pk_int4_t>);
static_assert(std::is_same_v<AccDataType, float>);
static_assert(std::is_same_v<CDataType, float> ||
std::is_same_v<CDataType, ck_tile::half_t>);
for(std::size_t k = 0; k < K; ++k)
{
AccDataType v_a;
AccDataType v_b;
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t_signed_conversion(pk_val);
if(k % 2 == 1)
v_a = fp32_val.hi;
else
v_a = fp32_val.lo;
}
else
{
v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
}
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t_signed_conversion(pk_val);
if(k % 2 == 1)
v_b = fp32_val.hi;
else
v_b = fp32_val.lo;
}
else if constexpr(std::is_same_v<BDataType, fp8_t>)
{
v_b = fp8_to_float_raw(b_element_op(b_k_n(k, n)));
}
else
{
v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
}
v_block_acc += v_a * v_b;
// Apply group dequant scale
if((k + 1) % QuantGroupSize == 0)
{
float scale = 0.f;
index_t outer_dim = (aquant) ? m : k / QuantGroupSize;
index_t inner_dim = (aquant) ? k / QuantGroupSize : n;
if constexpr(std::is_same_v<QDataType, float>)
{
scale = q(outer_dim, inner_dim);
}
else if constexpr(std::is_same_v<QDataType, ck_tile::fp8_t>)
{
scale = fp8_to_float_raw(q(outer_dim, inner_dim));
}
else if constexpr(std::is_same_v<QDataType, ck_tile::bf8_t>)
{
scale = bf8_to_float_raw(q(outer_dim, inner_dim));
}
else
{
static_assert(false, "Unexpected Q datatype.");
}
v_block_acc *= scale;
v_acc += v_block_acc;
v_block_acc = 0;
}
}
c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
};
make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
std::cout << std::endl;
}
template <typename ADataType,
typename BDataType,
typename AccDataType,
@@ -71,6 +175,58 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
}
template <typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
typename CDataType,
typename ACCElementOp,
typename DDataType = remove_cvref_t<std::tuple_element_t<0, DsDataType>>>
CK_TILE_HOST void
reference_gemm_multiple_d(const HostTensor<ADataType>& a_m_k,
const HostTensor<BDataType>& b_k_n,
const std::array<HostTensor<DDataType>, DsDataType::size()>& ds_m_n,
HostTensor<CDataType>& c_m_n,
const ACCElementOp& acc_element_op = {})
{
const std::size_t M = a_m_k.get_length(0);
const std::size_t N = b_k_n.get_length(1);
const std::size_t K = a_m_k.get_length(1);
auto f_mk_kn_mn = [&](auto m, auto n) {
AccDataType v_acc = 0;
for(std::size_t k = 0; k < K; ++k)
{
ADataType v_a = a_m_k(m, k);
BDataType v_b = b_k_n(k, n);
v_acc +=
ck_tile::type_convert<AccDataType>(v_a) * ck_tile::type_convert<AccDataType>(v_b);
}
CDataType v_c = 0;
if constexpr(DsDataType::size() == 0)
{
acc_element_op(v_c, ck_tile::type_convert<float>(v_acc));
}
else if constexpr(DsDataType::size() == 1)
{
acc_element_op(v_c,
ck_tile::type_convert<float>(v_acc),
ck_tile::type_convert<float>(ds_m_n[0](m, n)));
}
else if constexpr(DsDataType::size() == 2)
{
acc_element_op(v_c,
ck_tile::type_convert<float>(v_acc),
ck_tile::type_convert<float>(ds_m_n[0](m, n)),
ck_tile::type_convert<float>(ds_m_n[1](m, n)));
}
c_m_n(m, n) = ck_tile::type_convert<CDataType>(v_c);
};
make_ParallelTensorFunctor(f_mk_kn_mn, M, N)(std::thread::hardware_concurrency());
}
template <typename ADataType,
typename BDataType,
typename AccDataType,

View File

@@ -0,0 +1,167 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <thread>
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace ck_tile {
template <ck_tile::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType>
CK_TILE_HOST void
reference_grouped_conv_bwd_weight(const HostTensor<InDataType>& input,
HostTensor<WeiDataType>& weight,
const HostTensor<OutDataType>& output,
std::vector<ck_tile::long_index_t> conv_strides,
std::vector<ck_tile::long_index_t> conv_dilations,
std::vector<ck_tile::long_index_t> in_left_pads,
std::vector<ck_tile::long_index_t>)
{
if(!(input.get_num_of_dimension() == NDimSpatial + 3 &&
weight.get_num_of_dimension() == NDimSpatial + 3 &&
output.get_num_of_dimension() == NDimSpatial + 3))
{
throw std::runtime_error("wrong! inconsistent dimension");
}
if constexpr(NDimSpatial == 1)
{
auto func = [&](auto g, auto k, auto c, auto x) {
float v_acc = 0;
for(std::size_t n = 0; n < output.get_lengths()[1]; ++n)
{
for(std::size_t wo = 0; wo < output.get_lengths()[3]; ++wo)
{
auto wi = static_cast<ck_tile::long_index_t>(wo * conv_strides[0]) +
static_cast<ck_tile::long_index_t>(x * conv_dilations[0]) -
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
if(wi >= 0 && ck_tile::type_convert<std::size_t>(wi) < input.get_lengths()[3])
{
InDataType v_in = input(g, n, c, wi);
OutDataType v_out = output(g, n, k, wo);
v_acc += ck_tile::type_convert<float>(v_out) *
ck_tile::type_convert<float>(v_in);
}
}
}
OutDataType v_acc_converted = ck_tile::type_convert<WeiDataType>(v_acc);
weight(g, k, c, x) = v_acc_converted;
};
make_ParallelTensorFunctor(func,
weight.get_lengths()[0],
weight.get_lengths()[1],
weight.get_lengths()[2],
weight.get_lengths()[3])(std::thread::hardware_concurrency());
}
else if constexpr(NDimSpatial == 2)
{
auto func = [&](auto g, auto k, auto c, auto y, auto x) {
float v_acc = 0;
for(std::size_t n = 0; n < output.get_lengths()[1]; ++n)
{
for(std::size_t ho = 0; ho < output.get_lengths()[3]; ++ho)
{
auto hi = static_cast<ck_tile::long_index_t>(ho * conv_strides[0]) +
static_cast<ck_tile::long_index_t>(y * conv_dilations[0]) -
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
for(std::size_t wo = 0; wo < output.get_lengths()[4]; ++wo)
{
auto wi = static_cast<ck_tile::long_index_t>(wo * conv_strides[1]) +
static_cast<ck_tile::long_index_t>(x * conv_dilations[1]) -
static_cast<ck_tile::long_index_t>(in_left_pads[1]);
if(hi >= 0 &&
ck_tile::type_convert<std::size_t>(hi) < input.get_lengths()[3] &&
wi >= 0 &&
ck_tile::type_convert<std::size_t>(wi) < input.get_lengths()[4])
{
InDataType v_in = input(g, n, c, hi, wi);
OutDataType v_out = output(g, n, k, ho, wo);
v_acc += ck_tile::type_convert<float>(v_out) *
ck_tile::type_convert<float>(v_in);
}
}
}
}
WeiDataType v_acc_converted = ck_tile::type_convert<WeiDataType>(v_acc);
weight(g, k, c, y, x) = v_acc_converted;
};
make_ParallelTensorFunctor(func,
weight.get_lengths()[0],
weight.get_lengths()[1],
weight.get_lengths()[2],
weight.get_lengths()[3],
weight.get_lengths()[4])(std::thread::hardware_concurrency());
}
else if constexpr(NDimSpatial == 3)
{
auto func = [&](auto g, auto k, auto c, auto z, auto y, auto x) {
float v_acc = 0;
for(std::size_t n = 0; n < output.get_lengths()[1]; ++n)
{
for(std::size_t do_ = 0; do_ < output.get_lengths()[3]; ++do_)
{
auto di = static_cast<ck_tile::long_index_t>(do_ * conv_strides[0]) +
static_cast<ck_tile::long_index_t>(z * conv_dilations[0]) -
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
for(std::size_t ho = 0; ho < output.get_lengths()[4]; ++ho)
{
auto hi = static_cast<ck_tile::long_index_t>(ho * conv_strides[1]) +
static_cast<ck_tile::long_index_t>(y * conv_dilations[1]) -
static_cast<ck_tile::long_index_t>(in_left_pads[1]);
for(std::size_t wo = 0; wo < output.get_lengths()[5]; ++wo)
{
auto wi = static_cast<ck_tile::long_index_t>(wo * conv_strides[2]) +
static_cast<ck_tile::long_index_t>(x * conv_dilations[2]) -
static_cast<ck_tile::long_index_t>(in_left_pads[2]);
if(di >= 0 &&
ck_tile::type_convert<std::size_t>(di) < input.get_lengths()[3] &&
hi >= 0 &&
ck_tile::type_convert<std::size_t>(hi) < input.get_lengths()[4] &&
wi >= 0 &&
ck_tile::type_convert<std::size_t>(wi) < input.get_lengths()[5])
{
InDataType v_in = input(g, n, c, di, hi, wi);
OutDataType v_out = output(g, n, k, do_, ho, wo);
v_acc += ck_tile::type_convert<float>(v_out) *
ck_tile::type_convert<float>(v_in);
}
}
}
}
}
WeiDataType v_acc_converted = ck_tile::type_convert<WeiDataType>(v_acc);
weight(g, k, c, z, y, x) = v_acc_converted;
};
make_ParallelTensorFunctor(func,
weight.get_lengths()[0],
weight.get_lengths()[1],
weight.get_lengths()[2],
weight.get_lengths()[3],
weight.get_lengths()[4],
weight.get_lengths()[5])(std::thread::hardware_concurrency());
}
else
{
throw std::runtime_error(
"Ref_conv_bwd_weight: number of dimensions must be between 1 and 3.");
}
}
} // namespace ck_tile

View File

@@ -0,0 +1,165 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <thread>
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace ck_tile {
template <ck_tile::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType>
CK_TILE_HOST void reference_grouped_conv_fwd(const HostTensor<InDataType>& input,
const HostTensor<WeiDataType>& weight,
HostTensor<OutDataType>& output,
std::vector<ck_tile::long_index_t> conv_strides,
std::vector<ck_tile::long_index_t> conv_dilations,
std::vector<ck_tile::long_index_t> in_left_pads,
std::vector<ck_tile::long_index_t>)
{
if(!(input.get_num_of_dimension() == NDimSpatial + 3 &&
weight.get_num_of_dimension() == NDimSpatial + 3 &&
output.get_num_of_dimension() == NDimSpatial + 3))
{
throw std::runtime_error("wrong! inconsistent dimension");
}
if constexpr(NDimSpatial == 1)
{
auto func = [&](auto g, auto n, auto k, auto wo) {
float v_acc = 0;
for(std::size_t c = 0; c < weight.get_lengths()[2]; ++c)
{
for(std::size_t x = 0; x < weight.get_lengths()[3]; ++x)
{
auto wi = static_cast<ck_tile::long_index_t>(wo * conv_strides[0]) +
static_cast<ck_tile::long_index_t>(x * conv_dilations[0]) -
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
if(wi >= 0 && ck_tile::type_convert<std::size_t>(wi) < input.get_lengths()[3])
{
InDataType v_in = input(g, n, c, wi);
WeiDataType v_wei = weight(g, k, c, x);
v_acc += ck_tile::type_convert<float>(v_in) *
ck_tile::type_convert<float>(v_wei);
}
}
}
OutDataType v_acc_converted = ck_tile::type_convert<OutDataType>(v_acc);
output(g, n, k, wo) = v_acc_converted;
};
make_ParallelTensorFunctor(func,
output.get_lengths()[0],
output.get_lengths()[1],
output.get_lengths()[2],
output.get_lengths()[3])(std::thread::hardware_concurrency());
}
else if constexpr(NDimSpatial == 2)
{
auto func = [&](auto g, auto n, auto k, auto ho, auto wo) {
float v_acc = 0;
for(std::size_t c = 0; c < weight.get_lengths()[2]; ++c)
{
for(std::size_t y = 0; y < weight.get_lengths()[3]; ++y)
{
auto hi = static_cast<ck_tile::long_index_t>(ho * conv_strides[0]) +
static_cast<ck_tile::long_index_t>(y * conv_dilations[0]) -
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
for(std::size_t x = 0; x < weight.get_lengths()[4]; ++x)
{
auto wi = static_cast<ck_tile::long_index_t>(wo * conv_strides[1]) +
static_cast<ck_tile::long_index_t>(x * conv_dilations[1]) -
static_cast<ck_tile::long_index_t>(in_left_pads[1]);
if(hi >= 0 &&
ck_tile::type_convert<std::size_t>(hi) < input.get_lengths()[3] &&
wi >= 0 &&
ck_tile::type_convert<std::size_t>(wi) < input.get_lengths()[4])
{
InDataType v_in = input(g, n, c, hi, wi);
WeiDataType v_wei = weight(g, k, c, y, x);
v_acc += ck_tile::type_convert<float>(v_in) *
ck_tile::type_convert<float>(v_wei);
}
}
}
}
OutDataType v_acc_converted = ck_tile::type_convert<OutDataType>(v_acc);
output(g, n, k, ho, wo) = v_acc_converted;
};
make_ParallelTensorFunctor(func,
output.get_lengths()[0],
output.get_lengths()[1],
output.get_lengths()[2],
output.get_lengths()[3],
output.get_lengths()[4])(std::thread::hardware_concurrency());
}
else if constexpr(NDimSpatial == 3)
{
auto func = [&](auto g, auto n, auto k, auto d_o, auto ho, auto wo) {
float v_acc = 0;
for(std::size_t c = 0; c < weight.get_lengths()[2]; ++c)
{
for(std::size_t z = 0; z < weight.get_lengths()[3]; ++z)
{
auto di = static_cast<ck_tile::long_index_t>(d_o * conv_strides[0]) +
static_cast<ck_tile::long_index_t>(z * conv_dilations[0]) -
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
for(std::size_t y = 0; y < weight.get_lengths()[4]; ++y)
{
auto hi = static_cast<ck_tile::long_index_t>(ho * conv_strides[1]) +
static_cast<ck_tile::long_index_t>(y * conv_dilations[1]) -
static_cast<ck_tile::long_index_t>(in_left_pads[1]);
for(std::size_t x = 0; x < weight.get_lengths()[5]; ++x)
{
auto wi = static_cast<ck_tile::long_index_t>(wo * conv_strides[2]) +
static_cast<ck_tile::long_index_t>(x * conv_dilations[2]) -
static_cast<ck_tile::long_index_t>(in_left_pads[2]);
if(di >= 0 &&
ck_tile::type_convert<std::size_t>(di) < input.get_lengths()[3] &&
hi >= 0 &&
ck_tile::type_convert<std::size_t>(hi) < input.get_lengths()[4] &&
wi >= 0 &&
ck_tile::type_convert<std::size_t>(wi) < input.get_lengths()[5])
{
InDataType v_in = input(g, n, c, di, hi, wi);
WeiDataType v_wei = weight(g, k, c, z, y, x);
v_acc += ck_tile::type_convert<float>(v_in) *
ck_tile::type_convert<float>(v_wei);
}
}
}
}
}
OutDataType v_acc_converted = ck_tile::type_convert<OutDataType>(v_acc);
output(g, n, k, d_o, ho, wo) = v_acc_converted;
};
make_ParallelTensorFunctor(func,
output.get_lengths()[0],
output.get_lengths()[1],
output.get_lengths()[2],
output.get_lengths()[3],
output.get_lengths()[4],
output.get_lengths()[5])(std::thread::hardware_concurrency());
}
else
{
throw std::runtime_error("Ref_Conv_fwd: number of dimensions must be between 1 and 3.");
}
}
} // namespace ck_tile

View File

@@ -9,7 +9,7 @@
namespace ck_tile {
#define MOE_SORTING_MOCK_ID(token_id_, topk_id_) \
static_cast<uint32_t>(((token_id_)&0x00ffffff) | (((topk_id_)&0xff) << 24))
static_cast<uint32_t>(((token_id_) & 0x00ffffff) | (((topk_id_) & 0xff) << 24))
template <typename WeightType, typename IndexType = index_t>
CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
@@ -21,10 +21,12 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
index_t& unit_cnt,
const index_t experts,
const index_t unit_size,
const index_t tokens,
bool local_expert_masking,
bool skip_experts_with_zero_token = true)
{
const index_t num_token = topk_ids.mDesc.get_lengths()[0];
// note: if tokens is smaller than topk_ids.mDesc.get_lengths()[0], indicating local_token case
const index_t num_token = tokens; // topk_ids.mDesc.get_lengths()[0];
const index_t topk = topk_ids.mDesc.get_lengths()[1];
// allocate a temp buffer, and fill the value with [number_token|topk]
std::vector<std::vector<IndexType>> expert_tokens(

View File

@@ -30,4 +30,82 @@ reference_reduce(const HostTensor<XDataType>& x_m_n, HostTensor<YDataType>& y_m,
make_ParallelTensorFunctor(f, y_m.mDesc.get_lengths()[0])(std::thread::hardware_concurrency());
}
// Generic reference reduce for arbitrary dimensions
template <
typename XDataType,
typename ComputeDataType,
typename YDataType,
typename ReduceOp,
typename KeptDim, // Expected type: ck_tile::sequence<...> containing dimension indices to keep
typename ReduceDims> // Expected type: ck_tile::sequence<...> containing dimension indices to
// reduce
CK_TILE_HOST void reference_reduce(const HostTensor<XDataType>& x_tensor,
HostTensor<YDataType>& y_tensor,
ReduceOp reduce_op,
KeptDim kept_dim,
ReduceDims reduce_dims)
{
const auto& x_lengths = x_tensor.mDesc.get_lengths();
// Calculate total kept elements (product of all kept dimension lengths)
index_t total_kept_elements = 1;
static_for<0, kept_dim.size(), 1>{}(
[&](auto i) { total_kept_elements *= x_lengths[kept_dim.at(i)]; });
// Calculate total reduce elements (product of all reduce dimension lengths)
index_t total_reduce_elements = 1;
static_for<0, reduce_dims.size(), 1>{}(
[&](auto i) { total_reduce_elements *= x_lengths[reduce_dims.at(i)]; });
auto f = [&](auto linear_kept_idx) {
ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
// Convert linear kept index to multi-dimensional kept indices
std::vector<index_t> kept_indices(kept_dim.size());
index_t temp_kept = linear_kept_idx;
static_for<0, kept_dim.size(), 1>{}([&](auto i) {
constexpr auto dim_idx = kept_dim.size() - 1 - i;
constexpr auto dim = kept_dim.at(dim_idx);
const auto len = x_lengths[dim];
kept_indices[dim_idx] = temp_kept % len;
temp_kept /= len;
});
for(index_t reduce_idx = 0; reduce_idx < total_reduce_elements; ++reduce_idx)
{
// Convert linear reduce index to multi-dimensional reduce indices
std::vector<index_t> reduce_indices(reduce_dims.size());
index_t temp_reduce = reduce_idx;
static_for<0, reduce_dims.size(), 1>{}([&](auto i) {
constexpr auto dim_idx = reduce_dims.size() - 1 - i;
constexpr auto dim = reduce_dims.at(dim_idx);
const auto len = x_lengths[dim];
reduce_indices[dim_idx] = temp_reduce % len;
temp_reduce /= len;
});
// Build full input tensor indices by combining kept and reduce indices
std::vector<std::size_t> full_indices(x_lengths.size(), 0);
static_for<0, kept_dim.size(), 1>{}(
[&](auto i) { full_indices[kept_dim.at(i)] = kept_indices[i]; });
static_for<0, reduce_dims.size(), 1>{}(
[&](auto i) { full_indices[reduce_dims.at(i)] = reduce_indices[i]; });
// Access input tensor element
const auto v_a = type_convert<ComputeDataType>(x_tensor(full_indices));
v_acc = reduce_op(v_acc, v_a);
}
// Calculate output tensor index using kept indices
// The output tensor has the same structure as the kept dimensions
std::vector<std::size_t> y_indices(kept_dim.size());
static_for<0, kept_dim.size(), 1>{}([&](auto i) { y_indices[i] = kept_indices[i]; });
y_tensor(y_indices) = type_convert<YDataType>(v_acc);
};
make_ParallelTensorFunctor(f, total_kept_elements)(std::thread::hardware_concurrency());
}
} // namespace ck_tile

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -14,7 +14,7 @@ CK_TILE_HOST void
reference_softmax(const HostTensor<InputType>& x, HostTensor<OutputType>& y, index_t dim = -1)
{
index_t rank = x.get_num_of_dimension();
assert(rank == y.get_num_of_dimension());
assert(static_cast<std::size_t>(rank) == y.get_num_of_dimension());
assert(dim == -1 || dim < rank);
index_t target_dim = dim == -1 ? (rank - 1) : dim;

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -38,8 +38,8 @@ CK_TILE_HOST void reference_topk(const HostTensor<DataType>& x,
{
// rank must be the same
index_t rank = x.get_num_of_dimension();
assert(rank == y_values.get_num_of_dimension());
assert(rank == y_indices.get_num_of_dimension());
assert(static_cast<std::size_t>(rank) == y_values.get_num_of_dimension());
assert(static_cast<size_t>(rank) == y_indices.get_num_of_dimension());
assert(dim == -1 || dim < rank);
index_t topk_dim = dim == -1 ? (rank - 1) : dim;
@@ -47,7 +47,8 @@ CK_TILE_HOST void reference_topk(const HostTensor<DataType>& x,
auto x_len = x.get_lengths();
assert(k <= topk_src_len);
assert(k == y_values.get_length(topk_dim) && k == y_indices.get_length(topk_dim));
assert(static_cast<size_t>(k) == y_values.get_length(topk_dim) &&
static_cast<size_t>(k) == y_indices.get_length(topk_dim));
index_t n_parallel = x.get_element_size() / topk_src_len;

View File

@@ -0,0 +1,33 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
namespace ck_tile {
template <typename ADataType, typename BDataType>
void reference_transpose_elementwise(const HostTensor<ADataType>& a, HostTensor<BDataType>& b)
{
ck_tile::index_t M = static_cast<ck_tile::index_t>(a.mDesc.get_lengths()[0]);
ck_tile::index_t N = static_cast<ck_tile::index_t>(a.mDesc.get_lengths()[1]);
// Ensure the b tensor is sized correctly for N x M
if(static_cast<ck_tile::index_t>(b.mDesc.get_lengths()[0]) != N ||
static_cast<ck_tile::index_t>(b.mDesc.get_lengths()[1]) != M)
{
throw std::runtime_error("Output tensor b has incorrect dimensions for transpose.");
}
auto f = [&](auto i, auto j) {
auto v_a = a(i, j);
b(j, i) = ck_tile::type_convert<BDataType>(v_a);
};
make_ParallelTensorFunctor(f, M, N)(std::thread::hardware_concurrency());
}
} // namespace ck_tile

View File

@@ -0,0 +1,102 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include <hip/hip_runtime.h>
namespace ck_tile {
template <typename ADataType, typename BDataType>
struct RotatingMemWrapper
{
RotatingMemWrapper() = delete;
RotatingMemWrapper(const void* a_ptr_,
const void* b_ptr_,
std::size_t rotating_count_,
std::size_t size_a_,
std::size_t size_b_)
: a_ptr(a_ptr_),
b_ptr(b_ptr_),
rotating_count(rotating_count_),
size_a(size_a_),
size_b(size_b_)
{
p_a_grids.push_back(a_ptr);
p_b_grids.push_back(b_ptr);
for(size_t i = 1; i < rotating_count; i++)
{
{
void* pADeviceBuf;
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&pADeviceBuf), size_a_));
HIP_CHECK_ERROR(hipMemcpy(static_cast<void*>(pADeviceBuf),
const_cast<void*>(p_a_grids[0]),
size_a_,
hipMemcpyDeviceToDevice));
p_a_grids.push_back(pADeviceBuf);
}
{
void* pBDeviceBuf;
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&pBDeviceBuf), size_b_));
HIP_CHECK_ERROR(hipMemcpy(static_cast<void*>(pBDeviceBuf),
const_cast<void*>(p_b_grids[0]),
size_b_,
hipMemcpyDeviceToDevice));
p_b_grids.push_back(pBDeviceBuf);
}
}
}
void Next()
{
if(rotating_count > 1)
{
std::size_t idx = iter++ % rotating_count;
a_ptr = p_a_grids[idx];
b_ptr = p_b_grids[idx];
}
}
void Print()
{
std::cout << "RotatingMemWrapper: { size_a: " << size_a << ", size_b: " << size_b
<< ", rotating_count: " << rotating_count << "}" << std::endl;
}
~RotatingMemWrapper() noexcept
{
if(rotating_count > 1)
{
// restore ptr
a_ptr = p_a_grids[0];
b_ptr = p_b_grids[0];
// free device mem
for(size_t i = 1; i < rotating_count; i++)
{
ck_tile::hip_check_error(hipFree(const_cast<void*>(p_a_grids[i])));
ck_tile::hip_check_error(hipFree(const_cast<void*>(p_b_grids[i])));
}
}
}
private:
const void* a_ptr;
const void* b_ptr;
std::size_t iter = 0;
std::size_t rotating_count = 1;
std::size_t size_a = 0;
std::size_t size_b = 0;
std::vector<const void*> p_a_grids;
std::vector<const void*> p_b_grids;
};
inline void flush_icache()
{
hipDeviceProp_t deviceProps;
HIP_CHECK_ERROR(hipGetDeviceProperties(&deviceProps, 0));
int32_t gpu_block3 = deviceProps.multiProcessorCount * 60;
ck_tile::flush_cache<<<dim3(gpu_block3), dim3(64), 0, nullptr>>>();
HIP_CHECK_ERROR(hipGetLastError());
}
} // namespace ck_tile

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -20,6 +20,10 @@ namespace ck_tile {
*
* // create stream config with _some_stream_id_, and benchmark using cpu timer
* stream_config s = stream_config{_some_stream_id_, true, 0, 3, 10, false};
*
* // create stream config with _some_stream_id_, and enable gpu timer for rotating buffer with
*rotating buffer count stream_config s = stream_config{_some_stream_id_, true, 0, 3, 10, true,
*true, 1};
**/
struct stream_config
@@ -30,5 +34,7 @@ struct stream_config
int cold_niters_ = 3;
int nrepeat_ = 10;
bool is_gpu_timer_ = true; // keep compatible
bool flush_cache_ = false;
int rotating_count_ = 1;
};
} // namespace ck_tile

View File

@@ -0,0 +1,45 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <hip/hip_runtime_api.h>
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/host/stream_config.hpp"
#include "ck_tile/host/hip_check_error.hpp"
namespace ck_tile {
static inline index_t get_available_compute_units(const stream_config& s)
{
constexpr static uint32_t MAX_MASK_DWORDS = 64;
// assume at most 64*32 = 2048 CUs
uint32_t cu_mask[MAX_MASK_DWORDS]{};
auto count_set_bits = [](uint32_t dword) {
index_t count = 0;
while(dword != 0)
{
if(dword & 0x1)
{
count++;
}
dword = dword >> 1;
}
return count;
};
HIP_CHECK_ERROR(hipExtStreamGetCUMask(s.stream_id_, MAX_MASK_DWORDS, &cu_mask[0]));
index_t num_cu = 0;
for(uint32_t i = 0; i < MAX_MASK_DWORDS; i++)
{
num_cu += count_set_bits(cu_mask[i]);
}
return num_cu;
};
} // namespace ck_tile

View File

@@ -4,6 +4,10 @@
#pragma once
#include "ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp"
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_common_policy.hpp"
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_pipeline.hpp"
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_policy.hpp"
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_problem.hpp"
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_pipeline.hpp"
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp"
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp"

View File

@@ -19,7 +19,6 @@ struct BatchedTransposeHostArgs
index_t batch;
index_t height;
index_t width;
// index_t dim_blocks;
index_t dim_stride;
index_t dim_block_h;
index_t dim_block_w;
@@ -28,10 +27,12 @@ struct BatchedTransposeHostArgs
template <typename Pipeline_>
struct BatchedTransposeKernel
{
using Pipeline = remove_cvref_t<Pipeline_>;
using Problem = remove_cvref_t<typename Pipeline::Problem>;
using Type = typename Problem::InputType;
CK_TILE_DEVICE static index_t counter = 0;
using Pipeline = remove_cvref_t<Pipeline_>;
using Problem = remove_cvref_t<typename Pipeline::Problem>;
using Type = typename Problem::DataType;
struct BatchedTransposeKargs
{
@@ -46,11 +47,13 @@ struct BatchedTransposeKernel
using Kargs = BatchedTransposeKargs;
using Hargs = BatchedTransposeHostArgs;
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
CK_TILE_HOST static constexpr auto GridSize(const Hargs& host_args)
{
size_t grid_size_x = (h.width + h.dim_block_w - 1) / h.dim_block_w;
size_t grid_size_y = (h.height + h.dim_block_h - 1) / h.dim_block_h;
size_t grid_size_z = h.batch;
const size_t grid_size_x =
ck_tile::integer_divide_ceil(host_args.height, host_args.dim_block_h);
const size_t grid_size_y =
ck_tile::integer_divide_ceil(host_args.width, host_args.dim_block_w);
const size_t grid_size_z = host_args.batch;
return dim3(grid_size_x, grid_size_y, grid_size_z);
}
@@ -66,62 +69,58 @@ struct BatchedTransposeKernel
return k;
}
CK_TILE_HOST_DEVICE static constexpr auto BlockSize() { return Problem::kBlockSize; }
CK_TILE_HOST static constexpr auto BlockSize() { return Problem::kBlockSize; }
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
static constexpr ck_tile::index_t kMPerBlock = Problem::kMPerBlock;
static constexpr ck_tile::index_t kNPerBlock = Problem::kNPerBlock;
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr ck_tile::index_t VectorSizeInput = Problem::VectorSizeInput;
static constexpr ck_tile::index_t VectorStrideInput = 1;
static constexpr ck_tile::index_t VectorSizeOutput = Problem::VectorSizeOutput;
static constexpr ck_tile::index_t VectorStrideOutput = 1;
static constexpr ck_tile::index_t kMPerBlock = Problem::kMPerBlock;
static constexpr ck_tile::index_t kNPerBlock = Problem::kNPerBlock;
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
const auto iM = __builtin_amdgcn_readfirstlane(blockIdx.x * kMPerBlock);
const auto iN = __builtin_amdgcn_readfirstlane(blockIdx.y * kNPerBlock);
const auto offset = __builtin_amdgcn_readfirstlane(blockIdx.z * kargs.height * kargs.width);
static constexpr ck_tile::index_t kMPerThread = Problem::kMPerThread;
static constexpr ck_tile::index_t kNPerThread = Problem::kNPerThread;
static_assert(kMPerThread == 1 && kNPerThread == 1);
const auto iDim = blockIdx.z;
const auto x_m_n = [&]() {
const auto x_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<const Type*>(kargs.p_input) + iDim * kargs.dim_stride,
static_cast<const Type*>(kargs.p_input) + offset,
make_tuple(kargs.height, kargs.width),
make_tuple(kargs.width, 1),
number<kNPerThread>{}, // TODO thread load value
number<1>{});
number<VectorSizeInput>{},
number<VectorStrideInput>{});
return pad_tensor_view(x_dram_naive,
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
sequence<kPadM, kPadN>{});
}();
const auto iM = __builtin_amdgcn_readfirstlane(blockIdx.x * kMPerBlock);
const auto iN = __builtin_amdgcn_readfirstlane(blockIdx.y * kNPerBlock);
const auto y_n_m = [&]() {
const auto y_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<Type*>(kargs.p_output) + iDim * kargs.dim_stride,
static_cast<Type*>(kargs.p_output) + offset,
make_tuple(kargs.width, kargs.height),
make_tuple(kargs.height, 1),
number<kMPerThread>{},
number<1>{});
number<VectorSizeOutput>{},
number<VectorStrideOutput>{});
return pad_tensor_view(y_dram_naive,
make_tuple(number<kNPerBlock>{}, number<kMPerBlock>{}),
sequence<kPadN, kPadM>{});
}();
auto x_block_window =
make_tile_window(x_m_n,
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
{static_cast<ck_tile::index_t>(iM * kMPerBlock),
static_cast<ck_tile::index_t>(iN * kNPerBlock)});
auto x_block_window = make_tile_window(
x_m_n,
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
{static_cast<ck_tile::index_t>(iM), static_cast<ck_tile::index_t>(iN)});
auto y_block_window =
make_tile_window(y_n_m,
make_tuple(number<kNPerBlock>{}, number<kMPerBlock>{}),
{static_cast<ck_tile::index_t>(iN * kNPerBlock),
static_cast<ck_tile::index_t>(iM * kMPerBlock)});
auto y_block_window = make_tile_window(
y_n_m,
make_tuple(number<kNPerBlock>{}, number<kMPerBlock>{}),
{static_cast<ck_tile::index_t>(iN), static_cast<ck_tile::index_t>(iM)});
Pipeline{}(x_block_window, y_block_window);
}

View File

@@ -0,0 +1,33 @@
// SPDX-License-Identifier: MIT
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
struct BatchedTransposeCommonPolicy
{
CK_TILE_DEVICE static constexpr auto TileAccessPattern =
tile_distribution_pattern::thread_raked;
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeInputDistribution()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kLeadDimPerBlock = Problem::kNPerBlock;
constexpr index_t kSecondDimPerBlock = Problem::kMPerBlock;
constexpr index_t kVectorSize = Problem::VectorSizeInput;
static_assert((kLeadDimPerBlock * kVectorSize) % kBlockSize == 0, "");
using TileEncodingPattern = TileDistributionEncodingPattern2D<kBlockSize,
kSecondDimPerBlock,
kLeadDimPerBlock,
kVectorSize,
TileAccessPattern>;
return TileEncodingPattern::Make2DStaticTileDistribution();
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,67 @@
// SPDX-License-Identifier: MIT
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck_tile {
template <typename Problem_, typename Policy_>
struct BatchedTransposeLdsPipeline
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using DataType = remove_cvref_t<typename Problem::DataType>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kLeadSizePerBlock = Problem::kLeadSizePerBlock;
static constexpr index_t kSecondSizePerBlock = Problem::kSecondSizePerBlock;
static constexpr index_t GetVectorSize() { return Policy::template GetVectorSize<Problem>(); }
CK_TILE_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename InputTileWindow, typename OutputTileWindow>
CK_TILE_DEVICE void operator()(const InputTileWindow& input_window,
OutputTileWindow& output_window)
{
__shared__ char smem[GetSmemSize()];
auto input_tile_window =
make_tile_window(input_window, Policy::template MakeInputDistribution<Problem>());
auto output_tile_window =
make_tile_window(output_window, Policy::template MakeOutputDistribution<Problem>());
DataType* p_lds_ptr = reinterpret_cast<DataType*>(smem);
constexpr auto in_lds_block_desc = Policy::template MakeLdsStoreBlockDescriptor<Problem>();
auto input_lds_block =
make_tensor_view<address_space_enum::lds>(p_lds_ptr, in_lds_block_desc);
constexpr auto out_lds_block_desc = Policy::template MakeLdsLoadBlockDescriptor<Problem>();
auto output_lds_block =
make_tensor_view<address_space_enum::lds>(p_lds_ptr, out_lds_block_desc);
auto copy_to_lds_window =
make_tile_window(input_lds_block,
make_tuple(number<kSecondSizePerBlock>{}, number<kLeadSizePerBlock>{}),
{0, 0});
auto load_from_lds_window =
make_tile_window(output_lds_block,
make_tuple(number<kSecondSizePerBlock>{}, number<kLeadSizePerBlock>{}),
{0, 0},
Policy::template MakeLdsLoadTileDistribution<Problem>());
auto x = load_tile(input_tile_window);
store_tile(copy_to_lds_window, x);
block_sync_lds();
auto y = load_tile_transpose(load_from_lds_window);
store_tile(output_tile_window, y);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,123 @@
// SPDX-License-Identifier: MIT
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "batched_transpose_common_policy.hpp"
namespace ck_tile {
struct BatchedTransposeLdsPolicy : public BatchedTransposeCommonPolicy
{
template <typename Problem>
CK_TILE_DEVICE static constexpr index_t GetSmemSize()
{
return integer_least_multiple(
sizeof(typename Problem::DataType) *
MakeLdsStoreBlockDescriptor<Problem>().get_element_space_size(),
16);
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeOutputDistribution()
{
constexpr auto input_dstr = MakeLdsLoadTileDistribution<Problem>();
using OutTileDstrEncode =
typename OutputTileDistributionTraits<typename decltype(input_dstr)::DstrEncode,
typename Problem::DataType>::TransposedDstrEncode;
constexpr auto block_dstr = make_static_tile_distribution(OutTileDstrEncode{});
return block_dstr;
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeLdsStoreBlockDescriptor()
{
constexpr index_t kLeadDimPerBlock = Problem::kLeadSizePerBlock;
constexpr index_t kSecondDimPerBlock = Problem::kSecondSizePerBlock;
constexpr index_t kVectorSize = Problem::LDSVectorSize;
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kSecondDimPerBlock>{},
number<kLeadDimPerBlock / kVectorSize>{},
number<kVectorSize>{}),
make_tuple(number<kLeadDimPerBlock>{}, number<kVectorSize>{}, number<1>{}),
number<kVectorSize>{},
number<1>{});
constexpr auto lds_block_desc = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(make_pass_through_transform(number<kSecondDimPerBlock>{}),
make_merge_transform(make_tuple(number<kLeadDimPerBlock / kVectorSize>{},
number<kVectorSize>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return lds_block_desc;
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeLdsLoadBlockDescriptor()
{
constexpr index_t kLeadDimPerBlock = Problem::kLeadSizePerBlock;
constexpr index_t kSecondDimPerBlock = Problem::kSecondSizePerBlock;
constexpr index_t kVectorSize = Problem::LDSVectorSize;
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kSecondDimPerBlock>{},
number<kLeadDimPerBlock / kVectorSize>{},
number<kVectorSize>{}),
make_tuple(number<kLeadDimPerBlock>{}, number<kVectorSize>{}, number<1>{}),
number<kVectorSize>{},
number<1>{});
constexpr auto lds_block_desc = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(make_pass_through_transform(number<kSecondDimPerBlock>{}),
make_merge_transform(make_tuple(number<kLeadDimPerBlock / kVectorSize>{},
number<kVectorSize>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return lds_block_desc;
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeLdsLoadTileDistribution()
{
using DataType = typename Problem::DataType;
// Calculate block-level dimensions
constexpr index_t kLeadIterPerWarp = 1;
constexpr index_t kSecondIterPerWarp = 1;
constexpr index_t kLeadNumWarps = Problem::kLeadNumWarps;
constexpr index_t kSecondNumWarps = Problem::kSecondNumWarps;
// Calculate repetitions of base pattern
constexpr index_t kLeadRepetitions = Problem::kQuadNumPerLeadDim;
constexpr index_t kSecondRepetitions = Problem::kQuadNumPerSecondDim;
constexpr index_t kSecondDimIterations = Problem::kIterationsInSecondDim;
constexpr index_t kSecondDimStrSub = kSecondRepetitions / kSecondDimIterations;
constexpr index_t kLaneGroupSize = 16;
constexpr auto xdllevel_dstr_encoding = make_transposed_distr_encode<DataType,
kLaneGroupSize,
kSecondDimStrSub,
kSecondDimIterations,
kLeadRepetitions,
1>();
constexpr auto input_tile_encode =
InputTileDistributionEncoding<decltype(xdllevel_dstr_encoding),
kLeadIterPerWarp,
kSecondIterPerWarp,
kLeadNumWarps,
kSecondNumWarps>();
constexpr auto block_dstr = make_static_tile_distribution(input_tile_encode);
return block_dstr;
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,73 @@
// SPDX-License-Identifier: MIT
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
// supports 2D transpose which will store to lds,
// then use ds_read_b*_tr_b* instruction to get the transposed data
template <typename DataType_,
typename BlockTile, // sequence<block_x, block_y>
typename NumWarps,
bool kPadM_,
bool kPadN_>
struct BatchedTransposeLdsProblem
{
using DataType = remove_cvref_t<DataType_>;
static constexpr index_t kRowWarps_ = NumWarps::at(number<0>{});
static constexpr index_t kColWarps_ = NumWarps::at(number<1>{});
static constexpr index_t kBlockSize_ = get_warp_size() * kRowWarps_ * kColWarps_;
static constexpr index_t kRowPerBlock_ = BlockTile::at(number<0>{});
static constexpr index_t kColPerBlock_ = BlockTile::at(number<1>{});
static constexpr index_t kBlockSize = kBlockSize_;
// warps per block
static constexpr index_t kLeadNumWarps = kColWarps_;
static constexpr index_t kSecondNumWarps = kRowWarps_;
static constexpr index_t kLeadSizePerBlock = kColPerBlock_;
static constexpr index_t kSecondSizePerBlock = kRowPerBlock_;
static constexpr index_t kQuadrantLeadDim = LaneGroupTransposeTraits<DataType>::kleadDim;
static constexpr index_t kQuadrantSecondDim = LaneGroupTransposeTraits<DataType>::ksecondDim;
static_assert(kLeadSizePerBlock % kLeadNumWarps == 0,
"block dim should be divided by warp count!");
static_assert(kSecondSizePerBlock % kSecondNumWarps == 0,
"block dim should be divided by warp count!");
// rows/cols per warp
static constexpr index_t kLeadSizePerWarp = kLeadSizePerBlock / kLeadNumWarps;
static constexpr index_t kSecondSizePerWarp = kSecondSizePerBlock / kSecondNumWarps;
static_assert(kLeadSizePerWarp % kQuadrantLeadDim == 0,
"xdl dim should be divided by quad dim!");
static_assert(kSecondSizePerWarp % kQuadrantSecondDim == 0,
"xdl dim should be divided by quad dim!");
// xdl rows/cols is divided into quadrants.
static constexpr index_t kQuadNumPerLeadDim = kLeadSizePerWarp / kQuadrantLeadDim;
static constexpr index_t kQuadNumPerSecondDim = kSecondSizePerWarp / kQuadrantSecondDim;
static constexpr index_t kIterationsInSecondDim =
kQuadNumPerLeadDim * kQuadNumPerSecondDim * 16 / get_warp_size();
// definitions to adapt to BatchedTransposeKernel
// FIXME: support padding
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
static constexpr auto kMPerBlock = kSecondSizePerBlock;
static constexpr auto kNPerBlock = kLeadSizePerBlock;
// 128-bit is the max single-instruction bandwidth for load/store
static constexpr index_t MaxLoadStoreSize = 16;
static constexpr auto VectorSizeInput = kPadN ? 1 : MaxLoadStoreSize / sizeof(DataType);
static constexpr auto VectorSizeOutput = kPadM ? 1 : MaxLoadStoreSize / sizeof(DataType);
static constexpr auto LDSVectorSize = MaxLoadStoreSize / sizeof(DataType);
};
} // namespace ck_tile

View File

@@ -5,8 +5,6 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp"
#include <string>
#include <type_traits>
namespace ck_tile {
@@ -14,39 +12,26 @@ template <typename Problem_, typename Policy_ = BatchedTransposePolicy>
struct BatchedTransposePipeline
{
// TODO: this kernel only support warp per row
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using InputType = ck_tile::remove_cvref_t<typename Problem::InputType>;
static constexpr ck_tile::index_t kMPerBlock = Problem::kMPerBlock;
static constexpr ck_tile::index_t kNPerBlock = Problem::kNPerBlock;
static constexpr index_t AlignmentM = Problem::AlignmentM;
static constexpr index_t AlignmentN = Problem::AlignmentN;
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
using Problem = ck_tile::remove_cvref_t<Problem_>;
using Policy = ck_tile::remove_cvref_t<Policy_>;
template <typename InputWindow, typename OutputWindow>
CK_TILE_DEVICE auto operator()(const InputWindow& input_window, OutputWindow& out_window)
{
auto inp_win =
make_tile_window(input_window, Policy::template MakeInputDistribution<Problem>());
auto input_tile = load_tile(inp_win);
auto output_tile = make_static_distributed_tensor<typename Problem::DataType>(
Policy::template MakeOutputDistribution<Problem>());
transpose_tile2d(output_tile, input_tile);
auto out_win =
make_tile_window(out_window, Policy::template MakeOutputDistribution<Problem>());
auto x = load_tile(inp_win); // x->thread input_win->block
auto y = make_static_distributed_tensor<InputType>(
Policy::template MakeOutputDistribution<Problem>());
constexpr auto span_2d_x = decltype(x)::get_distributed_spans();
sweep_tile_span(span_2d_x[number<0>{}], [&](auto idx0) {
sweep_tile_span(span_2d_x[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx1, idx0);
y(i_j_idx) = x(i_j_idx);
});
});
store_tile(out_win, y);
store_tile(out_win, output_tile);
}
};
} // namespace ck_tile

View File

@@ -4,41 +4,26 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/softmax.hpp"
#include "ck_tile/ops/topk.hpp"
#include "batched_transpose_common_policy.hpp"
namespace ck_tile {
struct BatchedTransposePolicy
struct BatchedTransposePolicy : public BatchedTransposeCommonPolicy
{
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeInputDistribution()
CK_TILE_DEVICE static constexpr auto MakeOutputDistribution()
{
using S = Problem;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<>,
tuple<sequence<S::kMWarpPerBlock, S::kMThreadPerWarp, S::kMPerThread>,
sequence<S::kNWarpPerBlock, S::kNThreadPerWarp, S::kNPerThread>>,
tuple<sequence<1, 2>, sequence<1, 2>>,
tuple<sequence<0, 0>, sequence<1, 1>>,
sequence<1, 2>,
sequence<2, 2>>{});
}
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::kMPerBlock;
constexpr index_t NPerBlock = Problem::kNPerBlock;
constexpr index_t VecLoadSize = Problem::VectorSizeOutput;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOutputDistribution()
{
using S = Problem;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<>,
tuple<sequence<S::kNWarpPerBlock, S::kNThreadPerWarp, S::kNPerThread>,
sequence<S::kMWarpPerBlock, S::kMThreadPerWarp, S::kMPerThread>>,
tuple<sequence<2, 1>, sequence<2, 1>>,
tuple<sequence<0, 0>, sequence<1, 1>>,
sequence<2, 1>,
sequence<2, 2>>{});
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
MPerBlock,
NPerBlock,
VecLoadSize,
TileAccessPattern>;
return TileEncodingPattern::MakeShuffled2DStaticTileDistribution();
}
};
} // namespace ck_tile

View File

@@ -4,45 +4,33 @@
#pragma once
#include "ck_tile/core.hpp"
#include <string>
#include <type_traits>
#define VectorLoadSize 16
namespace ck_tile {
template <typename InputType_,
typename BlockTile, // Sequence<...
typename WarpTile, // Sequence<...
typename ThreadTile, // Sequence<...
bool kPadM_ = true,
bool kPadN_ = true>
template <typename DataType_,
typename BlockTile, // Sequence<...
typename WarpLayout,
bool kPadM_ = false,
bool kPadN_ = false> // Sequence<...
struct BatchedTransposeProblem
{
using InputType = remove_cvref_t<InputType_>;
using DataType = remove_cvref_t<DataType_>;
static constexpr index_t kMPerThread = ThreadTile::at(number<0>{});
static constexpr index_t kNPerThread = ThreadTile::at(number<1>{});
static constexpr index_t kMPerWarp = WarpTile::at(number<0>{});
static constexpr index_t kNPerWarp = WarpTile::at(number<1>{});
static constexpr index_t kMThreadPerWarp = kMPerWarp / kMPerThread;
static constexpr index_t kNThreadPerWarp = kNPerWarp / kNPerThread;
static constexpr index_t kMPerWarp = WarpLayout::at(number<0>{});
static constexpr index_t kNPerWarp = WarpLayout::at(number<1>{});
static constexpr index_t kMPerBlock = BlockTile::at(number<0>{});
static constexpr index_t kNPerBlock = BlockTile::at(number<1>{});
static constexpr index_t kMWarpPerBlock = kMPerBlock / kMPerWarp;
static constexpr index_t kNWarpPerBlock = kNPerBlock / kNPerWarp;
static constexpr index_t kBlockSize =
kMThreadPerWarp * kNThreadPerWarp * kMWarpPerBlock * kNWarpPerBlock;
static constexpr index_t kBlockSize = kMPerWarp * kNPerWarp * get_warp_size();
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
static constexpr index_t AlignmentM = kPadM ? VectorLoadSize / sizeof(InputType) : 1; // TODO
static constexpr index_t AlignmentN = kPadN ? VectorLoadSize / sizeof(InputType) : 1;
// 128-bit is the max single-instruction bandwidth for load/store
static constexpr index_t MaxLoadStoreSize = 16;
static constexpr index_t VectorSizeInput = kPadN ? 1 : MaxLoadStoreSize / sizeof(DataType);
static constexpr index_t VectorSizeOutput = kPadM ? 1 : MaxLoadStoreSize / sizeof(DataType);
};
} // namespace ck_tile

0
include/ck_tile/ops/common/utils.hpp Executable file → Normal file
View File

View File

@@ -3,6 +3,11 @@
#pragma once
#include "ck_tile/ops/elementwise/binary_elementwise_operation.hpp"
#include "ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp"
#include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_default_policy.hpp"
#include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_problem.hpp"
#include "ck_tile/ops/elementwise/pipeline/elementwise_shape.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"

View File

@@ -0,0 +1,94 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
namespace element_wise {
struct Add
{
template <typename Y, typename X0, typename X1>
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
template <>
__host__ __device__ constexpr void
operator()<float>(float& y, const float& x0, const float& x1) const
{
y = x0 + x1;
};
template <>
__host__ __device__ constexpr void
operator()<double>(double& y, const double& x0, const double& x1) const
{
y = x0 + x1;
};
template <>
__host__ __device__ constexpr void
operator()<float>(float& y, const float& x0, const half_t& x1) const
{
y = x0 + type_convert<half_t>(x1);
};
template <>
__host__ __device__ constexpr void
operator()<half_t>(half_t& y, const float& x0, const float& x1) const
{
y = type_convert<half_t>(x0 + x1);
};
template <>
__host__ __device__ constexpr void
operator()<half_t>(half_t& y, const float& x0, const half_t& x1) const
{
y = type_convert<half_t>(x0) + x1;
};
template <>
__host__ __device__ constexpr void
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
{
y = x0 + x1;
};
template <>
__host__ __device__ constexpr void
operator()<float>(float& y, const float& x0, const bf16_t& x1) const
{
const float x1_tmp = type_convert<float>(x1);
y = x0 + x1_tmp;
}
template <>
__host__ __device__ constexpr void
operator()<bf16_t>(bf16_t& y, const bf16_t& x0, const bf16_t& x1) const
{
const float x1_tmp = type_convert<float>(x0);
const float x2_tmp = type_convert<float>(x1);
const float y_tmp = x1_tmp + x2_tmp;
y = type_convert<bf16_t>(y_tmp);
}
template <>
__host__ __device__ constexpr void
operator()<bf16_t>(bf16_t& y, const float& x0, const bf16_t& x1) const
{
const float x2_tmp = type_convert<float>(x1);
const float y_tmp = x0 + x2_tmp;
y = type_convert<bf16_t>(y_tmp);
}
template <>
__host__ __device__ constexpr void
operator()<int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
{
y = x0 + x1;
};
};
} // namespace element_wise
} // namespace ck_tile

View File

@@ -0,0 +1,123 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_problem.hpp"
#include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_default_policy.hpp"
namespace ck_tile {
template <typename Problem_, typename Policy_>
struct ElementWiseKernel
{
using Problem = ck_tile::remove_cvref_t<Problem_>;
using Policy = ck_tile::remove_cvref_t<Policy_>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
using ElementWiseOperation = ck_tile::remove_cvref_t<typename Problem::ElementWiseOperation>;
template <typename... XDataType, typename Dims>
CK_TILE_DEVICE void operator()(Dims lens,
Dims input_strides,
Dims output_strides,
const tuple<XDataType...>& input_tensors,
YDataType* p_y) const
{
using S = typename Problem::BlockShape;
// Setup block-level coordinates and transforms
const index_t iM = get_block_id() * S::kBlockM;
const auto merge_transform = make_merge_transform(lens);
// Load all input tiles into registers.
// The lambda structure here is intended to minimize the lifetime
// of intermediate objects (views, windows) used for loading.
const auto x_tiles = ck_tile::generate_tuple(
[&](auto i) {
const auto tensor_view = make_naive_tensor_view<address_space_enum::global>(
input_tensors.get(i), lens, input_strides, number<S::kVectorM>{}, number<1>{});
const auto transformed_tensor = pad_tensor_view(
transform_tensor_view(tensor_view,
ck_tile::make_tuple(merge_transform),
ck_tile::make_tuple(make_index_sequence<Dims::size()>{}),
ck_tile::make_tuple(sequence<0>{})),
ck_tile::make_tuple(number<S::kBlockM>{}),
sequence<Problem::kPad>{});
const auto x_window =
make_tile_window(transformed_tensor,
ck_tile::make_tuple(number<S::kBlockM>{}),
{iM},
Policy::template MakeXBlockTileDistribution<Problem>());
return load_tile(x_window);
},
number<sizeof...(XDataType)>{});
// Setup output tile in registers.
const auto& x_tile0 = x_tiles.get(number<0>{});
auto y_tile = make_static_distributed_tensor<YDataType>(x_tile0.get_tile_distribution());
// Perform element-wise computation.
const auto spans = x_tile0.get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx) {
const auto tile_idx = make_tuple(idx);
apply(
[&](auto&&... tiles) {
ElementWiseOperation{}(y_tile(tile_idx),
type_convert<ComputeDataType>(tiles[tile_idx])...);
},
x_tiles);
});
// Setup output window and store the result tile.
const auto y_m_n = make_naive_tensor_view<address_space_enum::global>(
p_y, lens, output_strides, number<S::kVectorM>{});
const auto transformed_y_m_n = pad_tensor_view(
transform_tensor_view(y_m_n,
ck_tile::make_tuple(merge_transform),
ck_tile::make_tuple(make_index_sequence<Dims::size()>{}),
ck_tile::make_tuple(sequence<0>{})),
ck_tile::make_tuple(number<S::kBlockM>{}),
sequence<Problem::kPad>{});
auto y_window = make_tile_window(transformed_y_m_n,
make_tuple(number<S::kBlockM>{}),
{iM},
y_tile.get_tile_distribution());
store_tile(y_window, cast_tile<YDataType>(y_tile));
}
template <typename... Ints>
CK_TILE_HOST static bool IsSupportedArgument(const ck_tile::tuple<Ints...>& input_sizes)
{
int total_elements = 1;
const auto kVectorM = Problem_::BlockShape::kVectorM;
apply([&](auto&&... args) { ((total_elements *= args), ...); }, input_sizes);
if((total_elements % kVectorM) != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("Conditions not met: total number of input elements (",
total_elements,
") should be multiple of the vectorization size (",
kVectorM,
")");
}
return false;
}
return true;
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,29 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
struct ElementWiseDefaultPolicy
{
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution()
{
using S = typename Problem::BlockShape;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<>, // Replicate
tuple<sequence<S::kRepeatM,
S::kWarpPerBlockM,
S::kThreadPerWarpM,
S::kVectorM>>, // Hierarchical
tuple<sequence<1>, sequence<1>>, // Parallel
tuple<sequence<1>, sequence<2>>, // Parallel
sequence<1, 1>, // Yield
sequence<0, 3>>{} // Yield
);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,26 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
template <typename XDataType_,
typename ComputeDataType_,
typename YDataType_,
typename BlockShape_,
typename ElementWiseOperation_,
bool kPad_ = true>
struct ElementWisePipelineProblem
{
using XDataType = remove_cvref_t<XDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using YDataType = remove_cvref_t<YDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
using ElementWiseOperation = remove_cvref_t<ElementWiseOperation_>;
static constexpr bool kPad = kPad_;
};
} // namespace ck_tile

View File

@@ -0,0 +1,30 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
template <typename BlockWarps, typename BlockTile, typename WarpTile, typename ComputeDataType>
struct ElementWiseShape
{
static constexpr index_t kBlockM = BlockTile::at(number<0>{});
static constexpr index_t kWarpM = WarpTile::at(number<0>{});
static constexpr index_t kVectorM =
min(static_cast<index_t>(16 / sizeof(ComputeDataType)), kWarpM / get_warp_size());
static constexpr index_t kWarpPerBlockM = BlockWarps::at(number<0>{});
static constexpr index_t kThreadPerWarpM = get_warp_size();
static constexpr index_t kRepeatM = kBlockM / (kWarpPerBlockM * kVectorM * kThreadPerWarpM);
static constexpr index_t kBlockSize =
ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
};
} // namespace ck_tile

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -110,6 +110,86 @@ CK_TILE_DEVICE bf16x4_t i4_to_bhalf4(int q)
return res;
}
CK_TILE_DEVICE fp8x8_t amd_assembly_i4_to_fp8x8(int a)
{
uint32_t src = static_cast<uint32_t>(a), src_hi;
uint32_t fp8x4_lo, fp8x4_hi;
float tmp_0, tmp_1;
asm volatile("v_lshrrev_b32 %[v_hi_src], 4, %[v_src]\n"
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_3\n"
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_3\n"
"v_cvt_pk_fp8_f32 %[v_dst_hi], %[v_tmp_1], %[v_tmp_0], op_sel:[0, 0, 1]\n"
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_2\n"
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_2\n"
"v_cvt_pk_fp8_f32 %[v_dst_hi], %[v_tmp_1], %[v_tmp_0]\n"
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_1\n"
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_1\n"
"v_cvt_pk_fp8_f32 %[v_dst_lo], %[v_tmp_1], %[v_tmp_0], op_sel:[0, 0, 1]\n"
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src]\n"
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src]\n"
"v_cvt_pk_fp8_f32 %[v_dst_lo], %[v_tmp_1], %[v_tmp_0]\n"
: [v_tmp_0] "+v"(tmp_0),
[v_tmp_1] "+v"(tmp_1),
[v_hi_src] "+v"(src_hi),
[v_dst_lo] "+v"(fp8x4_lo),
[v_dst_hi] "+v"(fp8x4_hi),
[v_src] "+v"(src)
:);
return bit_cast<fp8x8_t>(((static_cast<uint64_t>(fp8x4_hi) << 32) | fp8x4_lo));
}
CK_TILE_DEVICE float amd_assembly_fp8_to_fp32(uint32_t src)
{
float res;
asm volatile("v_cvt_f32_fp8 %0, %1, src0_sel:BYTE_0" : "=v"(res) : "v"(src));
return res;
}
CK_TILE_DEVICE float amd_assembly_bf8_to_fp32(uint32_t src)
{
float res;
asm volatile("v_cvt_f32_bf8 %0, %1, src0_sel:BYTE_0" : "=v"(res) : "v"(src));
return res;
}
CK_TILE_DEVICE bf8x8_t amd_assembly_i4_to_bf8x8(int a)
{
uint32_t src = static_cast<uint32_t>(a), src_hi;
uint32_t bf8x4_lo, bf8x4_hi;
float tmp_0, tmp_1;
asm volatile("v_lshrrev_b32 %[v_hi_src], 4, %[v_src]\n"
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_3\n"
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_3\n"
"v_cvt_pk_bf8_f32 %[v_dst_hi], %[v_tmp_1], %[v_tmp_0], op_sel:[0, 0, 1]\n"
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_2\n"
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_2\n"
"v_cvt_pk_bf8_f32 %[v_dst_hi], %[v_tmp_1], %[v_tmp_0]\n"
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_1\n"
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_1\n"
"v_cvt_pk_bf8_f32 %[v_dst_lo], %[v_tmp_1], %[v_tmp_0], op_sel:[0, 0, 1]\n"
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src]\n"
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src]\n"
"v_cvt_pk_bf8_f32 %[v_dst_lo], %[v_tmp_1], %[v_tmp_0]\n"
: [v_tmp_0] "+v"(tmp_0),
[v_tmp_1] "+v"(tmp_1),
[v_hi_src] "+v"(src_hi),
[v_dst_lo] "+v"(bf8x4_lo),
[v_dst_hi] "+v"(bf8x4_hi),
[v_src] "+v"(src)
:);
return bit_cast<bf8x8_t>(((static_cast<uint64_t>(bf8x4_hi) << 32) | bf8x4_lo));
}
struct PassThroughPack8
{
template <typename Y, typename X>
@@ -126,6 +206,16 @@ struct PassThroughPack8
y.lo = i4_to_bhalf4(bit_cast<int>(x));
y.hi = i4_to_bhalf4(bit_cast<int>(x) >> 16);
}
CK_TILE_HOST_DEVICE constexpr void operator()(fp8x8_t& y, const pk_int4x4_t& x) const
{
y = amd_assembly_i4_to_fp8x8(bit_cast<int>(x));
}
CK_TILE_HOST_DEVICE constexpr void operator()(bf8x8_t& y, const pk_int4x4_t& x) const
{
y = amd_assembly_i4_to_bf8x8(bit_cast<int>(x));
}
constexpr const static bool is_pack8_invocable = true;
};
@@ -172,219 +262,67 @@ struct PassThroughPack2
struct PassThrough
{
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
template <class T>
using raw_t = std::remove_cv_t<std::remove_reference_t<T>>;
template <>
CK_TILE_HOST_DEVICE void operator()<double, double>(double& y, const double& x) const
template <class Y, class X>
CK_TILE_HOST_DEVICE void operator()(Y&& y, const X& x) const
{
y = x;
/* Only do the assignment when
- y is an *l-value* and
- y is *not* const */
if constexpr(std::is_lvalue_reference_v<Y&&> && !std::is_const_v<raw_t<Y>>)
{
y = ck_tile::type_convert<raw_t<Y>>(x);
}
/* otherwise (r-value or const) → do nothing */
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, double>(float& y, const double& x) const
template <typename E, typename C, typename... Ds>
CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const Ds&... ds) const -> void
{
y = type_convert<float>(x);
}
// Suppress unused parameter warning for ds
((void)ds, ...);
template <>
CK_TILE_HOST_DEVICE void operator()<double, float>(double& y, const float& x) const
{
y = type_convert<double>(x);
// Just assign e with c
if constexpr(std::is_same_v<E, C>)
{
e = c;
}
else
{
e = ck_tile::type_convert<E>(c);
}
}
};
template <>
CK_TILE_HOST_DEVICE void operator()<float, float>(float& y, const float& x) const
struct MultiDMultiply
{
template <typename E, typename C, typename... Ds>
CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const Ds&... ds) const -> void
{
y = x;
}
// Start with the base value c
float result = ck_tile::type_convert<float>(c);
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::fp16_t, ck_tile::fp16_t>(ck_tile::fp16_t& y, const ck_tile::fp16_t& x) const
{
y = x;
}
// Multiply by each D parameter using fold expression
((result *= ck_tile::type_convert<float>(ds)), ...);
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp16_t, float>(ck_tile::fp16_t& y,
const float& x) const
{
y = type_convert<ck_tile::fp16_t>(x);
e = ck_tile::type_convert<E>(result);
}
};
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::bf16_t, ck_tile::bf16_t>(ck_tile::bf16_t& y, const ck_tile::bf16_t& x) const
struct MultiDAdd
{
template <typename E, typename C, typename... Ds>
CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const Ds&... ds) const -> void
{
y = x;
}
// Start with the base value c
float result = ck_tile::type_convert<float>(c);
template <>
CK_TILE_HOST_DEVICE void operator()<int32_t, int32_t>(int32_t& y, const int32_t& x) const
{
y = x;
}
// Add by each D parameter using fold expression
((result += ck_tile::type_convert<float>(ds)), ...);
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::bf16_t, float>(ck_tile::bf16_t& y,
const float& x) const
{
y = type_convert<ck_tile::bf16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::bf16_t>(float& y,
const ck_tile::bf16_t& x) const
{
y = type_convert<float>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::bf16_t, ck_tile::fp16_t>(ck_tile::bf16_t& y, const ck_tile::fp16_t& x) const
{
y = type_convert<ck_tile::bf16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::fp16_t>(float& y,
const ck_tile::fp16_t& x) const
{
y = type_convert<float>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp16_t, int8_t>(ck_tile::fp16_t& y,
const int8_t& x) const
{
y = type_convert<ck_tile::fp16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::bf16_t, int8_t>(ck_tile::bf16_t& y,
const int8_t& x) const
{
y = type_convert<ck_tile::bf16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<uint8_t, uint8_t>(uint8_t& y, const uint8_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<int8_t, int32_t>(int8_t& y, const int32_t& x) const
{
y = type_convert<int8_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<int32_t, int8_t>(int32_t& y, const int8_t& x) const
{
y = type_convert<int32_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<int8_t, float>(int8_t& y, const float& x) const
{
y = type_convert<int8_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, int8_t>(float& y, const int8_t& x) const
{
y = type_convert<float>(x);
}
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <>
CK_TILE_HOST_DEVICE void operator()<int4_t, int4_t>(int4_t& y, const int4_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<int4_t, int>(int4_t& y, const int& x) const
{
y = type_convert<int4_t>(x);
}
#endif
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::fp8_t, ck_tile::fp8_t>(ck_tile::fp8_t& y, const ck_tile::fp8_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::fp8_t>(float& y,
const ck_tile::fp8_t& x) const
{
y = type_convert<float>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp8_t, float>(ck_tile::fp8_t& y,
const float& x) const
{
y = type_convert<ck_tile::fp8_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::fp16_t, ck_tile::fp8_t>(ck_tile::fp16_t& y, const ck_tile::fp8_t& x) const
{
y = type_convert<ck_tile::fp16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::fp8_t, ck_tile::fp16_t>(ck_tile::fp8_t& y, const ck_tile::fp16_t& x) const
{
y = type_convert<ck_tile::fp8_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::bf8_t, ck_tile::bf8_t>(ck_tile::bf8_t& y, const ck_tile::bf8_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::bf8_t>(float& y,
const ck_tile::bf8_t& x) const
{
y = type_convert<float>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::bf8_t, float>(ck_tile::bf8_t& y,
const float& x) const
{
y = type_convert<ck_tile::bf8_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::fp16_t, ck_tile::bf8_t>(ck_tile::fp16_t& y, const ck_tile::bf8_t& x) const
{
y = type_convert<ck_tile::fp16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::bf8_t, ck_tile::fp16_t>(ck_tile::bf8_t& y, const ck_tile::fp16_t& x) const
{
y = ck_tile::type_convert<ck_tile::bf8_t>(x);
e = ck_tile::type_convert<E>(result);
}
};
@@ -1479,5 +1417,6 @@ struct FastNumericArrayConverter<uint8_t, ck_tile::fp16_t, N>
CK_TILE_DEVICE OutputArray operator()(InputArray const& Input) { return convert(Input); }
};
#endif
} // namespace element_wise
} // namespace ck_tile

View File

@@ -4,9 +4,9 @@
#pragma once
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
#include "ck_tile/ops/epilogue/default_2d_and_dynamic_quant_epilogue.hpp"
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp"
#include "ck_tile/ops/epilogue/default_2d_and_dynamic_quant_epilogue.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -11,34 +11,52 @@ namespace ck_tile {
template <typename ADataType_,
typename BDataType_,
typename DsDataType_,
typename AccDataType_,
typename ODataType_,
typename CLayout_,
typename DsLayout_,
typename ELayout_,
typename CDElementwise_,
index_t kBlockSize_,
index_t kM_,
index_t kN_,
index_t kMWave_,
index_t kNWave_,
index_t kMPerXdl_,
index_t kNPerXdl_,
index_t kKPerXdl_,
bool isCTransposed_>
index_t MWave_,
index_t NWave_,
index_t MPerXdl_,
index_t NPerXdl_,
index_t KPerXdl_,
bool isCTransposed_,
memory_operation_enum MemoryOperation_,
index_t kNumWaveGroups_ = 1,
bool FixedVectorSize_ = false,
index_t VectorSizeC_ = 1>
struct CShuffleEpilogueProblem
{
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using AccDataType = remove_cvref_t<AccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
using CLayout = remove_cvref_t<CLayout_>;
static constexpr index_t kBlockSize = kBlockSize_;
static constexpr index_t kMPerBlock = kM_;
static constexpr index_t kNPerBlock = kN_;
static constexpr index_t kMWave = kMWave_;
static constexpr index_t kNWave = kNWave_;
static constexpr index_t kMPerXdl = kMPerXdl_;
static constexpr index_t kNPerXdl = kNPerXdl_;
static constexpr index_t kKPerXdl = kKPerXdl_;
static constexpr index_t isCTransposed = isCTransposed_;
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using AccDataType = remove_cvref_t<AccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
using DsDataType = remove_cvref_t<DsDataType_>;
using DsLayout = remove_cvref_t<DsLayout_>;
using ELayout = remove_cvref_t<ELayout_>;
using CDElementwise = remove_cvref_t<CDElementwise_>;
static constexpr index_t kBlockSize = kBlockSize_;
static constexpr index_t kMPerBlock = kM_;
static constexpr index_t kNPerBlock = kN_;
static constexpr index_t MWave = MWave_;
static constexpr index_t NWave = NWave_;
static constexpr index_t MPerXdl = MPerXdl_;
static constexpr index_t NPerXdl = NPerXdl_;
static constexpr index_t KPerXdl = KPerXdl_;
static constexpr index_t isCTransposed = isCTransposed_;
static constexpr memory_operation_enum MemoryOperation = MemoryOperation_;
static constexpr bool FixedVectorSize = FixedVectorSize_;
static constexpr index_t VectorSizeC = VectorSizeC_;
static constexpr index_t kNumWaveGroups = kNumWaveGroups_;
static constexpr index_t NumDTensor = DsDataType::size();
static_assert(NumDTensor == DsLayout::size(),
"The size of DsDataType and DsLayout should be the same");
};
template <typename Problem_, typename Policy_ = void>
@@ -49,32 +67,33 @@ struct CShuffleEpilogue
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using DsDataType = remove_cvref_t<typename Problem::DsDataType>;
using DsLayout = remove_cvref_t<typename Problem::DsLayout>;
using ATypeToUse =
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
// Used for weight-only quantization kernel, B would be dequantized to the same data type as A
using BTypeToUse =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ODataType, BDataType>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kMPerBlock = Problem::kMPerBlock;
static constexpr index_t kNPerBlock = Problem::kNPerBlock;
static constexpr index_t kMWave = Problem::kMWave;
static constexpr index_t kNWave = Problem::kNWave;
static constexpr index_t kMPerXdl = Problem::kMPerXdl;
static constexpr index_t kNPerXdl = Problem::kNPerXdl;
static constexpr index_t kKPerXdl = Problem::kKPerXdl;
static constexpr index_t isCTransposed = Problem::isCTransposed;
static constexpr index_t kMPerIteration = kMPerXdl * kMWave;
static constexpr index_t kNPerIteration = kNPerXdl * kNWave;
using WG = WarpGemmMfmaDispatcher<ADataType,
BTypeToUse,
AccDataType,
kMPerXdl,
kNPerXdl,
kKPerXdl,
isCTransposed>;
using CWarpDstr = typename WG::CWarpDstr;
using CWarpTensor = typename WG::CWarpTensor;
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
using ELayout = remove_cvref_t<typename Problem::ELayout>;
using CDElementwise = remove_cvref_t<typename Problem::CDElementwise>;
static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kMPerBlock = Problem::kMPerBlock;
static constexpr index_t kNPerBlock = Problem::kNPerBlock;
static constexpr index_t MWave = Problem::MWave;
static constexpr index_t NWave = Problem::NWave;
static constexpr index_t MPerXdl = Problem::MPerXdl;
static constexpr index_t NPerXdl = Problem::NPerXdl;
static constexpr index_t KPerXdl = Problem::KPerXdl;
static constexpr index_t isCTransposed = Problem::isCTransposed;
static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
static constexpr index_t MPerIteration = MPerXdl * MWave;
static constexpr index_t NPerIteration = NPerXdl * NWave;
static constexpr index_t NumDTensor = Problem::NumDTensor;
static_assert(NumDTensor == DsLayout::size(),
"The size of DsDataType and DsLayout should be the same");
/**
* @brief Get the vector store size for C tensor.
*
@@ -85,100 +104,244 @@ struct CShuffleEpilogue
*
* @return The vector store size for C tensor.
*/
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeC()
{
constexpr index_t MaxVectorStoreSize = 16;
return MaxVectorStoreSize / sizeof(ODataType);
if constexpr(FixedVectorSize)
{
return VectorSizeC;
}
constexpr index_t max_vector_size = 16;
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
return std::min(static_cast<int>(NPerIteration),
static_cast<int>(max_vector_size / sizeof(ODataType)));
}
else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
{
return std::min(static_cast<int>(MPerIteration),
static_cast<int>(max_vector_size / sizeof(ODataType)));
}
else
{
static_assert(false, "Unsupported ELayout!");
}
}
/**
* @brief Get the vector store size for Di tensor.
*
* @return The vector store size for Di tensor.
*/
template <index_t I>
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeD(number<I> index)
{
constexpr index_t max_vector_size = 16;
using DiDataType = remove_cvref_t<std::tuple_element_t<index.value, DsDataType>>;
using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
{
return std::min(static_cast<int>(NPerIteration),
static_cast<int>(max_vector_size / sizeof(DiDataType)));
}
else if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::ColumnMajor>)
{
return std::min(static_cast<int>(MPerIteration),
static_cast<int>(max_vector_size / sizeof(DiDataType)));
}
else
{
static_assert(false, "Unsupported DLayout!");
}
return max_vector_size / sizeof(DiDataType);
}
/**
* @brief Shuffle tile configuration parameters
*
* @details These parameters control the number of XDL tiles processed per wave in each shuffle
* iteration:
* - NumMXdlPerWavePerShuffle: Number of XDL tiles in M dimension processed per wave
* - NumNXdlPerWavePerShuffle: Number of XDL tiles in N dimension processed per wave
*/
static constexpr auto shuffle_tile_tuple = [] {
constexpr index_t elem_per_thread = MPerXdl * NPerXdl / get_warp_size();
if constexpr(elem_per_thread >= GetVectorSizeC())
{
return std::make_tuple(1, 1);
}
else
{
constexpr index_t num_xdl_shuffles = GetVectorSizeC() / elem_per_thread;
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
static_assert((kMPerBlock % (MPerXdl * MWave) == 0) &&
(kMPerBlock % num_xdl_shuffles == 0),
"kMPerBlock must be divisible by MPerXdl*MWave and "
"num_xdl_shuffles for CShuffleEpilogue");
return std::make_tuple(min(num_xdl_shuffles, kMPerBlock / (MPerXdl * MWave)), 1);
}
else
{
static_assert((kNPerBlock % (NPerXdl * NWave) == 0) &&
(kNPerBlock % num_xdl_shuffles == 0),
"kNPerBlock must be divisible by NPerXdl*NWave and "
"num_xdl_shuffles for CShuffleEpilogue");
return std::make_tuple(1, min(num_xdl_shuffles, kNPerBlock / (NPerXdl * NWave)));
}
}
}();
static constexpr index_t NumMXdlPerWavePerShuffle = std::get<0>(shuffle_tile_tuple);
static constexpr index_t NumNXdlPerWavePerShuffle = std::get<1>(shuffle_tile_tuple);
static constexpr auto MNPerIterationShuffle = [] {
constexpr index_t m_val = MPerXdl * MWave * NumMXdlPerWavePerShuffle;
constexpr index_t n_val = NPerXdl * NWave * NumNXdlPerWavePerShuffle;
if constexpr(kMPerBlock % m_val != 0 || kNPerBlock % n_val != 0)
return std::make_tuple(MPerXdl * MWave, NPerXdl * NWave);
else
return std::make_tuple(m_val, n_val);
}();
static constexpr index_t MPerIterationShuffle = std::get<0>(MNPerIterationShuffle);
static constexpr index_t NPerIterationShuffle = std::get<1>(MNPerIterationShuffle);
using WG = WarpGemmDispatcher<ATypeToUse,
BTypeToUse,
AccDataType,
MPerXdl,
NPerXdl,
KPerXdl,
isCTransposed>;
using CWarpDstr = typename WG::CWarpDstr;
using CWarpTensor = typename WG::CWarpTensor;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDescriptor()
{
// N is contiguous dimension
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_descriptor(
make_tuple(number<kMWave * kMPerXdl>{}, number<kNWave * kNPerXdl>{}),
make_tuple(number<kNWave * kNPerXdl>{}, number<1>{}));
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
make_tuple(number<NPerIterationShuffle>{}, number<1>{}));
}
// M is contiguous dimension
else if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::ColumnMajor>)
else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
{
return make_naive_tensor_descriptor(
make_tuple(number<kMWave * kMPerXdl>{}, number<kNWave * kNPerXdl>{}),
make_tuple(number<1>{}, number<kMWave * kMPerXdl>{}));
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
make_tuple(number<1>{}, number<MPerIterationShuffle>{}));
}
else
{
static_assert(false, "Unsupported CLayout!");
static_assert(false, "Unsupported ELayout!");
}
}
CK_TILE_DEVICE static constexpr auto MakeLdsDistributionEncode()
{
constexpr auto block_outer_dstr_encoding =
tile_distribution_encoding<sequence<>,
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
sequence<NumNXdlPerWavePerShuffle, NWave>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto block_dstr_encoding = detail::make_embed_tile_distribution_encoding(
block_outer_dstr_encoding, typename CWarpDstr::DstrEncode{});
return block_dstr_encoding;
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return kMWave * kNWave * kMPerXdl * kNPerXdl * sizeof(ODataType);
return MPerIterationShuffle * NPerIterationShuffle * sizeof(ODataType);
}
template <typename ODramWindow,
typename OAccTile,
memory_operation_enum out_memory_data_op = memory_operation_enum::set>
CK_TILE_DEVICE auto
operator()(ODramWindow& out_dram_window, const OAccTile& o_acc_tile, void* p_smem)
template <typename ODramWindow, typename OAccTile, typename DsDramWindows>
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
const OAccTile& o_acc_tile,
const DsDramWindows& ds_dram_windows,
void* p_smem)
{
constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode());
const index_t iMWarp = get_warp_id() / kNWave;
const index_t iNWarp = get_warp_id() - iMWarp * kNWave;
auto lds_tile = make_static_distributed_tensor<AccDataType>(LdsTileDistr);
constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
auto o_lds_block = make_tensor_view<address_space_enum::lds>(
static_cast<ODataType*>(p_smem), lds_block_desc);
auto in_lds_window =
make_tile_window(o_lds_block,
make_tuple(number<kMPerXdl>{}, number<kNPerXdl>{}),
{number<kMPerXdl>{} * iMWarp, number<kNPerXdl>{} * iNWarp});
auto out_lds_window =
make_tile_window(o_lds_block,
make_tuple(number<kMWave * kMPerXdl>{}, number<kNWave * kNPerXdl>{}),
{0, 0});
auto in_lds_window = make_tile_window(
o_lds_block,
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
{0, 0},
LdsTileDistr);
auto out_lds_window = make_tile_window(
o_lds_block,
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
{0, 0});
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
sequence<0, 1>,
sequence<kMPerXdl * kMWave, kNPerXdl * kNWave>>;
sequence<0, 1>,
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
constexpr index_t num_access = SFC::get_num_of_access();
static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
"Currently, the CShuffle Epilogue only supports the Row Major Output layout");
using TileEncodingPattern =
TileDistributionEncodingPattern2D<kBlockSize,
kMPerIteration,
kNPerIteration,
MPerIterationShuffle,
NPerIterationShuffle,
GetVectorSizeC(),
tile_distribution_pattern::thread_raked>;
tile_distribution_pattern::thread_raked,
Problem::kNumWaveGroups>;
constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution();
auto d_dram_windows = generate_tuple(
[&](auto idx) {
return make_tile_window(ds_dram_windows[idx], dram_tile_distribution);
},
number<NumDTensor>{});
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
CWarpTensor c_warp_in_tensor;
static_for<0, num_access, 1>{}([&](auto iAccess) {
block_sync_lds();
constexpr auto idx_y_start = SFC::get_index(iAccess);
constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (kMPerXdl * kMWave)>{};
constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (kNPerXdl * kNWave)>{};
constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (MPerIterationShuffle)>{};
constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (NPerIterationShuffle)>{};
c_warp_in_tensor.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
lds_tile.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
merge_sequences(
sequence<mIter * NumMXdlPerWavePerShuffle, nIter * NumNXdlPerWavePerShuffle>{},
c_warp_y_index_zeros),
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
c_warp_y_lengths));
const auto c_warp_in_tensor_casted = cast_tile<ODataType>(c_warp_in_tensor);
const auto c_warptile_in_tensor_casted = cast_tile<ODataType>(lds_tile);
block_sync_lds();
store_tile(in_lds_window, c_warp_in_tensor_casted);
store_tile(in_lds_window, c_warptile_in_tensor_casted);
block_sync_lds();
const auto c_out_tensor =
load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
if constexpr(out_memory_data_op == memory_operation_enum::set)
const auto ds_tensor = generate_tuple(
[&](auto idx) { return load_tile(d_dram_windows[idx]); }, number<NumDTensor>{});
const auto c_ds_tiles = concat_tuple_of_reference(
tie(c_out_tensor, c_out_tensor),
generate_tie([&](auto idx) -> const auto& { return ds_tensor[idx]; },
number<NumDTensor>{}));
tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles);
if constexpr(MemoryOperation == memory_operation_enum::set)
{
store_tile(out_dram_window, c_out_tensor);
}
@@ -189,7 +352,13 @@ struct CShuffleEpilogue
if constexpr(iAccess != num_access - 1)
{
constexpr auto step = SFC::get_forward_step(iAccess);
move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})});
static_for<0, NumDTensor, 1>{}([&](auto idx) {
move_tile_window(d_dram_windows[idx],
{step.at(number<0>{}), step.at(number<1>{})});
});
}
});
}

View File

@@ -15,17 +15,21 @@ template <typename AccDataType_,
typename ODataType_,
bool kPadM_,
bool kPadN_,
bool UseRawStore_ = true>
bool UseRawStore_ = true,
memory_operation_enum MemoryOperation_ = memory_operation_enum::set>
struct Default2DEpilogueProblem
{
using AccDataType = remove_cvref_t<AccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
static constexpr bool UseRawStore = UseRawStore_;
using AccDataType = remove_cvref_t<AccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
static constexpr bool UseRawStore = UseRawStore_;
static constexpr memory_operation_enum MemoryOperation = MemoryOperation_;
};
template <typename AccDataType_,
template <typename ADataType_,
typename BDataType_,
typename AccDataType_,
typename ODataType_,
typename CLayout_,
bool kPadM_,
@@ -34,10 +38,17 @@ template <typename AccDataType_,
index_t kNPerXdl_,
index_t kKPerXdl_,
bool isCTransposed_,
bool UseRawStore_ = true>
struct DefaultGemm2DEpilogueProblem
: public Default2DEpilogueProblem<AccDataType_, ODataType_, kPadM_, kPadN_, UseRawStore_>
bool UseRawStore_ = true,
memory_operation_enum MemoryOperation_ = memory_operation_enum::set>
struct DefaultGemm2DEpilogueProblem : public Default2DEpilogueProblem<AccDataType_,
ODataType_,
kPadM_,
kPadN_,
UseRawStore_,
MemoryOperation_>
{
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using CLayout = remove_cvref_t<CLayout_>;
static constexpr index_t kMPerXdl = kMPerXdl_;
static constexpr index_t kNPerXdl = kNPerXdl_;
@@ -54,22 +65,20 @@ struct Default2DEpilogue
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool UseRawStore = Problem::UseRawStore;
static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation;
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
// how do we fix this ?
template <typename ODramWindowTmp,
typename OAccTile,
memory_operation_enum out_memory_data_op = memory_operation_enum::set>
template <typename ODramWindowTmp, typename OAccTile>
CK_TILE_DEVICE auto
operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile, void* = nullptr)
operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile, void* = nullptr) const
{
// TODO: this is ugly
if constexpr(UseRawStore && (kPadM || kPadN))
{
if constexpr(out_memory_data_op == memory_operation_enum::set)
if constexpr(MemoryOperation == memory_operation_enum::set)
{
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
@@ -81,7 +90,7 @@ struct Default2DEpilogue
}
else
{
if constexpr(out_memory_data_op == memory_operation_enum::set)
if constexpr(MemoryOperation == memory_operation_enum::set)
{
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
@@ -91,27 +100,43 @@ struct Default2DEpilogue
}
}
}
template <typename ODramWindowTmp, typename OAccTile, typename DsDramWindows>
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
const OAccTile& o_acc_tile,
const DsDramWindows& /* unused */,
void* = nullptr) const
{
return operator()<ODramWindowTmp, OAccTile>(o_dram_window_tmp, o_acc_tile);
}
};
template <typename Problem_, typename Policy_ = void>
struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
{
using Problem = remove_cvref_t<Problem_>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using Problem = remove_cvref_t<Problem_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
// Used for weight-only quantization kernel, B would be dequantized to the same data type as A
using BTypeToUse =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
using DsDataType = ck_tile::tuple<>;
using DsLayout = ck_tile::tuple<>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
static constexpr index_t kMPerXdl = Problem::kMPerXdl;
static constexpr index_t kNPerXdl = Problem::kNPerXdl;
static constexpr index_t kKPerXdl = Problem::kKPerXdl;
static constexpr index_t isCTransposed = Problem::isCTransposed;
using WG = WarpGemmMfmaDispatcher<ODataType,
ODataType,
AccDataType,
kMPerXdl,
kNPerXdl,
kKPerXdl,
isCTransposed>;
using WG = WarpGemmDispatcher<ADataType,
BTypeToUse,
AccDataType,
kMPerXdl,
kNPerXdl,
kKPerXdl,
isCTransposed>;
using CWarpDstr = typename WG::CWarpDstr;
@@ -134,7 +159,9 @@ struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
else
{
// In this case each thread has just a single item in Ndim
return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN;
return (WG::WarpGemmAttribute::Impl::kCNLane *
WG::WarpGemmAttribute::Impl::kBNBlock) /
WG::kN;
}
}
// M is contiguous dimension
@@ -143,7 +170,9 @@ struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
if constexpr(isCTransposed)
{
// In this case each thread has just a single item in Mdim
return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN;
return (WG::WarpGemmAttribute::Impl::kCNLane *
WG::WarpGemmAttribute::Impl::kAMBlock) /
WG::kN;
}
else
{
@@ -162,6 +191,8 @@ struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
static_assert(false, "Unsupported CLayout!");
}
}
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeD() { return 1; }
};
} // namespace ck_tile

View File

@@ -3,10 +3,16 @@
#pragma once
#include "ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp"
#include "ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
#include "ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp"
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp"
#include "ck_tile/ops/flatmm/pipeline/tile_flatmm_shape.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -0,0 +1,122 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1_custom_policy.hpp"
namespace ck_tile {
// A is block window on shared memory
// B is block window on shared memory
// C is block distributed tensor
template <typename Problem_, typename BlockPolicy_>
struct BlockFlatmmASmemBSmemCRegV1
{
using Problem = remove_cvref_t<Problem_>;
using BlockPolicy = remove_cvref_t<BlockPolicy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; // TileFlatmmShape
static constexpr auto I0 = number<0>();
static constexpr auto I1 = number<1>();
static constexpr auto I2 = number<2>();
static constexpr auto idxM = I0;
static constexpr auto idxN = I1;
static constexpr auto idxK = I2;
using BlockTile = remove_cvref_t<typename BlockGemmShape::BlockTile>;
using BlockWarps = remove_cvref_t<typename BlockGemmShape::BlockWarps>;
using WarpTile = remove_cvref_t<typename BlockGemmShape::WarpTile>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
// C += A * B
template <typename CBlockTensor, typename ABlockWindow, typename BFlatBlockTensor>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
ABlockWindow& a_warp_windows,
BFlatBlockTensor& b_warp_tensor) const
{
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t KPerBlock = BlockGemmShape::kK;
constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp =
BlockTile::at(idxN) / (WarpTile::at(idxN) * BlockWarps::at(idxN));
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
using CWarpDstr = typename WG::CWarpDstr;
using CWarpTensor = typename WG::CWarpTensor;
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block window
const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor(nIter)(kIter));
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
}
};
} // namespace ck_tile

Some files were not shown because too many files have changed in this diff Show More