mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 03:19:48 +00:00
Merge branch 'develop' into hstu_attention_n0loop_fused_unroll
This commit is contained in:
@@ -10,9 +10,11 @@
|
||||
#include "ck_tile/core/algorithm/static_encoding_pattern.hpp"
|
||||
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
|
||||
#include "ck_tile/core/arch/amd_buffer_addressing_builtins.hpp"
|
||||
#include "ck_tile/core/arch/amd_transpose_load_encoding.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/generic_memory_space_atomic.hpp"
|
||||
#include "ck_tile/core/arch/utility.hpp"
|
||||
#include "ck_tile/core/arch/workgroup_barrier.hpp"
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
@@ -25,19 +27,23 @@
|
||||
#include "ck_tile/core/container/thread_buffer.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/numeric/bfloat16.hpp"
|
||||
#include "ck_tile/core/numeric/e8m0.hpp"
|
||||
#include "ck_tile/core/numeric/float8.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/int8.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/numeric/mxfp_convert.hpp"
|
||||
#include "ck_tile/core/numeric/null_type.hpp"
|
||||
#include "ck_tile/core/numeric/numeric.hpp"
|
||||
#include "ck_tile/core/numeric/pk_fp4.hpp"
|
||||
#include "ck_tile/core/numeric/pk_int4.hpp"
|
||||
#include "ck_tile/core/numeric/type_convert.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
#include "ck_tile/core/tensor/buffer_view.hpp"
|
||||
#include "ck_tile/core/tensor/load_tile.hpp"
|
||||
#include "ck_tile/core/tensor/load_tile_transpose.hpp"
|
||||
#include "ck_tile/core/tensor/null_tensor.hpp"
|
||||
#include "ck_tile/core/tensor/null_tile_window.hpp"
|
||||
#include "ck_tile/core/tensor/shuffle_tile.hpp"
|
||||
@@ -53,12 +59,15 @@
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
|
||||
#include "ck_tile/core/tensor/tile_elementwise.hpp"
|
||||
#include "ck_tile/core/tensor/tile_scatter_gather.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window_base.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window_linear.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window_utils.hpp"
|
||||
#include "ck_tile/core/tensor/transpose_tile.hpp"
|
||||
#include "ck_tile/core/tensor/update_tile.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/utility/debug.hpp"
|
||||
#include "ck_tile/core/utility/env.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/functional_with_tuple.hpp"
|
||||
@@ -66,6 +75,7 @@
|
||||
#include "ck_tile/core/utility/literals.hpp"
|
||||
#include "ck_tile/core/utility/magic_div.hpp"
|
||||
#include "ck_tile/core/utility/philox_rand.hpp"
|
||||
#include "ck_tile/core/utility/print.hpp"
|
||||
#include "ck_tile/core/utility/random.hpp"
|
||||
#include "ck_tile/core/utility/reduce_operator.hpp"
|
||||
#include "ck_tile/core/utility/static_counter.hpp"
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/utility/magic_div.hpp"
|
||||
#include "ck_tile/core/utility/print.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -139,20 +140,19 @@ struct pass_through : public base_transform<1, 1>
|
||||
{
|
||||
return make_tuple(low_vector_lengths, low_vector_strides);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("pass_through{");
|
||||
|
||||
//
|
||||
printf("up_lengths_:");
|
||||
print(up_lengths_);
|
||||
|
||||
//
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename LowLength>
|
||||
CK_TILE_HOST_DEVICE static void print(const pass_through<LowLength>& pt)
|
||||
{
|
||||
printf("pass_through{");
|
||||
|
||||
printf("up_lengths_: ");
|
||||
print(pt.get_upper_lengths());
|
||||
|
||||
printf("}");
|
||||
}
|
||||
|
||||
template <typename LowLength,
|
||||
typename LeftPadLength,
|
||||
typename RightPadLength,
|
||||
@@ -229,29 +229,25 @@ struct pad : public base_transform<1, 1>
|
||||
ck_tile::is_known_at_compile_time<LeftPadLength>::value &&
|
||||
ck_tile::is_known_at_compile_time<RightPadLength>::value;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("pad{");
|
||||
|
||||
//
|
||||
printf("up_lengths_: ");
|
||||
print(up_lengths_);
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("left_pad_length_: ");
|
||||
print(left_pad_length_);
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("right_pad_length_: ");
|
||||
print(right_pad_length_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename LowLength,
|
||||
typename LeftPadLength,
|
||||
typename RightPadLength,
|
||||
bool SkipIsValidCheck>
|
||||
CK_TILE_HOST_DEVICE static void
|
||||
print(const pad<LowLength, LeftPadLength, RightPadLength, SkipIsValidCheck>& p)
|
||||
{
|
||||
printf("pad{");
|
||||
printf("up_lengths_: ");
|
||||
print(p.up_lengths_);
|
||||
printf(", left_pad_length_: ");
|
||||
print(p.left_pad_length_);
|
||||
printf(", right_pad_length_: ");
|
||||
print(p.right_pad_length_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
template <typename LowLength, typename LeftPadLength, bool SkipIsValidCheck = false>
|
||||
struct left_pad
|
||||
{
|
||||
@@ -330,24 +326,20 @@ struct left_pad
|
||||
// It's up to runtime to check the padding length should be multiple of vector length
|
||||
return make_tuple(low_vector_lengths, low_vector_strides);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("left_pad{");
|
||||
|
||||
//
|
||||
printf("up_lengths_: ");
|
||||
print(up_lengths_);
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("left_pad_length_: ");
|
||||
print(left_pad_length_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename LowLength, typename LeftPadLength, bool SkipIsValidCheck>
|
||||
CK_TILE_HOST_DEVICE static void
|
||||
print(const left_pad<LowLength, LeftPadLength, SkipIsValidCheck>& lp)
|
||||
{
|
||||
printf("left_pad{");
|
||||
printf("up_lengths_: ");
|
||||
print(lp.up_lengths_);
|
||||
printf(", left_pad_length_: ");
|
||||
print(lp.left_pad_length_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck = false>
|
||||
struct right_pad : public base_transform<1, 1>
|
||||
{
|
||||
@@ -430,24 +422,20 @@ struct right_pad : public base_transform<1, 1>
|
||||
// It's up to runtime to check the padding length should be multiple of vector length
|
||||
return make_tuple(low_vector_lengths, low_vector_strides);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("right_pad{");
|
||||
|
||||
//
|
||||
printf("up_lengths_: ");
|
||||
print(up_lengths_);
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("right_pad_length_: ");
|
||||
print(right_pad_length_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck>
|
||||
CK_TILE_HOST_DEVICE static void
|
||||
print(const right_pad<LowLength, RightPadLength, SkipIsValidCheck>& rp)
|
||||
{
|
||||
printf("right_pad{");
|
||||
printf("up_lengths_: ");
|
||||
print(rp.up_lengths_);
|
||||
printf(", right_pad_length_: ");
|
||||
print(rp.right_pad_length_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1]
|
||||
// UpLengths and Coefficients can be either of the followings:
|
||||
// 1) Tuple of index_t, which is known at run-time, or
|
||||
@@ -532,24 +520,19 @@ struct embed : public base_transform<1, UpLengths::size()>
|
||||
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
|
||||
ck_tile::is_known_at_compile_time<Coefficients>::value;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("embed{");
|
||||
|
||||
//
|
||||
printf("up_lengths_: ");
|
||||
print(up_lengths_);
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("coefficients_: ");
|
||||
print(coefficients_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename UpLengths, typename Coefficients>
|
||||
CK_TILE_HOST_DEVICE static void print(const embed<UpLengths, Coefficients>& e)
|
||||
{
|
||||
printf("embed{");
|
||||
printf("up_lengths_: ");
|
||||
print(e.up_lengths_);
|
||||
printf(", coefficients_: ");
|
||||
print(e.coefficients_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
template <typename LowLengths>
|
||||
struct lambda_merge_generate_MagicDivision_calculate_magic_divisor
|
||||
{
|
||||
@@ -699,24 +682,19 @@ struct merge_v2_magic_division : public base_transform<LowLengths::size(), 1>
|
||||
|
||||
return make_tuple(up_vector_lengths, up_vector_strides);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("merge_v2_magic_division{");
|
||||
|
||||
//
|
||||
printf("low_lengths_ ");
|
||||
print(low_lengths_);
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("up_lengths_ ");
|
||||
print(up_lengths_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename LowLengths>
|
||||
CK_TILE_HOST_DEVICE static void print(const merge_v2_magic_division<LowLengths>& m)
|
||||
{
|
||||
printf("merge_v2_magic_division{");
|
||||
printf("low_lengths_: ");
|
||||
print(m.low_lengths_);
|
||||
printf(", up_lengths_: ");
|
||||
print(m.up_lengths_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
// Implementation of "merge" transformation primitive that uses division and mod. It is supposed to
|
||||
// be used for low_lengths that are known at compile time and are power of 2, otherwise performance
|
||||
// will be very bad
|
||||
@@ -830,29 +808,21 @@ struct merge_v3_division_mod : public base_transform<LowLengths::size(), 1>
|
||||
|
||||
return make_tuple(up_vector_lengths, up_vector_strides);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("Merge_v3_direct_division_mod{");
|
||||
|
||||
//
|
||||
printf("low_lengths_ ");
|
||||
print(low_lengths_);
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("low_lengths_scan_ ");
|
||||
print(low_lengths_scan_);
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("up_lengths_ ");
|
||||
print(up_lengths_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename LowLengths>
|
||||
CK_TILE_HOST_DEVICE static void print(const merge_v3_division_mod<LowLengths>& m)
|
||||
{
|
||||
printf("merge_v3_division_mod{");
|
||||
printf("low_lengths_: ");
|
||||
print(m.low_lengths_);
|
||||
printf(", low_lengths_scan_: ");
|
||||
print(m.low_lengths_scan_);
|
||||
printf(", up_lengths_: ");
|
||||
print(m.up_lengths_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
template <typename UpLengths, bool Use24BitIntegerCalculation>
|
||||
struct unmerge : public base_transform<1, UpLengths::size()>
|
||||
{
|
||||
@@ -958,24 +928,19 @@ struct unmerge : public base_transform<1, UpLengths::size()>
|
||||
|
||||
return make_tuple(up_vector_lengths, up_vector_strides);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("unmerge{");
|
||||
|
||||
//
|
||||
printf("up_lengths_");
|
||||
print(up_lengths_);
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("up_lengths_scan_");
|
||||
print(up_lengths_scan_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename UpLengths, bool Use24BitIntegerCalculation>
|
||||
CK_TILE_HOST_DEVICE static void print(const unmerge<UpLengths, Use24BitIntegerCalculation>& u)
|
||||
{
|
||||
printf("unmerge{");
|
||||
printf("up_lengths_: ");
|
||||
print(u.up_lengths_);
|
||||
printf(", up_lengths_scan_: ");
|
||||
print(u.up_lengths_scan_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
template <typename LowerIndex>
|
||||
struct freeze : public base_transform<1, 0>
|
||||
{
|
||||
@@ -1023,19 +988,17 @@ struct freeze : public base_transform<1, 0>
|
||||
{
|
||||
return ck_tile::is_known_at_compile_time<LowerIndex>::value;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("freeze{");
|
||||
|
||||
//
|
||||
printf("low_idx_: ");
|
||||
print(low_idx_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename LowerIndex>
|
||||
CK_TILE_HOST_DEVICE static void print(const freeze<LowerIndex>& f)
|
||||
{
|
||||
printf("freeze{");
|
||||
printf("low_idx_: ");
|
||||
print(f.low_idx_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
// insert a dangling upper dimension without lower dimension
|
||||
template <typename UpperLength>
|
||||
struct insert : public base_transform<0, 1>
|
||||
@@ -1092,18 +1055,17 @@ struct insert : public base_transform<0, 1>
|
||||
{
|
||||
return ck_tile::is_known_at_compile_time<UpperLength>::value;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("insert{");
|
||||
|
||||
//
|
||||
print(up_lengths_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename UpperLength>
|
||||
CK_TILE_HOST_DEVICE static void print(const insert<UpperLength>& i)
|
||||
{
|
||||
printf("insert{");
|
||||
printf("up_lengths_: ");
|
||||
print(i.up_lengths_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
// replicate the original tensor and create a higher dimensional tensor
|
||||
template <typename UpLengths>
|
||||
struct replicate : public base_transform<0, UpLengths::size()>
|
||||
@@ -1152,21 +1114,19 @@ struct replicate : public base_transform<0, UpLengths::size()>
|
||||
return ck_tile::is_known_at_compile_time<UpLengths>::value;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("replicate{");
|
||||
|
||||
//
|
||||
printf("up_lengths_: ");
|
||||
print(up_lengths_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
|
||||
//
|
||||
UpLengths up_lengths_;
|
||||
};
|
||||
|
||||
template <typename UpLengths>
|
||||
CK_TILE_HOST_DEVICE static void print(const replicate<UpLengths>& r)
|
||||
{
|
||||
printf("replicate{");
|
||||
printf("up_lengths_: ");
|
||||
print(r.up_lengths_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
template <typename LowLength, typename SliceBegin, typename SliceEnd>
|
||||
struct slice : public base_transform<1, 1>
|
||||
{
|
||||
@@ -1238,28 +1198,20 @@ struct slice : public base_transform<1, 1>
|
||||
ck_tile::is_known_at_compile_time<SliceBegin>::value &&
|
||||
ck_tile::is_known_at_compile_time<SliceEnd>::value;
|
||||
}
|
||||
};
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("slice{");
|
||||
|
||||
//
|
||||
printf("up_lengths_: ");
|
||||
print(up_lengths_);
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("slice_begin_: ");
|
||||
print(slice_begin_);
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("slice_end_: ");
|
||||
print(slice_end_);
|
||||
|
||||
printf("}");
|
||||
} // namespace ck
|
||||
}; // namespace ck
|
||||
template <typename LowLength, typename SliceBegin, typename SliceEnd>
|
||||
CK_TILE_HOST_DEVICE static void print(const slice<LowLength, SliceBegin, SliceEnd>& s)
|
||||
{
|
||||
printf("slice{");
|
||||
printf("up_lengths_: ");
|
||||
print(s.up_lengths_);
|
||||
printf(", slice_begin_: ");
|
||||
print(s.slice_begin_);
|
||||
printf(", slice_end_: ");
|
||||
print(s.slice_end_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
/*
|
||||
* \brief lower_idx = upper_idx % modulus.
|
||||
@@ -1328,19 +1280,19 @@ struct modulo : public base_transform<1, 1>
|
||||
{
|
||||
return ck_tile::is_known_at_compile_time<UpLengths>::value;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("Modulus{");
|
||||
|
||||
//
|
||||
printf("up_lengths_: ");
|
||||
print(up_lengths_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Modulus, typename UpLength>
|
||||
CK_TILE_HOST_DEVICE static void print(const modulo<Modulus, UpLength>& m)
|
||||
{
|
||||
printf("modulo{");
|
||||
printf("modulus_: ");
|
||||
print(m.modulus_);
|
||||
printf(", up_lengths_: ");
|
||||
print(m.up_lengths_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
// 2D XOR, NOTE: "xor" is a keyword
|
||||
template <typename LowLengths>
|
||||
struct xor_t : public base_transform<2, 2>
|
||||
@@ -1424,20 +1376,17 @@ struct xor_t : public base_transform<2, 2>
|
||||
|
||||
return make_tuple(up_vector_lengths, up_vector_strides);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("xor_t{");
|
||||
|
||||
//
|
||||
printf("up_lengths_: ");
|
||||
print(up_lengths_);
|
||||
printf(", ");
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename LowLengths>
|
||||
CK_TILE_HOST_DEVICE static void print(const xor_t<LowLengths>& x)
|
||||
{
|
||||
printf("xor_t{");
|
||||
printf("up_lengths_: ");
|
||||
print(x.up_lengths_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
template <typename LowLength, typename OffsetLength>
|
||||
struct offset : public base_transform<1, 1>
|
||||
{
|
||||
@@ -1509,24 +1458,19 @@ struct offset : public base_transform<1, 1>
|
||||
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
|
||||
ck_tile::is_known_at_compile_time<OffsetLength>::value;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("offset{");
|
||||
|
||||
//
|
||||
printf("up_lengths_: ");
|
||||
print(up_lengths_);
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("offset_length_: ");
|
||||
print(offset_length_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename LowLength, typename OffsetLength>
|
||||
CK_TILE_HOST_DEVICE static void print(const offset<LowLength, OffsetLength>& o)
|
||||
{
|
||||
printf("offset{");
|
||||
printf("up_lengths_: ");
|
||||
print(o.up_lengths_);
|
||||
printf(", offset_length_: ");
|
||||
print(o.offset_length_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
template <typename UpLength, typename IndexingAdaptor>
|
||||
struct indexing : public base_transform<1, 1>
|
||||
{
|
||||
@@ -1595,20 +1539,19 @@ struct indexing : public base_transform<1, 1>
|
||||
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
|
||||
IndexingAdaptor::is_known_at_compile_time();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("embed{");
|
||||
|
||||
//
|
||||
printf("up_lengths_: ");
|
||||
print(up_lengths_);
|
||||
printf(", ");
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename UpLength, typename IndexingAdaptor>
|
||||
CK_TILE_HOST_DEVICE static void print(const indexing<UpLength, IndexingAdaptor>& i)
|
||||
{
|
||||
printf("indexing{");
|
||||
printf("up_lengths_: ");
|
||||
print(i.up_lengths_);
|
||||
printf(", iadaptor_: ");
|
||||
print(i.iadaptor_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
//*******************************************************************************************************
|
||||
|
||||
template <typename LowLength>
|
||||
|
||||
@@ -100,10 +100,8 @@ struct space_filling_curve
|
||||
// Given tensor strides \p access_lengths, and 1D index of space-filling-curve, compute the
|
||||
// idim-th element of multidimensional index.
|
||||
// All constexpr variables have to be captured by VALUE.
|
||||
constexpr auto compute_index = [ idx_1d, access_strides ](auto idim) constexpr
|
||||
{
|
||||
constexpr auto compute_index_impl = [ idx_1d, access_strides ](auto jdim) constexpr
|
||||
{
|
||||
constexpr auto compute_index = [idx_1d, access_strides](auto idim) constexpr {
|
||||
constexpr auto compute_index_impl = [idx_1d, access_strides](auto jdim) constexpr {
|
||||
auto res = idx_1d.value;
|
||||
auto id = 0;
|
||||
|
||||
|
||||
@@ -1,6 +1,73 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
/**
|
||||
* @file
|
||||
* We're defining the data access pattern for a 2D window (`XPerTile` by `YPerTile`)
|
||||
for `BlockSize` threads in a thread block.
|
||||
* X dimension is considered contiguous in memory, so a single instruction can access
|
||||
several adjacent and properly aligned elements (vector); the access pattern along X tile
|
||||
dimension is parameterized only by the suggested vector size `VecSize`.
|
||||
* We can't access more than `MaxVecSize = TileElementsPerThread = TileSize / BlockSize` elements
|
||||
with a single memory access, so the actual vector size along the X dimension is
|
||||
`X0 = min(MaxVecSize, VecSize)`.
|
||||
* This leaves `X1 = XPerTile / X0` threads per tile in X dimension.
|
||||
* X1 is also the number of threads per warp in X dimension, that is,
|
||||
X dimension is not split between warps, and each warp accesses X dimension entirely,
|
||||
and there is no iteration in X dimension.
|
||||
* The tuple <X0, X1> defines the X-axis access pattern.
|
||||
This part is common between the 2D distribution patterns.
|
||||
|
||||
* What's different between the different 2D distribution patterns, is the Y axis access pattern.
|
||||
* There are 3 components in this access pattern;
|
||||
* (1) number of Y-axis elements (rows) per warp for a single instruction access,
|
||||
* (2) number of warps per thread block,
|
||||
* (3) number of iterations to cover the entire Y axis.
|
||||
|
||||
* The raked here represents how data is partitioned across different processing granularity.
|
||||
* It represents howe we are going to access the data in thread, warp, or blocked in contiguous
|
||||
region.
|
||||
* From below, the qualifier for 'raked' is the part of warp/thread hierarchy
|
||||
* in the split of Y tile dimension where the iteration happens,
|
||||
* meaning, the iteration can be logically inserted as a tile dimension in 3 ways,
|
||||
* (1) after thread -> thread-raked,
|
||||
* (2) between warp and thread -> warp-raked,
|
||||
* (3) before warp -> block-raked
|
||||
|
||||
* *Thread raked*
|
||||
|
||||
* Y0 is the number of warps, which we can get from the equation `Y0 * WarpSize == BlockSize`
|
||||
* Y1 is the number of rows accessed by a warp within a single iteration,
|
||||
compute it from the equation `Y0 * X1 == WarpSize`
|
||||
* Y2 is the number of iterations to cover the tile,
|
||||
compute it from the equation `Y0 * Y1 * Y2 == YPerTile`
|
||||
|
||||
* *Warp raked*
|
||||
|
||||
* Y0 is the number of warps, we can get it in the same way as for thread-raked pattern,
|
||||
`Y0 * WarpSize == BlockSize`
|
||||
* Y1 is the number of iterations to cover the tile, `Y0 * Y1 * Y2 == YPerTile`.
|
||||
Compute Y2 from the equation below
|
||||
* Y2 is the number of rows accessed by a warp in a single iteration, `Y2 * X1 == WarpSize`
|
||||
|
||||
* *Block raked*
|
||||
|
||||
* Y0 is the number of iterations to cover the tile, `Y0 * Y1 * Y2 == YPerTile`.
|
||||
Compute Y1 and Y2 from the equations below
|
||||
* Y1 is the number of warps, `Y1 * WarpSize == BlockSize`
|
||||
* Y2 is the number of rows accessed by a warp in a single iteration, `Y2 * X1 == WarpSize`
|
||||
|
||||
* In all cases, the tuple <Y0, Y1, Y2> defines the Y-axis access pattern.
|
||||
|
||||
* *Selection*
|
||||
* When we are selecting, Thread-raked is used in element-wise operation because it is the
|
||||
* Thread-major memory order.
|
||||
* Warp-raked is used in matrix multiplication because the vectorization is in warp level.
|
||||
* Block-raked is used mostly for the reduction process, where will reduce the block in global
|
||||
* atomic level.
|
||||
*
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
@@ -10,6 +77,7 @@
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
|
||||
#include "ck_tile/core/utility/print.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -56,78 +124,116 @@ template <index_t BlockSize,
|
||||
index_t YPerTile,
|
||||
index_t XPerTile,
|
||||
index_t VecSize,
|
||||
tile_distribution_pattern DistributionPattern>
|
||||
tile_distribution_pattern DistributionPattern,
|
||||
index_t NumWaveGroups = 1>
|
||||
struct TileDistributionEncodingPattern2D : public TileDistributionEncodingPattern
|
||||
{
|
||||
};
|
||||
|
||||
// Thread raked
|
||||
template <index_t BlockSize, index_t YPerTile, index_t XPerTile, index_t VecSize>
|
||||
template <index_t BlockSize,
|
||||
index_t YPerTile,
|
||||
index_t XPerTile,
|
||||
index_t VecSize,
|
||||
index_t NumWaveGroups>
|
||||
struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
YPerTile,
|
||||
XPerTile,
|
||||
VecSize,
|
||||
tile_distribution_pattern::thread_raked>
|
||||
: public TileDistributionEncodingPattern
|
||||
tile_distribution_pattern::thread_raked,
|
||||
NumWaveGroups> : public TileDistributionEncodingPattern
|
||||
{
|
||||
|
||||
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
|
||||
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
|
||||
static constexpr index_t warp_size = get_warp_size();
|
||||
static constexpr index_t num_warps = BlockSize / get_warp_size();
|
||||
static constexpr index_t X1 = VecSize;
|
||||
static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
|
||||
static constexpr index_t warp_size = get_warp_size();
|
||||
static constexpr index_t num_warps = BlockSize / get_warp_size();
|
||||
static constexpr index_t LargestVec = (XPerTile * YPerTile) / (num_warps * warp_size);
|
||||
static constexpr index_t X1 = VecSize > LargestVec ? LargestVec : VecSize;
|
||||
static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
|
||||
|
||||
// # of rows in Y dim accessed by single wavefront in one iteration
|
||||
static constexpr index_t Y1 = warp_size / X0;
|
||||
static_assert(X0 * Y1 == warp_size, "X0 * Y1 must cover whole wavefront!");
|
||||
|
||||
static constexpr index_t Y0 = num_warps;
|
||||
static constexpr index_t Y0 = num_warps / NumWaveGroups;
|
||||
// YPerWarp = YPerTile / Y0;
|
||||
// Y2 = YPerWarp / Y1;
|
||||
static constexpr index_t Y2 = YPerTile / (Y1 * Y0); // # of iters within wavefront
|
||||
|
||||
static_assert(X0 * Y1 * Y0 == BlockSize, "X0 * warp_ys * Y0 must cover whole workgroup!");
|
||||
static_assert(X0 * Y1 * Y0 * NumWaveGroups == BlockSize,
|
||||
"X0 * warp_ys * Y0 must cover whole workgroup!");
|
||||
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 1>>{});
|
||||
if constexpr(NumWaveGroups != 1)
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<Y0>,
|
||||
tuple<sequence<Y1, Y2>, sequence<X0, X1>>,
|
||||
tuple<sequence<0>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<0, 0>>, // -> <Y0>, <Y1, X0>
|
||||
sequence<1, 2>,
|
||||
sequence<1, 1>>{}); // -> <Y2, X1>
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>, // -> <Y0>, <Y1, X0>
|
||||
sequence<1, 2>,
|
||||
sequence<2, 1>>{}); // -> <Y2, X1>
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution()
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<X0, X1>, sequence<Y0, Y1, Y2>>,
|
||||
tuple<sequence<2>, sequence<2, 1>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 2>>{});
|
||||
if constexpr(NumWaveGroups != 1)
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<Y0>,
|
||||
tuple<sequence<X0, X1>, sequence<Y1, Y2>>,
|
||||
tuple<sequence<0>, sequence<2, 1>>,
|
||||
tuple<sequence<0>, sequence<0, 0>>, // -> <Y0>, <Y1, X0>
|
||||
sequence<1, 2>,
|
||||
sequence<1, 1>>{}); // -> <X1, Y2>
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<X0, X1>, sequence<Y0, Y1, Y2>>,
|
||||
tuple<sequence<2>, sequence<2, 1>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>, // -> <Y0>, <Y1, X0>
|
||||
sequence<1, 2>,
|
||||
sequence<1, 2>>{}); // -> <X1, Y2>
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Warp raked
|
||||
template <index_t BlockSize, index_t YPerTile, index_t XPerTile, index_t VecSize>
|
||||
template <index_t BlockSize,
|
||||
index_t YPerTile,
|
||||
index_t XPerTile,
|
||||
index_t VecSize,
|
||||
index_t NumWaveGroups>
|
||||
struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
YPerTile,
|
||||
XPerTile,
|
||||
VecSize,
|
||||
tile_distribution_pattern::warp_raked>
|
||||
: public TileDistributionEncodingPattern
|
||||
tile_distribution_pattern::warp_raked,
|
||||
NumWaveGroups> : public TileDistributionEncodingPattern
|
||||
{
|
||||
|
||||
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
|
||||
static constexpr index_t warp_size = get_warp_size();
|
||||
static constexpr index_t num_warps = BlockSize / get_warp_size();
|
||||
static constexpr index_t X1 = VecSize;
|
||||
static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
|
||||
static constexpr index_t warp_size = get_warp_size();
|
||||
static constexpr index_t num_warps = BlockSize / get_warp_size();
|
||||
static constexpr index_t LargestVec = (XPerTile * YPerTile) / (num_warps * warp_size);
|
||||
static constexpr index_t X1 = VecSize > LargestVec ? LargestVec : VecSize;
|
||||
static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
|
||||
|
||||
static constexpr index_t Y2 = warp_size / X0; // # of rows in Y dim to cover whole wavefront
|
||||
static_assert(X0 * Y2 == warp_size, "X0 * Y2 must cover whole wavefront!");
|
||||
@@ -144,9 +250,9 @@ struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<2, 0>>,
|
||||
tuple<sequence<0>, sequence<2, 0>>, // -> <Y0>, <Y2, X0>
|
||||
sequence<1, 2>,
|
||||
sequence<1, 1>>{});
|
||||
sequence<1, 1>>{}); // -> <Y1, X1>
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution()
|
||||
@@ -155,28 +261,33 @@ struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<X0, X1>, sequence<Y0, Y1, Y2>>,
|
||||
tuple<sequence<2>, sequence<2, 1>>,
|
||||
tuple<sequence<0>, sequence<2, 0>>,
|
||||
tuple<sequence<0>, sequence<2, 0>>, // -> <Y0>, <Y2, X0>
|
||||
sequence<1, 2>,
|
||||
sequence<1, 1>>{});
|
||||
sequence<1, 1>>{}); // -> <X1, Y1>
|
||||
}
|
||||
};
|
||||
|
||||
// Block raked
|
||||
template <index_t BlockSize, index_t YPerTile, index_t XPerTile, index_t VecSize>
|
||||
template <index_t BlockSize,
|
||||
index_t YPerTile,
|
||||
index_t XPerTile,
|
||||
index_t VecSize,
|
||||
index_t NumWaveGroups>
|
||||
struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
YPerTile,
|
||||
XPerTile,
|
||||
VecSize,
|
||||
tile_distribution_pattern::block_raked>
|
||||
: public TileDistributionEncodingPattern
|
||||
tile_distribution_pattern::block_raked,
|
||||
NumWaveGroups> : public TileDistributionEncodingPattern
|
||||
{
|
||||
|
||||
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
|
||||
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
|
||||
static constexpr index_t warp_size = get_warp_size();
|
||||
static constexpr index_t num_warps = BlockSize / get_warp_size();
|
||||
static constexpr index_t X1 = VecSize;
|
||||
static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
|
||||
static constexpr index_t warp_size = get_warp_size();
|
||||
static constexpr index_t num_warps = BlockSize / get_warp_size();
|
||||
static constexpr index_t LargestVec = (XPerTile * YPerTile) / (num_warps * warp_size);
|
||||
static constexpr index_t X1 = VecSize > LargestVec ? LargestVec : VecSize;
|
||||
static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
|
||||
static constexpr index_t Y2 = warp_size / X0; // # of rows in Y dim to cover whole wavefront
|
||||
static_assert(X0 * Y2 == warp_size, "X0 * Y2 must cover whole wavefront!");
|
||||
static constexpr index_t Y1 = num_warps;
|
||||
@@ -190,9 +301,9 @@ struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>, // -> <Y1>, <Y2, X0>
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
sequence<0, 1>>{}); // -> <Y0, X1>
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution()
|
||||
@@ -201,10 +312,57 @@ struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<X0, X1>, sequence<Y0, Y1, Y2>>,
|
||||
tuple<sequence<2>, sequence<2, 1>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>, // -> <Y1>, <Y2, X0>
|
||||
sequence<1, 2>,
|
||||
sequence<1, 0>>{});
|
||||
sequence<1, 0>>{}); // -> <X1, Y0>
|
||||
}
|
||||
};
|
||||
|
||||
// Helper function to convert enum to string
|
||||
constexpr const char* tile_distribution_pattern_to_string(tile_distribution_pattern pattern)
|
||||
{
|
||||
switch(pattern)
|
||||
{
|
||||
case tile_distribution_pattern::thread_raked: return "thread_raked";
|
||||
case tile_distribution_pattern::warp_raked: return "warp_raked";
|
||||
case tile_distribution_pattern::block_raked: return "block_raked";
|
||||
default: return "unknown";
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t BlockSize,
|
||||
index_t YPerTile,
|
||||
index_t XPerTile,
|
||||
index_t VecSize,
|
||||
tile_distribution_pattern DistributionPattern,
|
||||
index_t NumWaveGroups>
|
||||
CK_TILE_HOST_DEVICE void print(const TileDistributionEncodingPattern2D<BlockSize,
|
||||
YPerTile,
|
||||
XPerTile,
|
||||
VecSize,
|
||||
DistributionPattern,
|
||||
NumWaveGroups>&)
|
||||
{
|
||||
using PatternType = TileDistributionEncodingPattern2D<BlockSize,
|
||||
YPerTile,
|
||||
XPerTile,
|
||||
VecSize,
|
||||
DistributionPattern,
|
||||
NumWaveGroups>;
|
||||
|
||||
printf("TileDistributionEncodingPattern2D<BlockSize:%d, YPerTile:%d, XPerTile:%d, "
|
||||
"VecSize:%d, %s>: ",
|
||||
BlockSize,
|
||||
YPerTile,
|
||||
XPerTile,
|
||||
VecSize,
|
||||
tile_distribution_pattern_to_string(DistributionPattern));
|
||||
printf("{<Y0, Y1, Y2>: <%d, %d, %d>, <X0, X1>: <%d, %d>}\n",
|
||||
PatternType::Y0,
|
||||
PatternType::Y1,
|
||||
PatternType::Y2,
|
||||
PatternType::X0,
|
||||
PatternType::X1);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -13,6 +13,18 @@
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/ignore.hpp"
|
||||
|
||||
// This attribute gives a hint to the compiler that a branch is likely to be taken.
|
||||
// Then, the compiler should remove if possible the associated s_cbranch_execz branch that would
|
||||
// have been generated.
|
||||
#if __cplusplus >= 202002L
|
||||
#define LIKELY(x) (x) [[likely]]
|
||||
#else
|
||||
#define LIKELY(x) (__builtin_expect(!!(x), 1))
|
||||
#endif
|
||||
|
||||
using as3_uint32_ptr = uint32_t __attribute__((address_space(3)))*;
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -54,10 +66,36 @@ template<> struct buffer_load_trait<4 , thread_buffer<bf16_t, 2>> { using payloa
|
||||
// TODO: glc/slc/...
|
||||
template <index_t bytes, bool pre_nop = false>
|
||||
struct buffer_load;
|
||||
|
||||
template <index_t bytes, bool pre_nop = false>
|
||||
struct buffer_load_if;
|
||||
|
||||
template <index_t bytes>
|
||||
struct buffer_store;
|
||||
|
||||
template <index_t bytes>
|
||||
struct buffer_store_if;
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wundefined-reinterpret-cast"
|
||||
// TODO: strict aliasing rule seems fail when reinterpret_cast between vector type
|
||||
// (exp_vector_type(xxx))
|
||||
|
||||
#define HAS_RAW_BUFFER_BUILTINS \
|
||||
__has_builtin(__builtin_amdgcn_raw_buffer_load_b32) && \
|
||||
__has_builtin(__builtin_amdgcn_make_buffer_rsrc) && \
|
||||
__has_builtin(__builtin_amdgcn_raw_buffer_store_b32)
|
||||
|
||||
#if HAS_RAW_BUFFER_BUILTINS
|
||||
CK_TILE_DEVICE __amdgpu_buffer_rsrc_t cast_to_amdgpu_buffer_rsrc_t(int32x4_t res)
|
||||
{
|
||||
__amdgpu_buffer_rsrc_t as_rsrc;
|
||||
static_assert(sizeof(res) == sizeof(as_rsrc) && "Size of buffer resource should match");
|
||||
memcpy(&as_rsrc, &res, sizeof(res));
|
||||
return as_rsrc;
|
||||
}
|
||||
#endif
|
||||
|
||||
template <bool pre_nop>
|
||||
struct buffer_load<16, pre_nop>
|
||||
{
|
||||
@@ -72,6 +110,11 @@ struct buffer_load<16, pre_nop>
|
||||
{
|
||||
static_assert(sizeof(T) == 16);
|
||||
using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t;
|
||||
#if HAS_RAW_BUFFER_BUILTINS
|
||||
index_t s_offset = i_offset;
|
||||
reinterpret_cast<mbuf_t&>(value) = __builtin_amdgcn_raw_buffer_load_b128(
|
||||
cast_to_amdgpu_buffer_rsrc_t(res), v_offset, s_offset, 0);
|
||||
#else
|
||||
if constexpr(pre_nop)
|
||||
asm volatile("s_nop 4\n"
|
||||
"buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3"
|
||||
@@ -83,6 +126,7 @@ struct buffer_load<16, pre_nop>
|
||||
: "+v"(reinterpret_cast<mbuf_t&>(value))
|
||||
: "v"(v_offset), "s"(res), "n"(i_offset)
|
||||
: "memory");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@@ -100,6 +144,11 @@ struct buffer_load<8, pre_nop>
|
||||
{
|
||||
static_assert(sizeof(T) == 8);
|
||||
using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t;
|
||||
#if HAS_RAW_BUFFER_BUILTINS
|
||||
index_t s_offset = i_offset;
|
||||
reinterpret_cast<mbuf_t&>(value) = __builtin_amdgcn_raw_buffer_load_b64(
|
||||
cast_to_amdgpu_buffer_rsrc_t(res), v_offset, s_offset, 0);
|
||||
#else
|
||||
if constexpr(pre_nop)
|
||||
asm volatile("s_nop 4\n"
|
||||
"buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3"
|
||||
@@ -111,6 +160,7 @@ struct buffer_load<8, pre_nop>
|
||||
: "+v"(reinterpret_cast<mbuf_t&>(value))
|
||||
: "v"(v_offset), "s"(res), "n"(i_offset)
|
||||
: "memory");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@@ -128,6 +178,12 @@ struct buffer_load<4, pre_nop>
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t;
|
||||
|
||||
#if HAS_RAW_BUFFER_BUILTINS
|
||||
index_t s_offset = i_offset;
|
||||
reinterpret_cast<mbuf_t&>(value) = __builtin_amdgcn_raw_buffer_load_b32(
|
||||
cast_to_amdgpu_buffer_rsrc_t(res), v_offset, s_offset, 0);
|
||||
#else
|
||||
if constexpr(pre_nop)
|
||||
asm volatile("s_nop 4\n"
|
||||
"buffer_load_dword %0, %1, %2, 0 offen offset:%3"
|
||||
@@ -139,6 +195,7 @@ struct buffer_load<4, pre_nop>
|
||||
: "+v"(reinterpret_cast<mbuf_t&>(value))
|
||||
: "v"(v_offset), "s"(res), "n"(i_offset)
|
||||
: "memory");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@@ -156,6 +213,12 @@ struct buffer_load<2, pre_nop>
|
||||
{
|
||||
static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually
|
||||
using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t;
|
||||
|
||||
#if HAS_RAW_BUFFER_BUILTINS
|
||||
index_t s_offset = i_offset;
|
||||
reinterpret_cast<mbuf_t&>(value) = __builtin_amdgcn_raw_buffer_load_b16(
|
||||
cast_to_amdgpu_buffer_rsrc_t(res), v_offset, s_offset, 0);
|
||||
#else
|
||||
if constexpr(pre_nop)
|
||||
asm volatile("s_nop 4\n"
|
||||
"buffer_load_ushort %0, %1, %2, 0 offen offset:%3"
|
||||
@@ -167,6 +230,7 @@ struct buffer_load<2, pre_nop>
|
||||
: "+v"(reinterpret_cast<mbuf_t&>(value))
|
||||
: "v"(v_offset), "s"(res), "n"(i_offset)
|
||||
: "memory");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@@ -184,6 +248,11 @@ struct buffer_load<1, pre_nop>
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t;
|
||||
#if HAS_RAW_BUFFER_BUILTINS
|
||||
index_t s_offset = i_offset;
|
||||
reinterpret_cast<mbuf_t&>(value) = __builtin_amdgcn_raw_buffer_load_b16(
|
||||
cast_to_amdgpu_buffer_rsrc_t(res), v_offset, s_offset, 0);
|
||||
#else
|
||||
if constexpr(pre_nop)
|
||||
asm volatile("s_nop 4\n"
|
||||
"buffer_load_ubyte %0, %1, %2, 0 offen offset:%3"
|
||||
@@ -195,12 +264,31 @@ struct buffer_load<1, pre_nop>
|
||||
: "+v"(reinterpret_cast<mbuf_t&>(value))
|
||||
: "v"(v_offset), "s"(res), "n"(i_offset)
|
||||
: "memory");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t bytes, bool pre_nop = false>
|
||||
struct buffer_load_if;
|
||||
|
||||
#if HAS_RAW_BUFFER_BUILTINS
|
||||
template <index_t bytes, bool pre_nop>
|
||||
struct buffer_load_if
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE void operator()(T& value,
|
||||
int32x4_t res /*buffer resource*/,
|
||||
index_t v_offset,
|
||||
index_t s_offset,
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t flag = 0,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
if LIKELY(1 <= flag)
|
||||
{
|
||||
buffer_load<bytes, pre_nop>{}(
|
||||
value, res, v_offset, s_offset, i_offset, flag, bool_constant<pre_nop>{});
|
||||
}
|
||||
}
|
||||
};
|
||||
#else
|
||||
template <bool pre_nop>
|
||||
struct buffer_load_if<16, pre_nop>
|
||||
{
|
||||
@@ -366,9 +454,9 @@ struct buffer_load_if<1, pre_nop>
|
||||
: "memory");
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
#pragma clang diagnostic pop // "-Wundefined-reinterpret-cast"
|
||||
template <index_t bytes>
|
||||
struct buffer_store;
|
||||
|
||||
template <>
|
||||
struct buffer_store<16>
|
||||
@@ -383,10 +471,16 @@ struct buffer_store<16>
|
||||
{
|
||||
static_assert(sizeof(T) == 16);
|
||||
using mbuf_t = fp32x4_t;
|
||||
#if HAS_RAW_BUFFER_BUILTINS
|
||||
index_t s_offset = i_offset;
|
||||
__builtin_amdgcn_raw_buffer_store_b128(
|
||||
bit_cast<mbuf_t>(value), cast_to_amdgpu_buffer_rsrc_t(res), v_offset, s_offset, 0);
|
||||
#else
|
||||
asm volatile("buffer_store_dwordx4 %0, %1, %2, 0 offen offset:%3"
|
||||
:
|
||||
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "n"(i_offset)
|
||||
: "memory");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@@ -403,10 +497,16 @@ struct buffer_store<8>
|
||||
{
|
||||
static_assert(sizeof(T) == 8);
|
||||
using mbuf_t = fp32x2_t;
|
||||
#if HAS_RAW_BUFFER_BUILTINS
|
||||
index_t s_offset = i_offset;
|
||||
__builtin_amdgcn_raw_buffer_store_b64(
|
||||
bit_cast<mbuf_t>(value), cast_to_amdgpu_buffer_rsrc_t(res), v_offset, s_offset, 0);
|
||||
#else
|
||||
asm volatile("buffer_store_dwordx2 %0, %1, %2, 0 offen offset:%3"
|
||||
:
|
||||
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "n"(i_offset)
|
||||
: "memory");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@@ -423,10 +523,16 @@ struct buffer_store<4>
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
using mbuf_t = float;
|
||||
#if HAS_RAW_BUFFER_BUILTINS
|
||||
index_t s_offset = i_offset;
|
||||
__builtin_amdgcn_raw_buffer_store_b32(
|
||||
bit_cast<mbuf_t>(value), cast_to_amdgpu_buffer_rsrc_t(res), v_offset, s_offset, 0);
|
||||
#else
|
||||
asm volatile("buffer_store_dword %0, %1, %2, 0 offen offset:%3"
|
||||
:
|
||||
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "n"(i_offset)
|
||||
: "memory");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@@ -443,10 +549,16 @@ struct buffer_store<2>
|
||||
{
|
||||
static_assert(sizeof(T) == 2);
|
||||
using mbuf_t = short;
|
||||
#if HAS_RAW_BUFFER_BUILTINS
|
||||
index_t s_offset = i_offset;
|
||||
__builtin_amdgcn_raw_buffer_store_b16(
|
||||
bit_cast<mbuf_t>(value), cast_to_amdgpu_buffer_rsrc_t(res), v_offset, s_offset, 0);
|
||||
#else
|
||||
asm volatile("buffer_store_short %0, %1, %2, 0 offen offset:%3"
|
||||
:
|
||||
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "n"(i_offset)
|
||||
: "memory");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@@ -463,16 +575,38 @@ struct buffer_store<1>
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
using mbuf_t = float;
|
||||
#if HAS_RAW_BUFFER_BUILTINS
|
||||
index_t s_offset = i_offset;
|
||||
__builtin_amdgcn_raw_buffer_store_b8(
|
||||
bit_cast<mbuf_t>(value), cast_to_amdgpu_buffer_rsrc_t(res), v_offset, s_offset, 0);
|
||||
#else
|
||||
asm volatile("buffer_store_byte %0, %1, %2, 0 offen offset:%3"
|
||||
:
|
||||
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "n"(i_offset)
|
||||
: "memory");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
#if HAS_RAW_BUFFER_BUILTINS
|
||||
template <index_t bytes>
|
||||
struct buffer_store_if;
|
||||
|
||||
struct buffer_store_if
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE void operator()(const T& value,
|
||||
int32x4_t res /*buffer resource*/,
|
||||
index_t v_offset,
|
||||
index_t s_offset,
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t flag = 1)
|
||||
{
|
||||
if LIKELY(1 <= flag)
|
||||
{
|
||||
buffer_store<bytes>{}(value, res, v_offset, s_offset, i_offset);
|
||||
}
|
||||
}
|
||||
};
|
||||
#else
|
||||
template <>
|
||||
struct buffer_store_if<16>
|
||||
{
|
||||
@@ -613,6 +747,7 @@ struct buffer_store_if<1>
|
||||
: "memory");
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0)
|
||||
{
|
||||
@@ -1134,7 +1269,7 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
|
||||
// Direct loads from global to LDS.
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
|
||||
__attribute__((address_space(3))) uint32_t* lds_ptr,
|
||||
as3_uint32_ptr lds_ptr,
|
||||
index_t size,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
@@ -1179,6 +1314,17 @@ enum struct amd_buffer_coherence_enum
|
||||
glc = 1,
|
||||
slc = 2,
|
||||
glc_slc = 3,
|
||||
// gfx94: bit 0 = sc0, bit 1 = nt, bit 3 = swz, bit 4 = sc1
|
||||
// SC[1:0] System Cache level: 0=wave, 1=group, 2=device, 3=system
|
||||
// NT Non-Temporal: 0=expect temporal reuse; 1=do not expect temporal reuse
|
||||
WAVE_NT0 = 0,
|
||||
WAVE_NT1 = 2,
|
||||
GROUP_NT0 = 1,
|
||||
GROUP_NT1 = 3,
|
||||
DEVICE_NT0 = 8,
|
||||
DEVICE_NT1 = 10,
|
||||
SYSTEM_NT0 = 9,
|
||||
SYSTEM_NT1 = 11,
|
||||
};
|
||||
|
||||
template <index_t N,
|
||||
@@ -1301,8 +1447,10 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
|
||||
static_assert(
|
||||
(std::is_same<T, double>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, fp16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, bf16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, fp16_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) ||
|
||||
(std::is_same<T, bf16_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) ||
|
||||
(std::is_same<T, int32_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
@@ -1425,6 +1573,54 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
|
||||
|
||||
return bit_cast<rtn_type>(tmp);
|
||||
}
|
||||
else if constexpr(N == 16)
|
||||
{
|
||||
thread_buffer<float, 8> tmp;
|
||||
|
||||
tmp.template get_as<fp32x4_t>()(number<0>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
tmp.template get_as<fp32x4_t>()(number<1>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 4 * sizeof(float),
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
return bit_cast<rtn_type>(tmp);
|
||||
}
|
||||
else if constexpr(N == 32)
|
||||
{
|
||||
thread_buffer<float, 16> tmp;
|
||||
|
||||
tmp.template get_as<fp32x4_t>()(number<0>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
tmp.template get_as<fp32x4_t>()(number<1>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 4 * sizeof(float),
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
tmp.template get_as<fp32x4_t>()(number<2>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 8 * sizeof(float),
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
tmp.template get_as<fp32x4_t>()(number<3>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 12 * sizeof(float),
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
return bit_cast<rtn_type>(tmp);
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, bf16_t>::value) // bf16
|
||||
{
|
||||
@@ -1461,6 +1657,54 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
|
||||
|
||||
return bit_cast<rtn_type>(tmp);
|
||||
}
|
||||
else if constexpr(N == 16)
|
||||
{
|
||||
thread_buffer<float, 8> tmp;
|
||||
|
||||
tmp.template get_as<fp32x4_t>()(number<0>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
tmp.template get_as<fp32x4_t>()(number<1>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 4 * sizeof(float),
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
return bit_cast<rtn_type>(tmp);
|
||||
}
|
||||
else if constexpr(N == 32)
|
||||
{
|
||||
thread_buffer<float, 16> tmp;
|
||||
|
||||
tmp.template get_as<fp32x4_t>()(number<0>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
tmp.template get_as<fp32x4_t>()(number<1>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 4 * sizeof(float),
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
tmp.template get_as<fp32x4_t>()(number<2>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 8 * sizeof(float),
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
tmp.template get_as<fp32x4_t>()(number<3>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 12 * sizeof(float),
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
return bit_cast<rtn_type>(tmp);
|
||||
}
|
||||
}
|
||||
else // other datatype
|
||||
{
|
||||
@@ -1515,7 +1759,7 @@ template <typename T,
|
||||
index_t N,
|
||||
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem,
|
||||
CK_TILE_DEVICE void amd_async_buffer_load_impl(CK_TILE_LDS_ADDR T* smem,
|
||||
int32x4_t src_wave_buffer_resource,
|
||||
index_t src_thread_addr_offset,
|
||||
index_t src_wave_addr_offset,
|
||||
@@ -1545,29 +1789,35 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
|
||||
index_t flag = 0,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size");
|
||||
constexpr index_t bytes = sizeof(T) * N;
|
||||
|
||||
// Used to catch the cases when src_immediate_addr_offset is NOT 0.
|
||||
// Remove this assert once other sizes are implemented.
|
||||
assert(src_immediate_addr_offset == 0 &&
|
||||
"wrong! not implemented src_immediate_addr_offset size, only 0 supported");
|
||||
ignore = src_immediate_addr_offset;
|
||||
|
||||
#if defined(__gfx950__)
|
||||
static_assert(bytes == 4 || bytes == 12 || bytes == 16,
|
||||
"wrong! only support in dword, dwordx3, dwordx4");
|
||||
src_wave_addr_offset = 0;
|
||||
#else
|
||||
static_assert(bytes == 4, "wrong! not implemented vector size");
|
||||
#endif
|
||||
|
||||
// Set up v_offset:
|
||||
index_t v_offset = src_thread_addr_offset;
|
||||
if constexpr(oob_conditional_check)
|
||||
{
|
||||
index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2];
|
||||
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
|
||||
smem,
|
||||
sizeof(uint32_t),
|
||||
v_offset,
|
||||
src_wave_addr_offset,
|
||||
src_immediate_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
|
||||
smem,
|
||||
sizeof(uint32_t),
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
src_immediate_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
v_offset = flag ? v_offset : src_wave_buffer_resource[2];
|
||||
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_wave_buffer_resource,
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
|
||||
bytes,
|
||||
v_offset,
|
||||
src_wave_addr_offset,
|
||||
/*src_immediate_addr_offset*/ 0,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
|
||||
template <index_t N,
|
||||
@@ -2511,44 +2761,45 @@ CK_TILE_DEVICE void amd_buffer_atomic_max(const thread_buffer<T, N>& src_thread_
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, index_t NumElemsPerThread>
|
||||
CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
|
||||
const index_t global_offset,
|
||||
T* lds_base_ptr,
|
||||
const index_t lds_offset,
|
||||
const bool is_valid,
|
||||
const index_t src_element_space_size)
|
||||
#if defined(__gfx950__)
|
||||
template <typename T, index_t N, address_space_enum BufferAddressSpace>
|
||||
__device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
|
||||
{
|
||||
// Direct loads require that each thread reads and writes exactly a single DWORD.
|
||||
constexpr auto dword_bytes = 4;
|
||||
constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread;
|
||||
static_assert(bytes_per_thread == dword_bytes);
|
||||
|
||||
const uint32_t* global_ptr =
|
||||
reinterpret_cast<uint32_t*>(reinterpret_cast<uintptr_t>(global_base_ptr));
|
||||
const int32x4_t src_resource =
|
||||
make_wave_buffer_resource(global_ptr, src_element_space_size * sizeof(T));
|
||||
const index_t global_offset_bytes = is_valid ? global_offset * sizeof(T) : 0x80000000;
|
||||
|
||||
#if CK_TILE_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM
|
||||
T* lds_ptr = lds_base_ptr + lds_offset;
|
||||
auto const lds_ptr_sgpr =
|
||||
__builtin_amdgcn_readfirstlane((reinterpret_cast<uintptr_t>(lds_ptr)));
|
||||
asm volatile("s_mov_b32 m0, %0; \n\t"
|
||||
"buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr),
|
||||
"v"(global_offset_bytes),
|
||||
"s"(src_resource)
|
||||
: "memory");
|
||||
#else
|
||||
// LDS pointer must be attributed with the LDS address space.
|
||||
__attribute__((address_space(3))) uint32_t* lds_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
|
||||
reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset));
|
||||
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0);
|
||||
#endif
|
||||
static_assert(__has_builtin(__builtin_amdgcn_raw_buffer_load_b32),
|
||||
"We need to have the compatible compiler version to build this instruction");
|
||||
if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::half_t>)
|
||||
{
|
||||
typedef __attribute__((__vector_size__(4 * sizeof(__fp16)))) __fp16 llvm_fp16x4_t;
|
||||
__attribute__((address_space(3))) llvm_fp16x4_t* lds_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(3))) llvm_fp16x4_t*>(
|
||||
reinterpret_cast<uintptr_t>(in_ptr));
|
||||
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4f16(lds_ptr));
|
||||
}
|
||||
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::bf16_t>)
|
||||
{
|
||||
typedef __attribute__((__vector_size__(4 * sizeof(__bf16)))) __bf16 llvm_bf16x4_t;
|
||||
__attribute__((address_space(3))) llvm_bf16x4_t* lds_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(3))) llvm_bf16x4_t*>(
|
||||
reinterpret_cast<uintptr_t>(in_ptr));
|
||||
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4bf16(lds_ptr));
|
||||
}
|
||||
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::fp8_t> ||
|
||||
std::is_same_v<remove_cvref_t<T>, ck_tile::bf8_t> ||
|
||||
std::is_same_v<remove_cvref_t<T>, ck_tile::int8_t>)
|
||||
{
|
||||
typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_i32x2_t;
|
||||
__attribute__((address_space(3))) llvm_i32x2_t* lds_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(3))) llvm_i32x2_t*>(
|
||||
reinterpret_cast<uintptr_t>(in_ptr));
|
||||
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr8_b64_v2i32(lds_ptr));
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "not implemented");
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
|
||||
@@ -13,6 +13,9 @@
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/ignore.hpp"
|
||||
|
||||
using as3_uint32_ptr = uint32_t __attribute__((address_space(3)))*;
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -29,10 +32,6 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz
|
||||
{
|
||||
buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD};
|
||||
int32x4_t r = __builtin_bit_cast(int32x4_t, res);
|
||||
r.x = __builtin_amdgcn_readfirstlane(r.x);
|
||||
r.y = __builtin_amdgcn_readfirstlane(r.y);
|
||||
r.z = __builtin_amdgcn_readfirstlane(r.z);
|
||||
r.w = __builtin_amdgcn_readfirstlane(r.w);
|
||||
return r;
|
||||
}
|
||||
|
||||
@@ -881,95 +880,95 @@ CK_TILE_DEVICE_EXTERN int8_t
|
||||
llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i8.v4i32");
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i8");
|
||||
|
||||
CK_TILE_DEVICE_EXTERN int8x2_t
|
||||
llvm_amdgcn_raw_buffer_load_i8x2(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i8.v4i32");
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i8");
|
||||
|
||||
CK_TILE_DEVICE_EXTERN int8x4_t
|
||||
llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8.v4i32");
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8");
|
||||
|
||||
// buffer load i16
|
||||
CK_TILE_DEVICE_EXTERN int16_t
|
||||
llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i16.v4i32");
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i16");
|
||||
|
||||
CK_TILE_DEVICE_EXTERN int16x2_t
|
||||
llvm_amdgcn_raw_buffer_load_i16x2(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i16.v4i32");
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i16");
|
||||
|
||||
CK_TILE_DEVICE_EXTERN int16x4_t
|
||||
llvm_amdgcn_raw_buffer_load_i16x4(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i16.v4i32");
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i16");
|
||||
|
||||
// buffer load i32
|
||||
CK_TILE_DEVICE_EXTERN int32_t
|
||||
llvm_amdgcn_raw_buffer_load_i32(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32.v4i32");
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32");
|
||||
|
||||
CK_TILE_DEVICE_EXTERN int32x2_t
|
||||
llvm_amdgcn_raw_buffer_load_i32x2(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i32.v4i32");
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i32");
|
||||
|
||||
CK_TILE_DEVICE_EXTERN int32x4_t
|
||||
llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i32.v4i32");
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i32");
|
||||
|
||||
// buffer load fp16
|
||||
CK_TILE_DEVICE_EXTERN _Float16
|
||||
llvm_amdgcn_raw_buffer_load_fp16(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16.v4i32");
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16");
|
||||
|
||||
CK_TILE_DEVICE_EXTERN fp16x2_t llvm_amdgcn_raw_buffer_load_fp16x2(
|
||||
int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16.v4i32");
|
||||
CK_TILE_DEVICE_EXTERN fp16x2_t
|
||||
llvm_amdgcn_raw_buffer_load_fp16x2(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16");
|
||||
|
||||
CK_TILE_DEVICE_EXTERN fp16x4_t llvm_amdgcn_raw_buffer_load_fp16x4(
|
||||
int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f16.v4i32");
|
||||
CK_TILE_DEVICE_EXTERN fp16x4_t
|
||||
llvm_amdgcn_raw_buffer_load_fp16x4(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f16");
|
||||
|
||||
// buffer load fp32
|
||||
CK_TILE_DEVICE_EXTERN float
|
||||
llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f32.v4i32");
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f32");
|
||||
|
||||
CK_TILE_DEVICE_EXTERN fp32x2_t llvm_amdgcn_raw_buffer_load_fp32x2(
|
||||
int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f32.v4i32");
|
||||
CK_TILE_DEVICE_EXTERN fp32x2_t
|
||||
llvm_amdgcn_raw_buffer_load_fp32x2(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f32");
|
||||
|
||||
CK_TILE_DEVICE_EXTERN fp32x4_t llvm_amdgcn_raw_buffer_load_fp32x4(
|
||||
int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f32.v4i32");
|
||||
CK_TILE_DEVICE_EXTERN fp32x4_t
|
||||
llvm_amdgcn_raw_buffer_load_fp32x4(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f32");
|
||||
|
||||
// buffer store i8
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
@@ -977,21 +976,21 @@ llvm_amdgcn_raw_buffer_store_i8(int8_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i8.v4i32");
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i8");
|
||||
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_i8x2(int8x2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i8.v4i32");
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i8");
|
||||
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i8.v4i32");
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i8");
|
||||
|
||||
// buffer store i16
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
@@ -999,21 +998,21 @@ llvm_amdgcn_raw_buffer_store_i16(int16_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16.v4i32");
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16");
|
||||
|
||||
CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_i16x2(
|
||||
int16x2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16.v4i32");
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_i16x2(int16x2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16");
|
||||
|
||||
CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_i16x4(
|
||||
int16x4_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i16.v4i32");
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_i16x4(int16x4_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i16");
|
||||
|
||||
// buffer store i32
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
@@ -1021,7 +1020,7 @@ llvm_amdgcn_raw_buffer_store_i32(int32_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i32.v4i32");
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i32");
|
||||
|
||||
// buffer store ui16
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
@@ -1029,35 +1028,35 @@ llvm_amdgcn_raw_buffer_store_ui16(uint16_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16.v4i32");
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16");
|
||||
|
||||
CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_ui16x2(
|
||||
uint16x2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16.v4i32");
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_ui16x2(uint16x2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16");
|
||||
|
||||
CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_ui16x4(
|
||||
uint16x4_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i16.v4i32");
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_ui16x4(uint16x4_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i16");
|
||||
|
||||
CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_i32x2(
|
||||
int32x2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i32.v4i32");
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_i32x2(int32x2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i32");
|
||||
|
||||
CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_i32x4(
|
||||
int32x4_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i32.v4i32");
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i32");
|
||||
|
||||
// buffer store fp16
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
@@ -1065,21 +1064,21 @@ llvm_amdgcn_raw_buffer_store_fp16(_Float16 vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f16.v4i32");
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f16");
|
||||
|
||||
CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_fp16x2(
|
||||
fp16x2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f16.v4i32");
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_fp16x2(fp16x2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f16");
|
||||
|
||||
CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_fp16x4(
|
||||
fp16x4_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f16.v4i32");
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_fp16x4(fp16x4_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f16");
|
||||
|
||||
// buffer store fp32
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
@@ -1087,21 +1086,21 @@ llvm_amdgcn_raw_buffer_store_fp32(float vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f32.v4i32");
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f32");
|
||||
|
||||
CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_fp32x2(
|
||||
fp32x2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f32.v4i32");
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_fp32x2(fp32x2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f32");
|
||||
|
||||
CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_fp32x4(
|
||||
fp32x4_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32.v4i32");
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_fp32x4(fp32x4_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32");
|
||||
|
||||
// buffer atomic-add fp16
|
||||
CK_TILE_DEVICE_EXTERN fp16x2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2(
|
||||
@@ -1109,7 +1108,7 @@ CK_TILE_DEVICE_EXTERN fp16x2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2(
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16.v4i32");
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16");
|
||||
|
||||
// buffer atomic-add i32
|
||||
CK_TILE_DEVICE_EXTERN int32_t llvm_amdgcn_raw_buffer_atomic_add_i32(
|
||||
@@ -1117,7 +1116,7 @@ CK_TILE_DEVICE_EXTERN int32_t llvm_amdgcn_raw_buffer_atomic_add_i32(
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.add.i32.v4i32");
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.add.i32");
|
||||
|
||||
// buffer atomic-add fp32
|
||||
CK_TILE_DEVICE_EXTERN float llvm_amdgcn_raw_buffer_atomic_add_fp32(
|
||||
@@ -1125,25 +1124,25 @@ CK_TILE_DEVICE_EXTERN float llvm_amdgcn_raw_buffer_atomic_add_fp32(
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32.v4i32");
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32");
|
||||
|
||||
// buffer atomic-max fp64
|
||||
CK_TILE_DEVICE_EXTERN double llvm_amdgcn_raw_buffer_atomic_max_fp64(
|
||||
double vdata,
|
||||
int32x4_t rsrc, // dst_wave_buffer_resource
|
||||
int voffset, // dst_thread_addr_offset
|
||||
int soffset, // dst_wave_addr_offset
|
||||
int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64.v4i32");
|
||||
CK_TILE_DEVICE_EXTERN double
|
||||
llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
|
||||
int32x4_t rsrc, // dst_wave_buffer_resource
|
||||
int voffset, // dst_thread_addr_offset
|
||||
int soffset, // dst_wave_addr_offset
|
||||
int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64");
|
||||
|
||||
// Direct loads from global to LDS.
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
|
||||
__attribute__((address_space(3))) uint32_t* lds_ptr,
|
||||
as3_uint32_ptr lds_ptr,
|
||||
index_t size,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t offset,
|
||||
index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds.v4i32");
|
||||
index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds");
|
||||
|
||||
template <bool pre_nop = false>
|
||||
CK_TILE_DEVICE void async_buffer_load_dword_v(void* smem,
|
||||
@@ -1183,6 +1182,17 @@ enum struct amd_buffer_coherence_enum
|
||||
glc = 1,
|
||||
slc = 2,
|
||||
glc_slc = 3,
|
||||
// gfx94: bit 0 = sc0, bit 1 = nt, bit 3 = swz, bit 4 = sc1
|
||||
// SC[1:0] System Cache level: 0=wave, 1=group, 2=device, 3=system
|
||||
// NT Non-Temporal: 0=expect temporal reuse; 1=do not expect temporal reuse
|
||||
WAVE_NT0 = 0,
|
||||
WAVE_NT1 = 2,
|
||||
GROUP_NT0 = 1,
|
||||
GROUP_NT1 = 3,
|
||||
DEVICE_NT0 = 8,
|
||||
DEVICE_NT1 = 10,
|
||||
SYSTEM_NT0 = 9,
|
||||
SYSTEM_NT1 = 11,
|
||||
};
|
||||
|
||||
template <index_t N,
|
||||
@@ -1549,29 +1559,35 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
|
||||
index_t flag = 0,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size");
|
||||
constexpr index_t bytes = sizeof(T) * N;
|
||||
|
||||
// Used to catch the cases when src_immediate_addr_offset is NOT 0.
|
||||
// Remove this assert once other sizes are implemented.
|
||||
assert(src_immediate_addr_offset == 0 &&
|
||||
"wrong! not implemented src_immediate_addr_offset size, only 0 supported");
|
||||
ignore = src_immediate_addr_offset;
|
||||
|
||||
#if defined(__gfx950__)
|
||||
static_assert(bytes == 4 || bytes == 12 || bytes == 16,
|
||||
"wrong! only support in dword, dwordx3, dwordx4");
|
||||
src_wave_addr_offset = 0;
|
||||
#else
|
||||
static_assert(bytes == 4, "wrong! not implemented vector size");
|
||||
#endif
|
||||
|
||||
// Set up v_offset:
|
||||
index_t v_offset = src_thread_addr_offset;
|
||||
if constexpr(oob_conditional_check)
|
||||
{
|
||||
index_t v_offset = flag ? v_offset : src_wave_buffer_resource[2];
|
||||
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
|
||||
smem,
|
||||
sizeof(uint32_t),
|
||||
v_offset,
|
||||
src_wave_addr_offset,
|
||||
src_immediate_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
|
||||
smem,
|
||||
sizeof(uint32_t),
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
src_immediate_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
v_offset = flag ? v_offset : src_wave_buffer_resource[2];
|
||||
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_wave_buffer_resource,
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
|
||||
bytes,
|
||||
v_offset,
|
||||
src_wave_addr_offset,
|
||||
/*src_immediate_addr_offset*/ 0,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
|
||||
template <index_t N,
|
||||
@@ -2523,11 +2539,6 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
|
||||
const bool is_valid,
|
||||
const index_t src_element_space_size)
|
||||
{
|
||||
// Direct loads require that each thread reads and writes exactly a single DWORD.
|
||||
constexpr auto dword_bytes = 4;
|
||||
constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread;
|
||||
static_assert(bytes_per_thread == dword_bytes);
|
||||
|
||||
const uint32_t* global_ptr =
|
||||
reinterpret_cast<uint32_t*>(reinterpret_cast<uintptr_t>(global_base_ptr));
|
||||
const int32x4_t src_resource =
|
||||
@@ -2544,16 +2555,70 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
|
||||
"s"(src_resource)
|
||||
: "memory");
|
||||
#else
|
||||
// Direct loads require that each thread reads and writes exactly a single DWORD.
|
||||
#if defined(__gfx9__)
|
||||
constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread;
|
||||
#endif
|
||||
// Direct loads require that each thread reads and writes a multiple of DWORDs (4 bytes).
|
||||
// For gfx950: supports 1, 3, or 4 DWORDs per thread
|
||||
// For gfx942: supports exactly 1 DWORD per thread
|
||||
#if defined(__gfx950__)
|
||||
constexpr auto dword_bytes = 4;
|
||||
static_assert(bytes_per_thread == dword_bytes || bytes_per_thread == dword_bytes * 3 ||
|
||||
bytes_per_thread == dword_bytes * 4);
|
||||
#elif defined(__gfx9__)
|
||||
constexpr auto dword_bytes = 4;
|
||||
static_assert(bytes_per_thread == dword_bytes);
|
||||
#endif
|
||||
// LDS pointer must be attributed with the LDS address space.
|
||||
__attribute__((address_space(3))) uint32_t* lds_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
|
||||
reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset));
|
||||
as3_uint32_ptr lds_ptr =
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset));
|
||||
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0);
|
||||
src_resource, lds_ptr, bytes_per_thread, global_offset_bytes, 0, 0, 0);
|
||||
#endif
|
||||
}
|
||||
|
||||
#if defined(__gfx950__)
|
||||
template <typename T, index_t N, address_space_enum BufferAddressSpace>
|
||||
__device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
|
||||
{
|
||||
|
||||
static_assert(__has_builtin(__builtin_amdgcn_raw_buffer_load_b32),
|
||||
"We need to have the compatible compiler version to build this instruction");
|
||||
if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::half_t>)
|
||||
{
|
||||
typedef __attribute__((__vector_size__(4 * sizeof(__fp16)))) __fp16 llvm_fp16x4_t;
|
||||
__attribute__((address_space(3))) llvm_fp16x4_t* lds_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(3))) llvm_fp16x4_t*>(
|
||||
reinterpret_cast<uintptr_t>(in_ptr));
|
||||
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4f16(lds_ptr));
|
||||
}
|
||||
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::bf16_t>)
|
||||
{
|
||||
typedef __attribute__((__vector_size__(4 * sizeof(__bf16)))) __bf16 llvm_bf16x4_t;
|
||||
__attribute__((address_space(3))) llvm_bf16x4_t* lds_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(3))) llvm_bf16x4_t*>(
|
||||
reinterpret_cast<uintptr_t>(in_ptr));
|
||||
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4bf16(lds_ptr));
|
||||
}
|
||||
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::fp8_t> ||
|
||||
std::is_same_v<remove_cvref_t<T>, ck_tile::bf8_t> ||
|
||||
std::is_same_v<remove_cvref_t<T>, ck_tile::int8_t>)
|
||||
{
|
||||
typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_i32x2_t;
|
||||
__attribute__((address_space(3))) llvm_i32x2_t* lds_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(3))) llvm_i32x2_t*>(
|
||||
reinterpret_cast<uintptr_t>(in_ptr));
|
||||
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr8_b64_v2i32(lds_ptr));
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "not implemented");
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
#endif // CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
|
||||
|
||||
88
include/ck_tile/core/arch/amd_transpose_load_encoding.hpp
Normal file
88
include/ck_tile/core/arch/amd_transpose_load_encoding.hpp
Normal file
@@ -0,0 +1,88 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// this generate wave level tile distribution
|
||||
template <typename T, index_t LaneGroupSize = 16, typename = void>
|
||||
struct LaneGroupTransposeTraits;
|
||||
|
||||
template <typename T, index_t LaneGroupSize>
|
||||
struct LaneGroupTransposeTraits<T, LaneGroupSize, std::enable_if_t<sizeof(T) == 2>>
|
||||
{
|
||||
static_assert(LaneGroupSize == 16 || LaneGroupSize == 32 || LaneGroupSize == 64,
|
||||
"LaneGroupSize must be 16, 32, or 64");
|
||||
// before transpose, 4x16
|
||||
static constexpr index_t ksecondDim = 4;
|
||||
static constexpr index_t kleadDim = LaneGroupSize;
|
||||
// after transpose, 16x4
|
||||
static constexpr index_t ksecondDimT = LaneGroupSize;
|
||||
static constexpr index_t kleadDimT = 4;
|
||||
template <index_t kOuterDistDim0,
|
||||
index_t kOuterDistDim1,
|
||||
index_t kInnerDistDim0,
|
||||
index_t kInnerDistDim1>
|
||||
using TileDistribution = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<kOuterDistDim0, kOuterDistDim1, 4>,
|
||||
sequence<kInnerDistDim0, kInnerDistDim1, LaneGroupSize / 16, 4, 4>>,
|
||||
tuple<sequence<1, 2, 2, 1, 2>>,
|
||||
tuple<sequence<0, 0, 2, 2, 3>>,
|
||||
sequence<2, 1, 2>,
|
||||
sequence<1, 1, 4>>;
|
||||
};
|
||||
|
||||
template <typename T, index_t LaneGroupSize>
|
||||
struct LaneGroupTransposeTraits<T, LaneGroupSize, std::enable_if_t<sizeof(T) == 1>>
|
||||
{
|
||||
static constexpr index_t ksecondDim = 8;
|
||||
static constexpr index_t kleadDim = LaneGroupSize;
|
||||
|
||||
static constexpr index_t ksecondDimT = LaneGroupSize;
|
||||
static constexpr index_t kleadDimT = 8;
|
||||
|
||||
template <index_t kOuterDistDim0,
|
||||
index_t kOuterDistDim1,
|
||||
index_t kInnerDistDim0,
|
||||
index_t kInnerDistDim1>
|
||||
using TileDistribution = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<kOuterDistDim0, kOuterDistDim1, 8>,
|
||||
sequence<kInnerDistDim0, kInnerDistDim1, LaneGroupSize / 16, 2, 8>>,
|
||||
tuple<sequence<1, 2, 2, 1, 2>>,
|
||||
tuple<sequence<0, 0, 2, 2, 3>>,
|
||||
sequence<2, 1, 2>,
|
||||
sequence<1, 1, 4>>;
|
||||
};
|
||||
|
||||
/*
|
||||
* @brief This function is used to generate the transposed distribution encoding
|
||||
* for the given data type and distribution dimensions.
|
||||
*
|
||||
* @tparam T The data type of the elements in the tensor.
|
||||
* @tparam kOuterDistDim0 The outer distribution dimension 0, which is outer dimension for stride.
|
||||
* @tparam kOuterDistDim1 The outer distribution dimension 1, which is inner dimension for stride.
|
||||
* @tparam kInnerDistDim0 The inner distribution dimension 0, which is outer dimension for
|
||||
* consecutive.
|
||||
* @tparam kInnerDistDim1 The inner distribution dimension 1, which is inner dimension for
|
||||
* consecutive.
|
||||
*/
|
||||
template <typename T,
|
||||
index_t LaneGroupSize,
|
||||
index_t kOuterDistDim0,
|
||||
index_t kOuterDistDim1,
|
||||
index_t kInnerDistDim0,
|
||||
index_t kInnerDistDim1>
|
||||
CK_TILE_DEVICE constexpr auto make_transposed_distr_encode()
|
||||
{
|
||||
return typename LaneGroupTransposeTraits<T, LaneGroupSize>::
|
||||
template TileDistribution<kOuterDistDim0, kOuterDistDim1, kInnerDistDim0, kInnerDistDim1>{};
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -9,6 +9,16 @@
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/utility/ignore.hpp"
|
||||
|
||||
#define CK_TILE_S_CNT_MAX 0b1100'1111'0111'1111
|
||||
#define CK_TILE_VMCNT(cnt) \
|
||||
([]() { static_assert(!((cnt) >> 6), "VMCNT only has 6 bits"); }(), \
|
||||
((cnt) & 0b1111) | (((cnt) & 0b110000) << 10))
|
||||
#define CK_TILE_EXPCNT(cnt) \
|
||||
([]() { static_assert(!((cnt) >> 3), "EXP only has 3 bits"); }(), ((cnt) << 4))
|
||||
#define CK_TILE_LGKMCNT(cnt) \
|
||||
([]() { static_assert(!((cnt) >> 4), "LGKM only has 4 bits"); }(), ((cnt) << 8))
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -50,8 +60,11 @@ enum struct memory_operation_enum : std::uint16_t
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
|
||||
{
|
||||
// warpSize is defined by HIP
|
||||
return warpSize;
|
||||
#if defined(__GFX9__) || (!defined(__HIP_DEVICE_COMPILE__) && !defined(CK_TILE_WAVE32_ENABLED))
|
||||
return 64;
|
||||
#else
|
||||
return 32;
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE index_t get_grid_size() { return gridDim.x; }
|
||||
@@ -81,21 +94,6 @@ CK_TILE_DEVICE index_t get_thread_id() { return threadIdx.x; }
|
||||
|
||||
CK_TILE_DEVICE index_t get_block_id() { return blockIdx.x; }
|
||||
|
||||
CK_TILE_DEVICE void block_sync_lds()
|
||||
{
|
||||
#if CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
|
||||
// asm volatile("\
|
||||
// s_waitcnt lgkmcnt(0) \n \
|
||||
// s_barrier \
|
||||
// " ::);
|
||||
|
||||
__builtin_amdgcn_s_waitcnt(0xc07f);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
#else
|
||||
__syncthreads();
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void block_sync_load_raw(index_t cnt = 0)
|
||||
{
|
||||
#ifdef __gfx12__
|
||||
@@ -114,13 +112,68 @@ CK_TILE_DEVICE void block_sync_load_raw(index_t cnt = 0)
|
||||
#endif
|
||||
}
|
||||
|
||||
// https://llvm.org/docs/AMDGPU/gfx9_waitcnt.html
|
||||
struct waitcnt_arg
|
||||
{
|
||||
// bit numbers (hex) -------------------------> FE'DC'BA98'7'654'3210
|
||||
// [V]M [E]XP [L]GKM counters and [U]NUSED ---> VV'UU'LLLL'U'EEE'VVVV
|
||||
CK_TILE_DEVICE static constexpr index_t MAX = 0b11'00'1111'0'111'1111;
|
||||
|
||||
CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0b111111;
|
||||
CK_TILE_DEVICE static constexpr index_t kMaxExpCnt = 0b111;
|
||||
CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0b1111;
|
||||
|
||||
template <index_t cnt>
|
||||
CK_TILE_DEVICE static constexpr index_t from_vmcnt()
|
||||
{
|
||||
static_assert(cnt >= 0 && !(cnt >> 6), "valid range is [0..63]");
|
||||
return MAX & ((cnt & 0b1111) | ((cnt & 0b110000) << 10));
|
||||
}
|
||||
|
||||
template <index_t cnt>
|
||||
CK_TILE_DEVICE static constexpr index_t from_expcnt()
|
||||
{
|
||||
static_assert(cnt >= 0 && !(cnt >> 3), "valid range is [0..7]");
|
||||
return MAX & (cnt << 4);
|
||||
}
|
||||
|
||||
template <index_t cnt>
|
||||
CK_TILE_DEVICE static constexpr index_t from_lgkmcnt()
|
||||
{
|
||||
static_assert(cnt >= 0 && !(cnt >> 4), "valid range is [0..15]");
|
||||
return MAX & (cnt << 8);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t vmcnt = waitcnt_arg::kMaxVmCnt,
|
||||
index_t expcnt = waitcnt_arg::kMaxExpCnt,
|
||||
index_t lgkmcnt = waitcnt_arg::kMaxLgkmCnt>
|
||||
CK_TILE_DEVICE void s_waitcnt()
|
||||
{
|
||||
__builtin_amdgcn_s_waitcnt(waitcnt_arg::from_vmcnt<vmcnt>() |
|
||||
waitcnt_arg::from_expcnt<expcnt>() |
|
||||
waitcnt_arg::from_lgkmcnt<lgkmcnt>());
|
||||
}
|
||||
|
||||
template <index_t vmcnt = waitcnt_arg::kMaxVmCnt,
|
||||
index_t expcnt = waitcnt_arg::kMaxExpCnt,
|
||||
index_t lgkmcnt = waitcnt_arg::kMaxLgkmCnt>
|
||||
CK_TILE_DEVICE void s_waitcnt_barrier()
|
||||
{
|
||||
s_waitcnt<vmcnt, expcnt, lgkmcnt>();
|
||||
__builtin_amdgcn_s_barrier();
|
||||
}
|
||||
|
||||
template <index_t lgkmcnt = 0>
|
||||
CK_TILE_DEVICE void block_sync_lds()
|
||||
{
|
||||
s_waitcnt_barrier<waitcnt_arg::kMaxVmCnt, waitcnt_arg::kMaxExpCnt, lgkmcnt>();
|
||||
}
|
||||
|
||||
template <index_t vmcnt = 0>
|
||||
CK_TILE_DEVICE void block_sync_lds_direct_load()
|
||||
{
|
||||
asm volatile("\
|
||||
s_waitcnt vmcnt(0) \n \
|
||||
s_waitcnt lgkmcnt(0) \n \
|
||||
s_barrier \
|
||||
" ::);
|
||||
s_waitcnt_barrier<vmcnt, waitcnt_arg::kMaxExpCnt, waitcnt_arg::kMaxLgkmCnt>();
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void s_nop(index_t cnt = 0)
|
||||
@@ -158,4 +211,44 @@ __host__ __device__ T CK_CONSTANT_ADDRESS_SPACE* cast_pointer_to_constant_addres
|
||||
#pragma clang diagnostic pop
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr index_t get_smem_capacity()
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
return 163840;
|
||||
#else
|
||||
return 65536;
|
||||
#endif
|
||||
}
|
||||
|
||||
/// Helper function to convert address space enum to string
|
||||
CK_TILE_HOST_DEVICE constexpr const char* address_space_to_string(address_space_enum addr_space)
|
||||
{
|
||||
switch(addr_space)
|
||||
{
|
||||
case address_space_enum::generic: return "generic";
|
||||
case address_space_enum::global: return "global";
|
||||
case address_space_enum::lds: return "lds";
|
||||
case address_space_enum::sgpr: return "sgpr";
|
||||
case address_space_enum::constant: return "constant";
|
||||
case address_space_enum::vgpr: return "vgpr";
|
||||
default: return "unknown";
|
||||
}
|
||||
}
|
||||
|
||||
// Architecture tags
|
||||
struct gfx11_t
|
||||
{
|
||||
};
|
||||
struct gfx12_t
|
||||
{
|
||||
};
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto get_device_arch()
|
||||
{
|
||||
#if defined(__gfx11__)
|
||||
return gfx11_t{};
|
||||
#else // if defined(__gfx12__)
|
||||
return gfx12_t{};
|
||||
#endif
|
||||
}
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -6,6 +6,10 @@
|
||||
#include "ck_tile/core/numeric/type_convert.hpp"
|
||||
#include "ck_tile/core/container/thread_buffer.hpp"
|
||||
|
||||
#define HAS_GLOBAL_ATOMIC_PK_ADD_BUILTIN \
|
||||
__has_builtin(__builtin_amdgcn_global_atomic_fadd_v2f16) && \
|
||||
__has_builtin(__builtin_amdgcn_global_atomic_fadd_v2bf16)
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename T, typename ComputeType>
|
||||
@@ -32,6 +36,14 @@ CK_TILE_HOST_DEVICE bf16x4_t add_bf16x4_t(const bf16x4_t& a, const bf16x4_t& b)
|
||||
return rtn;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE fp16x2_t add_f16x2_t(const fp16x2_t& a, const fp16x2_t& b)
|
||||
{
|
||||
fp16x2_t rtn;
|
||||
rtn[0] = add<fp16_t, float>(a[0], b[0]);
|
||||
rtn[1] = add<fp16_t, float>(a[1], b[1]);
|
||||
return rtn;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE fp8x4_t add_fp8x4_t(const fp8x4_t& a, const fp8x4_t& b)
|
||||
{
|
||||
fp8x4_t rtn;
|
||||
@@ -304,6 +316,44 @@ CK_TILE_DEVICE void atomic_add<bf8x8_t>(bf8x8_t* p_dst, bf8x8_t const& x)
|
||||
} while(cur_v.u64 != old_v);
|
||||
}
|
||||
|
||||
//
|
||||
// Atomic add for fp16x2_t
|
||||
//
|
||||
template <>
|
||||
CK_TILE_DEVICE void atomic_add<fp16x2_t>(fp16x2_t* p_dst, fp16x2_t const& x)
|
||||
{
|
||||
#if HAS_GLOBAL_ATOMIC_PK_ADD_BUILTIN
|
||||
__builtin_amdgcn_global_atomic_fadd_v2f16(c_style_pointer_cast<fp16x2_t*>(p_dst), x);
|
||||
#else
|
||||
union U32F162_ADDR
|
||||
{
|
||||
uint32_t* u32_a;
|
||||
fp16x2_t* f162_a;
|
||||
};
|
||||
|
||||
union U32F162
|
||||
{
|
||||
uint32_t u32;
|
||||
fp16x2_t f162;
|
||||
};
|
||||
|
||||
U32F162_ADDR dword_addr;
|
||||
U32F162 cur_v;
|
||||
U32F162 new_;
|
||||
uint32_t old_v, new_v;
|
||||
dword_addr.f162_a = p_dst;
|
||||
cur_v.u32 = *dword_addr.u32_a;
|
||||
|
||||
do
|
||||
{
|
||||
old_v = cur_v.u32;
|
||||
new_.f162 = add_f16x2_t(cur_v.f162, x);
|
||||
new_v = new_.u32;
|
||||
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
|
||||
} while(cur_v.u32 != old_v);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, index_t N>
|
||||
CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
|
||||
{
|
||||
@@ -311,6 +361,7 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
|
||||
(std::is_same<T, uint32_t>::value && (N == 1)) ||
|
||||
(std::is_same<T, float>::value && (N == 1 || N == 2)) ||
|
||||
(std::is_same<T, double>::value && (N == 1 || N == 2)) ||
|
||||
(std::is_same<T, fp16_t>::value && (N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, bf16_t>::value && (N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, fp8_t>::value && (N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, bf8_t>::value && (N == 4 || N == 8 || N == 16)),
|
||||
@@ -406,6 +457,13 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
|
||||
atomic_add(c_style_pointer_cast<bf8x8_t*>(p_dst) + 1, x.template get_as<bf8x8_t>()[I1]);
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, fp16_t>::value)
|
||||
{
|
||||
static_for<0, N / 2, 1>{}([&](auto i) {
|
||||
atomic_add(c_style_pointer_cast<fp16x2_t*>(p_dst) + i,
|
||||
x.template get_as<fp16x2_t>()[i]);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, index_t N>
|
||||
|
||||
@@ -35,7 +35,7 @@ CK_TILE_DEVICE T warp_shuffle_up(const T& v_local, uint32_t lane_delta)
|
||||
#elif 1
|
||||
static_assert(sizeof(T) == sizeof(int32_t), "wrong!");
|
||||
|
||||
const uint32_t wrap_around_lane_delta = warpSize - lane_delta;
|
||||
const uint32_t wrap_around_lane_delta = get_warp_size() - lane_delta;
|
||||
|
||||
const int32_t v_remote_tmp = __builtin_amdgcn_ds_bpermute(
|
||||
(__lane_id() << 2) + (wrap_around_lane_delta << 2), bit_cast<int32_t>(v_local));
|
||||
@@ -59,6 +59,21 @@ CK_TILE_DEVICE T warp_shuffle_down(const T& v_local, uint32_t lane_delta)
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE auto warp_shuffle_down_pair(const T& v_local)
|
||||
{
|
||||
static_assert(sizeof(T) == sizeof(int32_t), "wrong!");
|
||||
|
||||
const int32x2_t x = __builtin_amdgcn_permlane32_swap(
|
||||
bit_cast<int32_t>(v_local), bit_cast<int32_t>(v_local), false, false);
|
||||
|
||||
thread_buffer<T, 2> v;
|
||||
v(0) = bit_cast<T>(x[0]);
|
||||
v(1) = bit_cast<T>(x[1]);
|
||||
|
||||
return v;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T warp_shuffle(const T& v_local, uint32_t src_lane)
|
||||
{
|
||||
|
||||
65
include/ck_tile/core/arch/workgroup_barrier.hpp
Normal file
65
include/ck_tile/core/arch/workgroup_barrier.hpp
Normal file
@@ -0,0 +1,65 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct workgroup_barrier
|
||||
{
|
||||
CK_TILE_DEVICE workgroup_barrier(uint32_t* ptr) : base_ptr(ptr) {}
|
||||
|
||||
CK_TILE_DEVICE uint32_t ld(uint32_t offset = 0)
|
||||
{
|
||||
return __atomic_load_n(base_ptr + offset, __ATOMIC_RELAXED);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void wait_eq(uint32_t value, uint32_t offset = 0)
|
||||
{
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
while(ld(offset) != value) {}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void wait_lt(uint32_t value, uint32_t offset = 0)
|
||||
{
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
while(ld(offset) < value) {}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void wait_set(uint32_t compare, uint32_t value, uint32_t offset = 0)
|
||||
{
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
while(atomicCAS(base_ptr + offset, compare, value) != compare) {}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// enter critical zoon, assume buffer is zero when launch kernel
|
||||
CK_TILE_DEVICE void aquire(uint32_t offset = 0) { wait_set(offset, 0, 1); }
|
||||
|
||||
// exit critical zoon, assume buffer is zero when launch kernel
|
||||
CK_TILE_DEVICE void release(uint32_t offset = 0) { wait_set(offset, 1, 0); }
|
||||
|
||||
CK_TILE_DEVICE void inc(uint32_t offset = 0)
|
||||
{
|
||||
__syncthreads();
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
atomicAdd(base_ptr + offset, 1);
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t* base_ptr;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -15,7 +15,8 @@
|
||||
#define __gfx103__
|
||||
#endif
|
||||
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \
|
||||
defined(__gfx1103__) || defined(__gfx11_generic__)
|
||||
defined(__gfx1103__) || defined(__gfx1150__) || defined(__gfx1151__) || \
|
||||
defined(__gfx1152__) || defined(__gfx11_generic__)
|
||||
#define __gfx11__
|
||||
#endif
|
||||
#if defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__)
|
||||
@@ -28,12 +29,6 @@
|
||||
#include "hip/hip_fp16.h"
|
||||
#endif
|
||||
|
||||
#include "ck_tile/core/utility/env.hpp"
|
||||
|
||||
// environment variable to enable logging:
|
||||
// export CK_TILE_LOGGING=ON or CK_TILE_LOGGING=1 or CK_TILE_LOGGING=ENABLED
|
||||
CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_LOGGING)
|
||||
|
||||
#ifdef __HIPCC__
|
||||
#define CK_TILE_HOST inline __host__
|
||||
#define CK_TILE_DEVICE inline __device__
|
||||
@@ -157,7 +152,7 @@ CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_LOGGING)
|
||||
// buffer atomic add: floating point
|
||||
#ifndef __HIP_DEVICE_COMPILE__ // for host code
|
||||
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
|
||||
#elif defined(__gfx9__) // for GPU code
|
||||
#elif defined(__gfx9__) || defined(__gfx12__) // for GPU code
|
||||
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
|
||||
#else // for GPU code
|
||||
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0
|
||||
@@ -196,6 +191,16 @@ CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_LOGGING)
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// use llvm builtin bf16 data type after ROCm 6.5
|
||||
#ifndef CK_TILE_USE_LLVM_BUILTIN_BF16
|
||||
#if(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 5 && HIP_VERSION_PATCH >= 50421) || \
|
||||
(HIP_VERSION_MAJOR >= 7)
|
||||
#define CK_TILE_USE_LLVM_BUILTIN_BF16 1
|
||||
#else
|
||||
#define CK_TILE_USE_LLVM_BUILTIN_BF16 0
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_DEBUG_LOG
|
||||
#define CK_TILE_DEBUG_LOG 0
|
||||
#endif
|
||||
@@ -228,6 +233,10 @@ CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_LOGGING)
|
||||
#define CK_TILE_FMHA_FWD_FAST_EXP2 0
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_FMHA_FLOAT_TO_FLOAT16_RTN
|
||||
#define CK_TILE_FMHA_FLOAT_TO_FLOAT16_RTN 0
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_BUFFER_LOAD_RAW_BF16_WA
|
||||
#define CK_TILE_BUFFER_LOAD_RAW_BF16_WA 1
|
||||
#endif
|
||||
@@ -241,22 +250,38 @@ CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_LOGGING)
|
||||
#define CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID 1
|
||||
#endif
|
||||
|
||||
#ifndef __HIP_DEVICE_COMPILE__ // for host code
|
||||
#ifdef CK_TILE_USE_OCP_FP8
|
||||
#ifndef CK_TILE_USE_OCP_FP8
|
||||
#if defined(__HIP_DEVICE_COMPILE__)
|
||||
#if defined(__gfx950__) || defined(__gfx12__)
|
||||
#define CK_TILE_USE_OCP_FP8 1
|
||||
#else
|
||||
#define CK_TILE_USE_OCP_FP8 0
|
||||
#endif
|
||||
#elif defined(__gfx950__) || defined(__gfx12__) // for GPU code
|
||||
#define CK_TILE_USE_OCP_FP8 1
|
||||
#else // for GPU code
|
||||
#else
|
||||
#define CK_TILE_USE_OCP_FP8 0
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
|
||||
#if __clang_major__ == 20
|
||||
#if __clang_major__ >= 20
|
||||
#define CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN 1
|
||||
#else
|
||||
#define CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN 0
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_WA_ISSUE_2028
|
||||
#define CK_TILE_WA_ISSUE_2028 0
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_WAVE32_ENABLED
|
||||
#if defined(__gfx11__) || defined(__gfx12__)
|
||||
#define CK_TILE_WAVE32_ENABLED
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// Y pointed to R, we don't see a valuable use case.
|
||||
// Will enforce encoding to check Y not pointed to R if set to zero
|
||||
#ifndef CK_TILE_ENC_SUPPORT_Y_TO_R
|
||||
#define CK_TILE_ENC_SUPPORT_Y_TO_R 0
|
||||
#endif
|
||||
|
||||
@@ -19,6 +19,25 @@ namespace ck_tile {
|
||||
// array<index_t, 4> buf {3, 2}; => {3, 2, 2, 2} (not {3,2,0,0})
|
||||
// use make_array_with({...}) to construct an array with compatible behavior as old ck
|
||||
// TODO: manually added constructor same as old ck
|
||||
/**
|
||||
* @brief A fixed-size array container similar to std::array with additional utilities.
|
||||
*
|
||||
* This template class provides a lightweight fixed-size array with value semantics,
|
||||
* supporting both host and device functionality for GPU programming. It includes
|
||||
* specialized initialization methods and type punning capabilities.
|
||||
*
|
||||
* @tparam T_ The type of elements in the array
|
||||
* @tparam N_ The fixed number of elements in the array
|
||||
*
|
||||
* @note This implementation provides additional features beyond std::array:
|
||||
* - GPU compatibility via CK_TILE_HOST_DEVICE macros
|
||||
* - Type punning via get_as() and set_as() methods
|
||||
* - Various specialized access methods
|
||||
* - Specialized initialization behaviors
|
||||
*
|
||||
* The initializer_list constructor fills remaining elements with the last value
|
||||
* provided if the list size is smaller than N, which is different than std::array.
|
||||
*/
|
||||
template <typename T_, index_t N_>
|
||||
struct array
|
||||
{
|
||||
@@ -142,6 +161,14 @@ struct array
|
||||
|
||||
// empty Array
|
||||
|
||||
/// @brief Specialization of array container for zero elements.
|
||||
///
|
||||
/// This is a specialization of the array container template for the case where the number of
|
||||
/// elements is 0. It provides the same interface as the general array template, but with operations
|
||||
/// appropriate for an empty array.
|
||||
///
|
||||
/// @tparam T The type of elements stored in the array (not used in this specialization but
|
||||
/// maintained for API consistency).
|
||||
template <typename T>
|
||||
struct array<T, 0>
|
||||
{
|
||||
@@ -150,9 +177,27 @@ struct array<T, 0>
|
||||
CK_TILE_HOST_DEVICE constexpr array() {}
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t size() { return 0; }
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return is_static_v<T>; };
|
||||
CK_TILE_HOST_DEVICE void print() const { printf("array{size: 0, data: []}"); }
|
||||
};
|
||||
|
||||
template <typename T, index_t N>
|
||||
CK_TILE_HOST_DEVICE static void print(const array<T, N>& a)
|
||||
{
|
||||
printf("array{size: %ld, data: [", static_cast<long>(N));
|
||||
for(index_t i = 0; i < N; ++i)
|
||||
{
|
||||
if(i > 0)
|
||||
printf(", ");
|
||||
print(a[i]);
|
||||
}
|
||||
printf("]}");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE static void print(const array<T, 0>&)
|
||||
{
|
||||
printf("array{size: 0, data: []}");
|
||||
}
|
||||
|
||||
template <typename, typename>
|
||||
struct vector_traits;
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ template <typename TData, index_t NSize>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_push_back(const array<TData, NSize>& a, const TData& x)
|
||||
{
|
||||
array<TData, NSize + 1> r;
|
||||
static_for<0, NSize, 1>{}([&r, &a ](auto i) constexpr { r(i) = a[i]; });
|
||||
static_for<0, NSize, 1>{}([&r, &a](auto i) constexpr { r(i) = a[i]; });
|
||||
r[number<NSize>{}] = x;
|
||||
return r;
|
||||
}
|
||||
|
||||
@@ -139,26 +139,21 @@ struct map
|
||||
|
||||
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
|
||||
CK_TILE_HOST_DEVICE constexpr iterator end() { return iterator{impl_, size_}; }
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("map{size_: %d, ", size_);
|
||||
//
|
||||
printf("impl_: [");
|
||||
//
|
||||
for(const auto& [k, d] : *this)
|
||||
{
|
||||
printf("{key: ");
|
||||
print(k);
|
||||
printf(", data: ");
|
||||
print(d);
|
||||
printf("}, ");
|
||||
}
|
||||
//
|
||||
printf("]");
|
||||
//
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename key, typename data, index_t max_size>
|
||||
CK_TILE_HOST_DEVICE static void print(const map<key, data, max_size>& m)
|
||||
{
|
||||
printf("map{size_: %d, impl_: [", m.size_);
|
||||
for(const auto& [k, d] : m)
|
||||
{
|
||||
printf("{key: ");
|
||||
print(k);
|
||||
printf(", data: ");
|
||||
print(d);
|
||||
printf("}, ");
|
||||
}
|
||||
printf("]}");
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -9,13 +9,10 @@
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/to_sequence.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/print.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <index_t, index_t, index_t>
|
||||
struct static_for;
|
||||
|
||||
template <index_t...>
|
||||
struct sequence;
|
||||
|
||||
@@ -196,15 +193,24 @@ struct sequence
|
||||
{
|
||||
return sequence<f(Is)...>{};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static void print()
|
||||
{
|
||||
printf("sequence{size: %d, data: [", size());
|
||||
((printf("%d ", Is)), ...);
|
||||
printf("]}");
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t... Is>
|
||||
CK_TILE_HOST_DEVICE static void print(const sequence<Is...>&)
|
||||
{
|
||||
printf("sequence<");
|
||||
if constexpr(sizeof...(Is) > 0)
|
||||
{
|
||||
bool first = true;
|
||||
(([&first](index_t value) {
|
||||
printf("%s%d", first ? "" : ", ", value);
|
||||
first = false;
|
||||
}(Is)),
|
||||
...);
|
||||
}
|
||||
printf(">");
|
||||
}
|
||||
|
||||
namespace impl {
|
||||
template <typename T, T... Ints>
|
||||
struct __integer_sequence;
|
||||
@@ -1178,6 +1184,15 @@ struct reverse_slice_sequence_impl<sequence<x>, sequence<m>, sequence<id>, Slice
|
||||
// clang-format off
|
||||
// input a sequence(with optional mask), and the SliceSize : size per slice
|
||||
// output the sequence each slice, and number of slices
|
||||
// the length count for slice size is from right to left(reverse slice)
|
||||
// or we can say, find the greatest common divider(gcd) from right to left, for the slice length
|
||||
//
|
||||
// e.g. <2, 8, 4>, slice length = 16
|
||||
// step-1: we take the right most <*, *, 4>, remaining 16/4=4
|
||||
// step-2: we only need 4 out of 8, of the midden dim, hence <*, 4, 4>
|
||||
// step-3: since nonthing remain, so the first dim we only need 1, hence<1, 4, 4>
|
||||
// => we got <1, 4, 4> as length for each slice
|
||||
// => total number of slice = <2, 8, 4> / <1, 4, 4> = <2, 2, 1>
|
||||
//
|
||||
// e.g. <2, 1, 4, 2>, 8 -> lengths:<1, 1, 4, 2> , nums: <2, 1, 1, 1> : 2 slices , slice_idx: 0
|
||||
// <4, 2, 4, 1, 2>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 1> : 16 slices , slice_idx: 2
|
||||
@@ -1197,7 +1212,7 @@ struct reverse_slice_sequence_impl<sequence<x>, sequence<m>, sequence<id>, Slice
|
||||
//
|
||||
// return tuple<slice_lengths, slice_nums, slice_index>, slice_index is at which index will start
|
||||
// have split slices (right -> left)
|
||||
// or the first index that sliced length is different from the original length
|
||||
// or the first index (right -> left) that sliced length is different from the original length
|
||||
// clang-format on
|
||||
template <typename Seq,
|
||||
index_t SliceSize,
|
||||
@@ -1207,6 +1222,11 @@ constexpr auto reverse_slice_sequence(Seq,
|
||||
Mask = typename uniform_sequence_gen<Seq::size(), 1>::type{})
|
||||
{
|
||||
static_assert(Seq::size() == Mask::size());
|
||||
static_assert(SliceSize != 0, "slice size zero is invalid");
|
||||
static_assert(container_reduce(pick_sequence_elements_by_mask(Seq{}, Mask{}), multiplies{}, 1) %
|
||||
SliceSize ==
|
||||
0,
|
||||
"slice size can't evenly divide input sizes");
|
||||
using sliced_type =
|
||||
impl::reverse_slice_sequence_impl<Seq,
|
||||
Mask,
|
||||
@@ -1222,9 +1242,8 @@ constexpr auto reverse_slice_sequence(Seq,
|
||||
template <typename Seq,
|
||||
index_t SliceSize,
|
||||
typename Mask = typename uniform_sequence_gen<Seq::size(), 1>::type>
|
||||
constexpr auto slice_sequence(Seq,
|
||||
number<SliceSize>,
|
||||
Mask = typename uniform_sequence_gen<Seq::size(), 1>::type{})
|
||||
constexpr auto
|
||||
slice_sequence(Seq, number<SliceSize>, Mask = typename uniform_sequence_gen<Seq::size(), 1>::type{})
|
||||
{
|
||||
constexpr auto r =
|
||||
reverse_slice_sequence(Seq{}.reverse(), number<SliceSize>{}, Mask{}.reverse());
|
||||
|
||||
@@ -42,7 +42,11 @@ struct thread_buffer {
|
||||
|
||||
// TODO: this ctor can't ignore
|
||||
CK_TILE_HOST_DEVICE constexpr thread_buffer() : data{} {}
|
||||
CK_TILE_HOST_DEVICE constexpr thread_buffer(const value_type & o) : data{o} {}
|
||||
CK_TILE_HOST_DEVICE constexpr thread_buffer(const value_type & o) : data{} {
|
||||
static_for<0, N, 1>{}(
|
||||
[&](auto i) { data[i] = o; }
|
||||
);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto size() { return N; }
|
||||
CK_TILE_HOST_DEVICE auto & get() {return data; }
|
||||
|
||||
@@ -262,12 +262,18 @@ struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
|
||||
return flag;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool IsTuple() { return true; }
|
||||
|
||||
#define TP_COM_() static_assert(I < size(), "wrong! out of range")
|
||||
// clang-format off
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() const { TP_COM_(); return impl::getv<I>(*this); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) const { TP_COM_(); return get<I>(); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() { TP_COM_(); return impl::getv<I>(*this); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) { TP_COM_(); return get<I>(); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() const & { TP_COM_(); return impl::getv<I>(*this); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) const & { TP_COM_(); return get<I>(); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() & { TP_COM_(); return impl::getv<I>(*this); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) & { TP_COM_(); return get<I>(); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() && { TP_COM_(); return impl::getv<I>(std::move(*this)); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) && { TP_COM_(); return std::move(*this).template get<I>(); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() const && { TP_COM_(); return impl::getv<I>(std::move(*this)); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) const &&{ TP_COM_(); return std::move(*this).template get<I>(); }
|
||||
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) at() const { TP_COM_(); return impl::getv<I>(*this); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) at(number<I>) const { TP_COM_(); return get<I>(); }
|
||||
@@ -294,12 +300,29 @@ struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
|
||||
#undef TP_COM_
|
||||
};
|
||||
|
||||
template <typename, typename = void>
|
||||
template <typename... T>
|
||||
CK_TILE_HOST_DEVICE void print(const tuple<T...>& t)
|
||||
{
|
||||
printf("tuple<");
|
||||
if constexpr(sizeof...(T) > 0)
|
||||
{
|
||||
bool first = true;
|
||||
static_for<0, sizeof...(T), 1>{}([&t, &first](auto i) {
|
||||
if(!first)
|
||||
printf(", ");
|
||||
print(t.get(i));
|
||||
first = false;
|
||||
});
|
||||
}
|
||||
printf(">");
|
||||
}
|
||||
|
||||
template <typename, typename>
|
||||
struct vector_traits;
|
||||
|
||||
// specialization for array
|
||||
template <typename... T>
|
||||
struct vector_traits<tuple<T...>>
|
||||
struct vector_traits<tuple<T...>, void>
|
||||
{
|
||||
using scalar_type = __type_pack_element<0, T...>;
|
||||
static constexpr index_t vector_size = sizeof...(T);
|
||||
@@ -396,11 +419,16 @@ struct tuple_array_impl<T, 1>
|
||||
};
|
||||
} // namespace impl
|
||||
|
||||
template <typename F, index_t... ids>
|
||||
CK_TILE_HOST_DEVICE constexpr auto generate_tuple_for(F&& f, sequence<ids...>)
|
||||
{
|
||||
return make_tuple(f(number<ids>{})...);
|
||||
}
|
||||
|
||||
template <typename F, index_t N>
|
||||
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F&& f, number<N>)
|
||||
{
|
||||
return unpack([&f](auto&&... is) { return make_tuple(f(is)...); },
|
||||
typename arithmetic_sequence_gen<0, N, 1>::type{});
|
||||
return generate_tuple_for(f, make_index_sequence<N>{});
|
||||
}
|
||||
|
||||
template <typename F, index_t N>
|
||||
@@ -465,6 +493,12 @@ transform_tuples_impl(F f, const X& x, const Y& y, const Z& z, sequence<Is...>)
|
||||
return make_tuple(f(x.at(number<Is>{}), y.at(number<Is>{}), z.at(number<Is>{}))...);
|
||||
}
|
||||
|
||||
template <typename F, typename Tuple, index_t... Is>
|
||||
constexpr decltype(auto) apply_impl(F&& f, Tuple&& t, sequence<Is...>)
|
||||
{
|
||||
return std::forward<F>(f)(std::forward<Tuple>(t).get(number<Is>{})...);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename F, typename X>
|
||||
@@ -488,6 +522,13 @@ CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y,
|
||||
f, x, y, z, typename arithmetic_sequence_gen<0, X::size(), 1>::type{});
|
||||
}
|
||||
|
||||
template <typename F, typename Tuple>
|
||||
constexpr decltype(auto) apply(F&& f, Tuple&& t)
|
||||
{
|
||||
constexpr index_t N = std::decay_t<Tuple>::size();
|
||||
return detail::apply_impl(std::forward<F>(f), std::forward<Tuple>(t), make_index_sequence<N>{});
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename F, typename X, index_t... Is>
|
||||
|
||||
@@ -6,6 +6,9 @@
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/numeric/numeric.hpp"
|
||||
#if CK_TILE_USE_LLVM_BUILTIN_BF16
|
||||
#include <hip/hip_bfloat16.h>
|
||||
#endif
|
||||
#include <stdint.h>
|
||||
|
||||
#pragma once
|
||||
@@ -102,7 +105,11 @@ struct native_t<bfloat16_t>
|
||||
using bf16_t = bfloat16_t;
|
||||
using bf16_raw_t = typename bf16_t::raw_type;
|
||||
#else
|
||||
#if CK_TILE_USE_LLVM_BUILTIN_BF16
|
||||
using bfloat16_t = __bf16;
|
||||
#else
|
||||
using bfloat16_t = ushort;
|
||||
#endif
|
||||
using bf16_t = bfloat16_t;
|
||||
using bf16_raw_t = uint16_t;
|
||||
#endif
|
||||
@@ -280,7 +287,11 @@ template <bf16_rounding_mode rounding =
|
||||
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE constexpr bfloat16_t float_to_bf16(float f, constant<rounding> = {})
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
return static_cast<bfloat16_t>(f);
|
||||
#else
|
||||
return bit_cast<bfloat16_t>(float_to_bf16_raw(f, constant<rounding>{}));
|
||||
#endif
|
||||
}
|
||||
|
||||
template <bf16_rounding_mode rounding =
|
||||
|
||||
102
include/ck_tile/core/numeric/e8m0.hpp
Normal file
102
include/ck_tile/core/numeric/e8m0.hpp
Normal file
@@ -0,0 +1,102 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/mxfp_convert.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/**
|
||||
* @brief Unsigned representation of a conventional biased Float32 exponent.
|
||||
*
|
||||
* bias = 127;
|
||||
*
|
||||
* E8M0_1 = 0b01111111; => 2^(127-127) = 1
|
||||
* E8M0_2 = 0b10000000; => 2^(128-127) = 2^1 = 2
|
||||
* E8M0_3 = 0b10000010; => 2^(130-127) = 2^3 = 8
|
||||
* E8M0_135 = 0b10000111; => 2^(135-127) = 2^8 = 256
|
||||
* E8M0_142 = 0b10001110; => 2^(142-127) = 2^15 = 32768
|
||||
* E8M0_MIN = 0b00000000; => 2^-127
|
||||
* E8M0_MAX = 0b11111110; => 2^127
|
||||
* E8M0_NAN = 0b11111111; => NaN
|
||||
*/
|
||||
|
||||
struct e8m0_bexp_t
|
||||
{
|
||||
using raw_type = uint8_t;
|
||||
using type = raw_type;
|
||||
|
||||
raw_type data;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr e8m0_bexp_t() : data{type{0b11111111}} {}
|
||||
CK_TILE_HOST_DEVICE explicit constexpr e8m0_bexp_t(type init) : data{init} {}
|
||||
CK_TILE_HOST_DEVICE explicit constexpr e8m0_bexp_t(float scale)
|
||||
: e8m0_bexp_t(static_cast<type>(numeric_utils<float>::get_exponent(scale)))
|
||||
{
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr operator type() const { return data; }
|
||||
CK_TILE_HOST_DEVICE constexpr raw_type& get() { return data; }
|
||||
CK_TILE_HOST_DEVICE constexpr raw_type get() const { return data; }
|
||||
CK_TILE_HOST_DEVICE constexpr operator float() const;
|
||||
|
||||
constexpr bool operator==(const e8m0_bexp_t& other) const { return data == other.data; }
|
||||
|
||||
constexpr bool operator!=(const e8m0_bexp_t& other) const { return data != other.data; }
|
||||
};
|
||||
|
||||
using e8m0_t = e8m0_bexp_t;
|
||||
using e8m0_raw_t = typename e8m0_t::raw_type;
|
||||
|
||||
template <>
|
||||
struct numeric_traits<e8m0_t>
|
||||
{
|
||||
using bitwise_type = e8m0_raw_t;
|
||||
|
||||
static constexpr int exp = 8;
|
||||
static constexpr int mant = 0;
|
||||
static constexpr int bias = 127;
|
||||
static constexpr int PackedSize = 1;
|
||||
};
|
||||
|
||||
// limits
|
||||
template <class T>
|
||||
struct numeric;
|
||||
|
||||
template <>
|
||||
struct numeric<e8m0_t>
|
||||
{
|
||||
static constexpr e8m0_raw_t binary_min = 0b00000000; // 2^-127
|
||||
static constexpr e8m0_raw_t binary_max = 0b11111110; // 2^127
|
||||
static constexpr e8m0_raw_t binary_nan = 0b11111111;
|
||||
CK_TILE_HOST_DEVICE static constexpr e8m0_t min() { return e8m0_t{binary_min}; }
|
||||
CK_TILE_HOST_DEVICE static constexpr e8m0_t max() { return e8m0_t{binary_max}; }
|
||||
CK_TILE_HOST_DEVICE static constexpr e8m0_t quiet_NaN() { return e8m0_t{binary_nan}; }
|
||||
CK_TILE_HOST_DEVICE static constexpr e8m0_t signaling_NaN() { return e8m0_t{binary_nan}; }
|
||||
CK_TILE_HOST_DEVICE static constexpr bool has_inf() { return false; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr e8m0_t epsilon() { return signaling_NaN(); }
|
||||
CK_TILE_HOST_DEVICE static constexpr e8m0_t round_error() { return signaling_NaN(); }
|
||||
CK_TILE_HOST_DEVICE static constexpr e8m0_t zero() { return signaling_NaN(); }
|
||||
CK_TILE_HOST_DEVICE static constexpr e8m0_t infinity() { return signaling_NaN(); }
|
||||
};
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr e8m0_bexp_t::operator float() const
|
||||
{
|
||||
using traits = numeric_traits<float>;
|
||||
if(data == numeric<e8m0_t>::binary_nan)
|
||||
{
|
||||
return std::numeric_limits<float>::signaling_NaN();
|
||||
}
|
||||
else if(data == 0)
|
||||
{
|
||||
return std::numeric_limits<float>::min();
|
||||
}
|
||||
else
|
||||
{
|
||||
return bit_cast<float>(static_cast<traits::bitwise_type>(data) << traits::mant);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -43,19 +43,19 @@ enum class fp8_interpretation
|
||||
};
|
||||
|
||||
/*
|
||||
* ______________FNUZ_________________ | ______________OCP________________
|
||||
* ______________FNUZ_________________ | ______________OCP________________
|
||||
* e4m3 e5m2 | e4m3 e5m2
|
||||
* bias : 8 16 | 7 15
|
||||
* inf : 1.0000.000 1.00000.00 | N/A s.11111.00
|
||||
* inf : N/A N/A | N/A s.11111.00
|
||||
* Nan : 1.0000.000 1.00000.00 | s.1111.111 s.11111.{01, 10, 11}
|
||||
* zero : 0.0000.000 0.00000.00 | s.0000.000 s.00000.00
|
||||
* Max(norm) : s.1111.111 (240) s.11111.11(57344) | s.1111.110(448) s.11110.11(57344)
|
||||
* Max(snorm): s.0000.111 s.00000.11 | s.0000.111 s.00000.11
|
||||
* 0.0068359375 2.288818e-05 | 0.013671875 4.57763671875e-05
|
||||
* 0.0068359375 2.288818e-05 | 0.013671875 4.57763671875e-05
|
||||
* Min(norm) : s.0001.000 s.00001.00 | s.0001.000 s.00001.00
|
||||
* 2^-7(0.00078125) 2^-15(3.05176e-05) | 2^-6(0.015625) 2^-14(6.10352e-05)
|
||||
* 2^-7(0.0078125) 2^-15(3.05176e-05) | 2^-6(0.015625) 2^-14(6.10352e-05)
|
||||
* Min(snorm): s.0000.001 s.00000.01 | s.0000.001 s.00000.01
|
||||
* 2^-10(0.00097656) 2^-17(7.629395e-06)| 2^-9(0.001953125) 2^-16(1.52588e-05)
|
||||
* 2^-10(0.0009765625) 2^-17(7.62939e-06) | 2^-9(0.001953125) 2^-16(1.52588e-05)
|
||||
*/
|
||||
|
||||
template <fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
|
||||
@@ -75,7 +75,7 @@ struct alignas(1) float8_e4m3_t
|
||||
#if CK_TILE_USE_OCP_FP8
|
||||
static constexpr int bias = 7; // OCP
|
||||
#else
|
||||
static constexpr int bias = 8; // FNUZ
|
||||
static constexpr int bias = 8; // FNUZ
|
||||
#endif
|
||||
using raw_type = uint8_t;
|
||||
raw_type data;
|
||||
@@ -259,50 +259,50 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
|
||||
// fp8/bf8 type exponent/mantissa layout
|
||||
constexpr int DstT_exp = numeric_traits<DstT>::exp; // exponent width of the destination type
|
||||
constexpr int DstT_mant = numeric_traits<DstT>::mant; // mantissa width of the destination type
|
||||
constexpr int DstT_bias = numeric_traits<DstT>::bias;
|
||||
constexpr bool is_fnuz =
|
||||
(numeric_traits<DstT>::f8_interpret == fp8_interpretation::E4M3_FNUZ) ||
|
||||
(numeric_traits<DstT>::f8_interpret == fp8_interpretation::E5M2_FNUZ);
|
||||
|
||||
constexpr int SrcT_exp = numeric_traits<SrcT>::exp;
|
||||
constexpr int SrcT_mant = numeric_traits<SrcT>::mant;
|
||||
constexpr int SrcT_exp = numeric_traits<SrcT>::exp;
|
||||
constexpr int SrcT_mant = numeric_traits<SrcT>::mant;
|
||||
constexpr int bias = numeric_traits<SrcT>::bias;
|
||||
constexpr unsigned int fInf = numeric_traits<SrcT>::Inf;
|
||||
constexpr unsigned int abs_mask = numeric_traits<SrcT>::abs_mask;
|
||||
|
||||
using SrcT_bitwise = typename numeric_traits<SrcT>::bitwise_type;
|
||||
SrcT_bitwise src_bitwise = bit_cast<SrcT_bitwise>(src);
|
||||
|
||||
unsigned long long head, mantissa;
|
||||
int exponent, bias;
|
||||
unsigned int head, mantissa;
|
||||
int exponent;
|
||||
unsigned int sign;
|
||||
unsigned long long fInf, abs_mask;
|
||||
|
||||
head = src_bitwise & numeric_traits<SrcT>::head_mask;
|
||||
mantissa = src_bitwise & numeric_traits<SrcT>::mant_mask;
|
||||
exponent = (head >> SrcT_mant) & numeric_traits<SrcT>::exp_mask;
|
||||
sign = head >> (SrcT_exp + SrcT_mant);
|
||||
bias = numeric_traits<SrcT>::bias;
|
||||
fInf = numeric_traits<SrcT>::Inf;
|
||||
abs_mask = numeric_traits<SrcT>::abs_mask;
|
||||
|
||||
unsigned int signed_inf = 0;
|
||||
unsigned int nan = 0;
|
||||
if constexpr(is_fnuz)
|
||||
{
|
||||
signed_inf = clip ? ((sign << 7) + 0x7f) : 0x80;
|
||||
signed_inf = clip ? ((sign << (DstT_exp + DstT_mant)) + 0x7f) : 0x80;
|
||||
nan = 0x80;
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(DstT_exp == 4)
|
||||
{ // e4m3
|
||||
signed_inf = (sign << 7) + (clip ? 0x7e : 0x7f);
|
||||
signed_inf = (sign << (DstT_exp + DstT_mant)) + (clip ? 0x7e : 0x7f);
|
||||
}
|
||||
else
|
||||
{ // e5m2
|
||||
signed_inf = (sign << 7) + (clip ? 0x7b : 0x7c);
|
||||
signed_inf = (sign << (DstT_exp + DstT_mant)) + (clip ? 0x7b : 0x7c);
|
||||
}
|
||||
nan = (sign << 7) + 0x7f;
|
||||
nan = (sign << (DstT_exp + DstT_mant)) + 0x7f;
|
||||
}
|
||||
// Max values
|
||||
unsigned long long ifmax = 0;
|
||||
unsigned int ifmax = 0;
|
||||
if constexpr(is_float)
|
||||
{
|
||||
if constexpr(DstT_exp == 5)
|
||||
@@ -343,9 +343,6 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
|
||||
// Deal with inf and NaNs
|
||||
if((src_bitwise & fInf) == fInf)
|
||||
{
|
||||
if constexpr(is_fnuz)
|
||||
return signed_inf;
|
||||
|
||||
return mantissa != 0 ? nan : signed_inf;
|
||||
}
|
||||
|
||||
@@ -354,11 +351,6 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
|
||||
return signed_inf;
|
||||
}
|
||||
|
||||
if(src_bitwise == 0)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
// First need to check if it is normal or denorm as there is a difference of
|
||||
// implicit 1 Then need to adjust the exponent to align with the F8 exponent,
|
||||
// in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
|
||||
@@ -367,8 +359,7 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
|
||||
|
||||
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
|
||||
// bits
|
||||
const int f8_bias = (1 << (DstT_exp - 1)) - 1 + (is_fnuz ? 1 : 0);
|
||||
const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal
|
||||
constexpr int f8_denormal_act_exponent = 1 - DstT_bias; // actual exponent of f8 denormal
|
||||
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
|
||||
// f8_exponent is the converted f8 exponent with bias encoding
|
||||
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
|
||||
@@ -406,11 +397,16 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
|
||||
// for this case, act_exponent could be larger. Just
|
||||
// that it does not need shift mantissa
|
||||
}
|
||||
mantissa += (1ull << SrcT_mant); // Add the implicit 1 into mantissa
|
||||
mantissa += (1u << SrcT_mant); // Add the implicit 1 into mantissa
|
||||
}
|
||||
|
||||
bool midpoint = (mantissa & ((1ull << (SrcT_mant - DstT_mant + exponent_diff)) - 1)) ==
|
||||
(1ull << (SrcT_mant - DstT_mant + exponent_diff - 1));
|
||||
// The value is smaller than min f8 denormal and results in zero (the early exit also prevents
|
||||
// an undefined behavior of bit shifts >= type width).
|
||||
if(exponent_diff > DstT_mant)
|
||||
{
|
||||
return is_fnuz ? 0 : (sign << (DstT_exp + DstT_mant));
|
||||
}
|
||||
bool midpoint = (mantissa & ((1u << (SrcT_mant - DstT_mant + exponent_diff)) - 1)) ==
|
||||
(1u << (SrcT_mant - DstT_mant + exponent_diff - 1));
|
||||
/* This part is a bit tricky. The judgment of whether it is a tie needs to be
|
||||
done before we shift right as shift right could rip off some residual part and
|
||||
make something not midpoint look like midpoint. For example, the fp16 number
|
||||
@@ -422,31 +418,31 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
|
||||
mantissa >>= exponent_diff;
|
||||
else if(exponent_diff == -1)
|
||||
mantissa <<= -exponent_diff;
|
||||
bool implicit_one = mantissa & (1ull << SrcT_mant);
|
||||
bool implicit_one = mantissa & (1u << SrcT_mant);
|
||||
// if there is no implicit 1, it means the f8 is denormal and need to adjust
|
||||
// to denorm exponent
|
||||
f8_exponent =
|
||||
(act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1);
|
||||
(act_exponent + exponent_diff) /*actual f8 exponent*/ + DstT_bias - (implicit_one ? 0 : 1);
|
||||
|
||||
// Now we have the exponent and mantissa adjusted
|
||||
unsigned long long drop_mask = (1ull << (SrcT_mant - DstT_mant)) - 1;
|
||||
unsigned int drop_mask = (1u << (SrcT_mant - DstT_mant)) - 1;
|
||||
bool odd =
|
||||
mantissa & (1ull << (SrcT_mant -
|
||||
DstT_mant)); // if the least significant bit that is not truncated is 1
|
||||
mantissa &
|
||||
(1u << (SrcT_mant - DstT_mant)); // if the least significant bit that is not truncated is 1
|
||||
mantissa +=
|
||||
(stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1ull) : mantissa)) & drop_mask;
|
||||
(stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1u) : mantissa)) & drop_mask;
|
||||
|
||||
// Now we deal with overflow
|
||||
if(f8_exponent == 0)
|
||||
{
|
||||
if((1ull << SrcT_mant) & mantissa)
|
||||
if((1u << SrcT_mant) & mantissa)
|
||||
{
|
||||
f8_exponent = 1; // denormal overflow to become normal, promote exponent
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if((1ull << (SrcT_mant + 1)) & mantissa)
|
||||
if((1u << (SrcT_mant + 1)) & mantissa)
|
||||
{
|
||||
mantissa >>= 1;
|
||||
f8_exponent++;
|
||||
@@ -471,9 +467,9 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
|
||||
}
|
||||
|
||||
if(f8_exponent == 0 && mantissa == 0)
|
||||
return is_fnuz ? 0 : (sign << 7);
|
||||
return is_fnuz ? 0 : (sign << (DstT_exp + DstT_mant));
|
||||
mantissa &= (1 << DstT_mant) - 1;
|
||||
return (sign << 7) | (f8_exponent << DstT_mant) | mantissa;
|
||||
return (sign << (DstT_exp + DstT_mant)) | (f8_exponent << DstT_mant) | mantissa;
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, bool clip = true>
|
||||
@@ -481,8 +477,9 @@ CK_TILE_HOST_DEVICE DstT run_cast_from_f8(SrcT x)
|
||||
{
|
||||
static_assert(std::is_same<SrcT, fp8_t>::value || std::is_same<SrcT, bf8_t>::value,
|
||||
"SrcT type must be fp8 or bf8.");
|
||||
constexpr int SrcT_exp = numeric_traits<SrcT>::exp;
|
||||
constexpr int SrcT_mant = numeric_traits<SrcT>::mant;
|
||||
constexpr int SrcT_exp = numeric_traits<SrcT>::exp;
|
||||
constexpr int SrcT_mant = numeric_traits<SrcT>::mant;
|
||||
constexpr uint8_t SrcT_abs_mask = numeric_traits<SrcT>::abs_mask;
|
||||
constexpr bool is_fnuz =
|
||||
(numeric_traits<SrcT>::f8_interpret == fp8_interpretation::E4M3_FNUZ) ||
|
||||
(numeric_traits<SrcT>::f8_interpret == fp8_interpretation::E5M2_FNUZ);
|
||||
@@ -518,9 +515,9 @@ CK_TILE_HOST_DEVICE DstT run_cast_from_f8(SrcT x)
|
||||
return 0;
|
||||
}
|
||||
|
||||
unsigned long long sign = x >> 7;
|
||||
unsigned long long mantissa = x & ((1 << SrcT_mant) - 1);
|
||||
int exponent = (x & 0x7F) >> SrcT_mant;
|
||||
unsigned int sign = x >> (SrcT_exp + SrcT_mant);
|
||||
unsigned int mantissa = x & ((1 << SrcT_mant) - 1);
|
||||
int exponent = (x & SrcT_abs_mask) >> SrcT_mant;
|
||||
if constexpr(is_fnuz)
|
||||
{
|
||||
if((x & 0xff) == 0x80)
|
||||
@@ -530,7 +527,7 @@ CK_TILE_HOST_DEVICE DstT run_cast_from_f8(SrcT x)
|
||||
}
|
||||
else
|
||||
{
|
||||
if(x == 0x80)
|
||||
if(x == SrcT(0x80))
|
||||
{
|
||||
return fNeg0;
|
||||
}
|
||||
@@ -559,7 +556,7 @@ CK_TILE_HOST_DEVICE DstT run_cast_from_f8(SrcT x)
|
||||
|
||||
if constexpr(SrcT_exp == 5 && is_half && !is_fnuz)
|
||||
{
|
||||
retval = x << 8;
|
||||
retval = static_cast<typename numeric_traits<DstT>::bitwise_type>(x) << 8;
|
||||
return bit_cast<DstT>(retval);
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
namespace ck_tile {
|
||||
|
||||
using index_t = int32_t;
|
||||
using int32_t = int32_t;
|
||||
using long_index_t = int64_t;
|
||||
using int8_t = int8_t;
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -19,14 +19,18 @@ struct constant
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; }
|
||||
};
|
||||
|
||||
template <auto v>
|
||||
CK_TILE_HOST_DEVICE static void print(const constant<v>&)
|
||||
{
|
||||
printf("%ld", static_cast<long>(v));
|
||||
}
|
||||
|
||||
template <typename T, T v>
|
||||
struct integral_constant : constant<v>
|
||||
{
|
||||
using value_type = T;
|
||||
using type = integral_constant; // using injected-class-name
|
||||
static constexpr T value = v;
|
||||
// constexpr CK_TILE_HOST_DEVICE operator value_type() const noexcept { return value; }
|
||||
// constexpr CK_TILE_HOST_DEVICE value_type operator()() const noexcept { return value; } //
|
||||
};
|
||||
|
||||
template <index_t v>
|
||||
@@ -79,4 +83,14 @@ CK_TILE_BINARY_OP(<=)
|
||||
#undef CK_TILE_LEFT_UNARY_OP
|
||||
#undef CK_TILE_BINARY_OP
|
||||
|
||||
template <typename T>
|
||||
struct is_constant : std::false_type
|
||||
{
|
||||
};
|
||||
template <auto v>
|
||||
struct is_constant<constant<v>> : std::true_type
|
||||
{
|
||||
};
|
||||
template <typename T>
|
||||
inline constexpr bool is_constant_v = is_constant<T>::value;
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -31,8 +31,8 @@ struct scales
|
||||
CK_TILE_HOST_DEVICE constexpr explicit scales(Scale lhs) : lhs_(lhs) {}
|
||||
|
||||
template <typename Right>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Right& rhs) const
|
||||
-> decltype(std::declval<const Scale&>() * rhs)
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
operator()(const Right& rhs) const -> decltype(std::declval<const Scale&>() * rhs)
|
||||
{
|
||||
return lhs_ * rhs;
|
||||
}
|
||||
@@ -43,13 +43,13 @@ struct scales
|
||||
|
||||
/// FIXME: create macro to replace '__host__ __device__' and nothing more
|
||||
template <typename Scale>
|
||||
__host__ __device__ scales(Scale)->scales<Scale>;
|
||||
__host__ __device__ scales(Scale) -> scales<Scale>;
|
||||
|
||||
template <typename Left = void, typename Right = Left>
|
||||
struct plus
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
|
||||
-> decltype(lhs + rhs)
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
|
||||
const Right& rhs) const -> decltype(lhs + rhs)
|
||||
{
|
||||
return lhs + rhs;
|
||||
}
|
||||
@@ -59,21 +59,21 @@ template <>
|
||||
struct plus<void, void>
|
||||
{
|
||||
template <typename Left, typename Right>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
|
||||
-> decltype(lhs + rhs)
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
|
||||
const Right& rhs) const -> decltype(lhs + rhs)
|
||||
{
|
||||
return lhs + rhs;
|
||||
}
|
||||
};
|
||||
|
||||
/// FIXME: create macro to replace '__host__ __device__' and nothing more
|
||||
__host__ __device__ plus()->plus<void, void>;
|
||||
__host__ __device__ plus() -> plus<void, void>;
|
||||
|
||||
template <typename Left = void, typename Right = Left>
|
||||
struct minus
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
|
||||
-> decltype(lhs - rhs)
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
|
||||
const Right& rhs) const -> decltype(lhs - rhs)
|
||||
{
|
||||
return lhs - rhs;
|
||||
}
|
||||
@@ -83,21 +83,21 @@ template <>
|
||||
struct minus<void, void>
|
||||
{
|
||||
template <typename Left, typename Right>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
|
||||
-> decltype(lhs - rhs)
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
|
||||
const Right& rhs) const -> decltype(lhs - rhs)
|
||||
{
|
||||
return lhs - rhs;
|
||||
}
|
||||
};
|
||||
|
||||
/// FIXME: create macro to replace '__host__ __device__' and nothing more
|
||||
__host__ __device__ minus()->minus<void, void>;
|
||||
__host__ __device__ minus() -> minus<void, void>;
|
||||
|
||||
template <typename Left = void, typename Right = Left>
|
||||
struct multiplies
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
|
||||
-> decltype(lhs * rhs)
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
|
||||
const Right& rhs) const -> decltype(lhs * rhs)
|
||||
{
|
||||
return lhs * rhs;
|
||||
}
|
||||
@@ -107,15 +107,15 @@ template <>
|
||||
struct multiplies<void, void>
|
||||
{
|
||||
template <typename Left, typename Right>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
|
||||
-> decltype(lhs * rhs)
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
|
||||
const Right& rhs) const -> decltype(lhs * rhs)
|
||||
{
|
||||
return lhs * rhs;
|
||||
}
|
||||
};
|
||||
|
||||
/// FIXME: create macro to replace '__host__ __device__' and nothing more
|
||||
__host__ __device__ multiplies()->multiplies<void, void>;
|
||||
__host__ __device__ multiplies() -> multiplies<void, void>;
|
||||
|
||||
template <typename T>
|
||||
struct maximize
|
||||
@@ -327,8 +327,8 @@ CK_TILE_HOST_DEVICE constexpr auto lcm(X x, Ys... ys)
|
||||
template <typename Left = void, typename Right = Left>
|
||||
struct equal
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
|
||||
-> decltype(lhs == rhs)
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
|
||||
const Right& rhs) const -> decltype(lhs == rhs)
|
||||
{
|
||||
return lhs == rhs;
|
||||
}
|
||||
@@ -338,15 +338,15 @@ template <>
|
||||
struct equal<void, void>
|
||||
{
|
||||
template <typename Left, typename Right>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
|
||||
-> decltype(lhs == rhs)
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
|
||||
const Right& rhs) const -> decltype(lhs == rhs)
|
||||
{
|
||||
return lhs == rhs;
|
||||
}
|
||||
};
|
||||
|
||||
/// FIXME: create macro to replace '__host__ __device__' and nothing more
|
||||
__host__ __device__ equal()->equal<void, void>;
|
||||
__host__ __device__ equal() -> equal<void, void>;
|
||||
|
||||
template <>
|
||||
struct equal<float, float>
|
||||
@@ -369,8 +369,8 @@ struct equal<double, double>
|
||||
template <typename Left = void, typename Right = Left>
|
||||
struct less
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
|
||||
-> decltype(lhs < rhs)
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
|
||||
const Right& rhs) const -> decltype(lhs < rhs)
|
||||
{
|
||||
return lhs < rhs;
|
||||
}
|
||||
@@ -380,21 +380,21 @@ template <>
|
||||
struct less<void, void>
|
||||
{
|
||||
template <typename Left, typename Right>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
|
||||
-> decltype(lhs < rhs)
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
|
||||
const Right& rhs) const -> decltype(lhs < rhs)
|
||||
{
|
||||
return lhs < rhs;
|
||||
}
|
||||
};
|
||||
|
||||
/// FIXME: create macro to replace '__host__ __device__' and nothing more
|
||||
__host__ __device__ less()->less<void, void>;
|
||||
__host__ __device__ less() -> less<void, void>;
|
||||
|
||||
template <typename Left = void, typename Right = Left>
|
||||
struct less_equal
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
|
||||
-> decltype(lhs <= rhs)
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
|
||||
const Right& rhs) const -> decltype(lhs <= rhs)
|
||||
{
|
||||
return lhs <= rhs;
|
||||
}
|
||||
@@ -404,15 +404,15 @@ template <>
|
||||
struct less_equal<void, void>
|
||||
{
|
||||
template <typename Left, typename Right>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
|
||||
-> decltype(lhs <= rhs)
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
|
||||
const Right& rhs) const -> decltype(lhs <= rhs)
|
||||
{
|
||||
return lhs <= rhs;
|
||||
}
|
||||
};
|
||||
|
||||
/// FIXME: create macro to replace '__host__ __device__' and nothing more
|
||||
__host__ __device__ less_equal()->less_equal<void, void>;
|
||||
__host__ __device__ less_equal() -> less_equal<void, void>;
|
||||
|
||||
template <>
|
||||
struct less_equal<float, float>
|
||||
@@ -487,6 +487,9 @@ struct log2e<float>
|
||||
template <typename T = double>
|
||||
constexpr T log2e_v = log2e<T>::value;
|
||||
|
||||
template <typename T = double>
|
||||
constexpr T log2e_rcp_v = 1. / log2e<T>::value;
|
||||
|
||||
CK_TILE_DEVICE
|
||||
float exp2(float x) { return exp2f(x); };
|
||||
|
||||
@@ -1380,6 +1383,44 @@ CK_TILE_DEVICE double exp<double>(double x)
|
||||
return exp(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T tanh_fast(T x)
|
||||
{
|
||||
return type_convert<T>((exp<T>(2.0 * type_convert<float>(x)) - 1.0) /
|
||||
(exp<T>(2.0 * type_convert<float>(x)) + 1.0));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE float tanh_fast<float>(float x)
|
||||
{
|
||||
// float a = __builtin_amdgcn_sinh(x);
|
||||
// float b = __builtin_amdgcn_cosh(x);
|
||||
// float e = a * __builtin_amdgcn_rcpf(b);
|
||||
// return e;
|
||||
|
||||
float a = 2.0f * log2e_v<float> * x;
|
||||
a = __builtin_amdgcn_exp2f(a);
|
||||
a = __builtin_amdgcn_rcpf(a + 1.0f);
|
||||
a = 2 * a;
|
||||
a = 1 - a;
|
||||
return a;
|
||||
|
||||
// float e, r, s, t, d;
|
||||
// float a = x;
|
||||
// s = abs(a);
|
||||
// t = -log2e_v<float> * 2.0f * s;
|
||||
// e = __builtin_amdgcn_exp2f(t);
|
||||
// d = e + 1.0f;
|
||||
// r = __builtin_amdgcn_rcpf(d);
|
||||
// r = e * (-r) + r;
|
||||
// if (s < 4.997253418e-3f) r = a;
|
||||
// union fipnr {float f; unsigned int i;};
|
||||
// fipnr r_; r_.f = r;
|
||||
// fipnr a_; a_.f = a;
|
||||
// { r_.i = (r_.i|(a_.i&0x80000000)); r = r_.f; }
|
||||
// return r;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T log(T x)
|
||||
{
|
||||
|
||||
218
include/ck_tile/core/numeric/mxfp_convert.hpp
Normal file
218
include/ck_tile/core/numeric/mxfp_convert.hpp
Normal file
@@ -0,0 +1,218 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
// modify from include/ck/utility/mxfp_utils.hpp
|
||||
|
||||
template <typename T>
|
||||
struct numeric_utils : numeric_traits<T>
|
||||
{
|
||||
|
||||
using traits = numeric_traits<T>;
|
||||
using _numeric = numeric<T>;
|
||||
using raw_type = typename traits::bitwise_type;
|
||||
|
||||
static constexpr int exp_mask = (1 << traits::exp) - 1;
|
||||
|
||||
static constexpr raw_type get_exponent(raw_type x)
|
||||
{
|
||||
// TODO: check if repeated calls are optimized.
|
||||
return (x >> traits::mant) & exp_mask;
|
||||
}
|
||||
static constexpr raw_type get_exponent(const T& x)
|
||||
{
|
||||
return get_exponent(bit_cast<raw_type>(x));
|
||||
}
|
||||
static constexpr bool is_positive(raw_type x)
|
||||
{
|
||||
return (x >> (traits::exp + traits::mant)) == _numeric::binary_zero;
|
||||
}
|
||||
static constexpr bool is_subnormal(raw_type x)
|
||||
{
|
||||
return get_exponent(x) == _numeric::binary_zero;
|
||||
}
|
||||
// TODO: replace double with template arg?
|
||||
static constexpr double get_mantissa(raw_type x)
|
||||
{
|
||||
double mantissa = is_subnormal(x) ? 0.0f : 1.0f;
|
||||
for(raw_type i = 0; i < traits::mant; ++i)
|
||||
{
|
||||
mantissa += std::ldexp(static_cast<float>(x & 0b1), -(traits::mant - i));
|
||||
x >>= 1;
|
||||
}
|
||||
return mantissa;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE float convert_to_float(typename T::raw_type data, float scale = 1.f)
|
||||
{
|
||||
using utils = numeric_utils<T>;
|
||||
float sign = utils::is_positive(data) ? 1.0 : -1.0;
|
||||
int exp = (utils::is_subnormal(data) ? 1 : utils::get_exponent(data)) - utils::bias;
|
||||
float mant = utils::get_mantissa(data);
|
||||
|
||||
return std::ldexp(sign * mant * scale, exp);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE typename T::raw_type convert_to_type(float value, float scale = 1.f)
|
||||
{
|
||||
using bitwise_type = typename numeric_traits<T>::bitwise_type;
|
||||
|
||||
value /= scale;
|
||||
|
||||
if(std::abs(value) > float(numeric<T>::max()))
|
||||
{
|
||||
float max_value = numeric<T>::max();
|
||||
|
||||
// cppcheck-suppress redundantAssignment
|
||||
uint32_t max_bitwise = bit_cast<uint32_t>(max_value);
|
||||
|
||||
// cppcheck-suppress redundantAssignment
|
||||
bitwise_type sign =
|
||||
bit_cast<uint32_t>(value) >> (numeric_traits<float>::exp + numeric_traits<float>::mant);
|
||||
bitwise_type exp =
|
||||
((max_bitwise >> numeric_traits<float>::mant) & numeric_traits<float>::exp_mask) -
|
||||
(numeric_traits<float>::bias - numeric_traits<T>::bias);
|
||||
bitwise_type mantissa =
|
||||
max_bitwise >> (numeric_traits<float>::mant - numeric_traits<T>::mant);
|
||||
|
||||
uint32_t mant_prev = max_bitwise >> (numeric_traits<float>::mant - numeric_traits<T>::mant);
|
||||
mant_prev &= ((1 << numeric_traits<T>::mant) - 1);
|
||||
mant_prev--;
|
||||
|
||||
mant_prev <<= (numeric_traits<float>::mant - numeric_traits<T>::mant);
|
||||
uint32_t prev_bit =
|
||||
((max_bitwise >> numeric_traits<float>::mant) << numeric_traits<float>::mant) |
|
||||
mant_prev;
|
||||
|
||||
float prev_val = bit_cast<float>(prev_bit);
|
||||
float diff = max_value - prev_val;
|
||||
|
||||
float actual_max = max_value + (diff / 2);
|
||||
|
||||
if(std::abs(value) < actual_max)
|
||||
{
|
||||
return sign << ((numeric_traits<T>::exp + numeric_traits<T>::mant)) |
|
||||
(exp << numeric_traits<T>::mant) | mantissa;
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(!numeric<T>::has_inf())
|
||||
{
|
||||
|
||||
return (1 << (numeric_traits<T>::mant + numeric_traits<T>::exp)) - 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
exp++;
|
||||
return sign << ((numeric_traits<T>::exp + numeric_traits<T>::mant)) |
|
||||
(exp << numeric_traits<T>::mant);
|
||||
}
|
||||
}
|
||||
}
|
||||
const int mfmt = numeric_traits<float>::mant;
|
||||
uint32_t x;
|
||||
x = bit_cast<uint32_t>(value);
|
||||
|
||||
uint32_t head, mantissa;
|
||||
int32_t exponent, bias;
|
||||
uint32_t sign;
|
||||
|
||||
head = x & numeric_traits<float>::head_mask;
|
||||
mantissa = x & numeric_traits<float>::mant_mask;
|
||||
exponent = (head >> numeric_traits<float>::mant) & numeric_traits<float>::exp_mask;
|
||||
sign = head >> (numeric_traits<float>::mant + numeric_traits<float>::exp);
|
||||
bias = numeric_traits<float>::bias;
|
||||
|
||||
if(x == 0)
|
||||
{
|
||||
return 0b0;
|
||||
}
|
||||
|
||||
const int mini_bias = numeric_traits<T>::bias;
|
||||
const int mini_denormal_act_exponent = 1 - mini_bias;
|
||||
|
||||
int act_exponent, out_exponent, exponent_diff;
|
||||
|
||||
bool is_subnorm = false;
|
||||
|
||||
if(exponent == 0)
|
||||
{
|
||||
act_exponent = exponent - bias + 1;
|
||||
exponent_diff = mini_denormal_act_exponent - act_exponent;
|
||||
is_subnorm = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
act_exponent = exponent - bias;
|
||||
if(act_exponent <= mini_denormal_act_exponent)
|
||||
{
|
||||
exponent_diff = mini_denormal_act_exponent - act_exponent;
|
||||
is_subnorm = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
exponent_diff = 0;
|
||||
}
|
||||
mantissa += (1UL << mfmt);
|
||||
}
|
||||
|
||||
auto shift_amount = (mfmt - numeric_traits<T>::mant + exponent_diff);
|
||||
shift_amount = (shift_amount >= 64) ? 63 : shift_amount;
|
||||
bool midpoint = (mantissa & ((1UL << shift_amount) - 1)) == (1UL << (shift_amount - 1));
|
||||
|
||||
float min_subnorm = float(numeric<T>::epsilon()) * (sign ? -1 : 1);
|
||||
|
||||
if(is_subnorm && std::abs(value) < std::abs(min_subnorm))
|
||||
{
|
||||
// closer to 0
|
||||
if(std::abs(value) <= std::abs(min_subnorm - value))
|
||||
return sign << (numeric_traits<T>::exp + numeric_traits<T>::mant);
|
||||
else
|
||||
return 1 | (sign << (numeric_traits<T>::exp + numeric_traits<T>::mant));
|
||||
}
|
||||
|
||||
if(exponent_diff > 0)
|
||||
mantissa >>= exponent_diff;
|
||||
else if(exponent_diff == -1)
|
||||
mantissa <<= -exponent_diff;
|
||||
bool implicit_one = mantissa & (1 << mfmt);
|
||||
out_exponent = (act_exponent + exponent_diff) + mini_bias - (implicit_one ? 0 : 1);
|
||||
|
||||
uint32_t drop_mask = (1UL << (mfmt - numeric_traits<T>::mant)) - 1;
|
||||
bool odd = mantissa & (1UL << (mfmt - numeric_traits<T>::mant));
|
||||
mantissa += (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa) & drop_mask;
|
||||
|
||||
if(out_exponent == 0)
|
||||
{
|
||||
if((1UL << mfmt) & mantissa)
|
||||
{
|
||||
out_exponent = 1;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if((1UL << (mfmt + 1)) & mantissa)
|
||||
{
|
||||
mantissa >>= 1;
|
||||
out_exponent++;
|
||||
}
|
||||
}
|
||||
|
||||
mantissa >>= (mfmt - numeric_traits<T>::mant);
|
||||
|
||||
if(out_exponent == 0 && mantissa == 0)
|
||||
{
|
||||
return sign << (numeric_traits<T>::exp + numeric_traits<T>::mant);
|
||||
}
|
||||
|
||||
mantissa &= (1UL << numeric_traits<T>::mant) - 1;
|
||||
return (sign << (numeric_traits<T>::exp + numeric_traits<T>::mant)) |
|
||||
(out_exponent << numeric_traits<T>::mant) | mantissa;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -103,94 +103,92 @@ struct numeric_traits<float>
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_) \
|
||||
attr_ bool operator==(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) == static_cast<float>(y); \
|
||||
} \
|
||||
attr_ bool operator!=(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) != static_cast<float>(y); \
|
||||
} \
|
||||
attr_ bool operator<(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) < static_cast<float>(y); \
|
||||
} \
|
||||
attr_ bool operator<=(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) <= static_cast<float>(y); \
|
||||
} \
|
||||
attr_ bool operator>(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) > static_cast<float>(y); \
|
||||
} \
|
||||
attr_ bool operator>=(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) >= static_cast<float>(y); \
|
||||
} \
|
||||
attr_ type_ operator+(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) + static_cast<float>(y)); \
|
||||
} \
|
||||
attr_ type_ operator-(const type_& x) \
|
||||
{ \
|
||||
constexpr uint32_t bits = sizeof(type_) * 8; \
|
||||
constexpr uint32_t mask = 1 << (bits - 1); \
|
||||
type_ y = x; \
|
||||
y.data ^= static_cast<typename type_::raw_type>(mask); \
|
||||
return y; \
|
||||
} \
|
||||
attr_ type_ operator-(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) - static_cast<float>(y)); \
|
||||
} \
|
||||
attr_ type_ operator*(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) * static_cast<float>(y)); \
|
||||
} \
|
||||
attr_ type_ operator/(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) / static_cast<float>(y)); \
|
||||
} \
|
||||
attr_ type_& operator+=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) + static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_& operator-=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) - static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_& operator*=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) * static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_& operator/=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) / static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_& operator++(type_& x) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) + 1.f); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_& operator--(type_& x) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) - 1.f); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_ operator++(type_& x, int) \
|
||||
{ \
|
||||
type_ y(x); \
|
||||
x = type_(static_cast<float>(x) + 1.f); \
|
||||
return y; \
|
||||
} \
|
||||
attr_ type_ operator--(type_& x, int) \
|
||||
{ \
|
||||
type_ y(x); \
|
||||
x = type_(static_cast<float>(x) - 1.f); \
|
||||
return y; \
|
||||
#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_) \
|
||||
attr_ bool operator==(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return std::abs(static_cast<float>(x) - static_cast<float>(y)) < \
|
||||
static_cast<float>(numeric<type_>::epsilon()); \
|
||||
} \
|
||||
attr_ bool operator!=(const type_& x, const type_& y) { return not operator==(x, y); } \
|
||||
attr_ bool operator<(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) < static_cast<float>(y); \
|
||||
} \
|
||||
attr_ bool operator<=(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) <= static_cast<float>(y); \
|
||||
} \
|
||||
attr_ bool operator>(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) > static_cast<float>(y); \
|
||||
} \
|
||||
attr_ bool operator>=(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) >= static_cast<float>(y); \
|
||||
} \
|
||||
attr_ type_ operator+(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) + static_cast<float>(y)); \
|
||||
} \
|
||||
attr_ type_ operator-(const type_& x) \
|
||||
{ \
|
||||
constexpr uint32_t bits = sizeof(type_) * 8; \
|
||||
constexpr uint32_t mask = 1 << (bits - 1); \
|
||||
type_ y = x; \
|
||||
y.data ^= static_cast<typename type_::raw_type>(mask); \
|
||||
return y; \
|
||||
} \
|
||||
attr_ type_ operator-(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) - static_cast<float>(y)); \
|
||||
} \
|
||||
attr_ type_ operator*(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) * static_cast<float>(y)); \
|
||||
} \
|
||||
attr_ type_ operator/(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) / static_cast<float>(y)); \
|
||||
} \
|
||||
attr_ type_& operator+=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) + static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_& operator-=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) - static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_& operator*=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) * static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_& operator/=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) / static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_& operator++(type_& x) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) + 1.f); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_& operator--(type_& x) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) - 1.f); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_ operator++(type_& x, int) \
|
||||
{ \
|
||||
type_ y(x); \
|
||||
x = type_(static_cast<float>(x) + 1.f); \
|
||||
return y; \
|
||||
} \
|
||||
attr_ type_ operator--(type_& x, int) \
|
||||
{ \
|
||||
type_ y(x); \
|
||||
x = type_(static_cast<float>(x) - 1.f); \
|
||||
return y; \
|
||||
}
|
||||
|
||||
357
include/ck_tile/core/numeric/pk_fp4.hpp
Normal file
357
include/ck_tile/core/numeric/pk_fp4.hpp
Normal file
@@ -0,0 +1,357 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cmath>
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/mxfp_convert.hpp"
|
||||
|
||||
#if defined(__gfx950__)
|
||||
#define CK_TILE_FP4_CVT_DEVICE 1
|
||||
#else
|
||||
#define CK_TILE_FP4_CVT_DEVICE 0
|
||||
#endif
|
||||
|
||||
#define TEST_convert_with_table 0
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using fp32_t = float;
|
||||
using fp32x2_t = float __attribute__((ext_vector_type(2)));
|
||||
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
|
||||
using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2)));
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr uint8_t float_to_e2m1(float x, float scale = 1.f);
|
||||
|
||||
// TODO: Add stochastic method
|
||||
struct pk_float4_e2m1_t
|
||||
{
|
||||
// TODO: Can we merge raw_type and type?
|
||||
using raw_type = uint8_t;
|
||||
using type = raw_type;
|
||||
raw_type data;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t() : data{type{}} {}
|
||||
template <typename T, typename = std::enable_if_t<std::is_integral_v<T>>>
|
||||
CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t(T init) : data{static_cast<type>(init)}
|
||||
{
|
||||
}
|
||||
CK_TILE_HOST_DEVICE explicit constexpr pk_float4_e2m1_t(float init, float scale = 1.f)
|
||||
: data{float_to_e2m1(init, scale)}
|
||||
{
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr operator type() const { return data; }
|
||||
CK_TILE_HOST_DEVICE constexpr raw_type& get() { return data; }
|
||||
CK_TILE_HOST_DEVICE constexpr raw_type get() const { return data; }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr float to_float(float scale = 1.f) const;
|
||||
CK_TILE_HOST_DEVICE constexpr fp32x2_t to_fp32x2(float scale = 1.f) const;
|
||||
CK_TILE_HOST_DEVICE constexpr fp16_t to_fp16(float scale = 1.f) const;
|
||||
CK_TILE_HOST_DEVICE constexpr fp16x2_t to_fp16x2(float scale = 1.f) const;
|
||||
CK_TILE_HOST_DEVICE constexpr bf16_t to_bf16(float scale = 1.f) const;
|
||||
CK_TILE_HOST_DEVICE constexpr bf16x2_t to_bf16x2(float scale = 1.f) const;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr operator float() const { return to_float(); }
|
||||
CK_TILE_HOST_DEVICE constexpr operator fp32x2_t() const { return to_fp32x2(); }
|
||||
CK_TILE_HOST_DEVICE constexpr operator fp16_t() const { return to_fp16(); }
|
||||
CK_TILE_HOST_DEVICE constexpr operator fp16x2_t() const { return to_fp16x2(); }
|
||||
CK_TILE_HOST_DEVICE constexpr operator bf16_t() const { return to_bf16(); }
|
||||
CK_TILE_HOST_DEVICE constexpr operator bf16x2_t() const { return to_bf16x2(); }
|
||||
|
||||
template <index_t I>
|
||||
CK_TILE_HOST_DEVICE constexpr raw_type unpack(number<I>) const;
|
||||
CK_TILE_HOST_DEVICE constexpr static pk_float4_e2m1_t pack(const type x0, const type x1)
|
||||
{
|
||||
return (x1 << 4) | (x0 & 0b00001111);
|
||||
}
|
||||
|
||||
#if TEST_convert_with_table
|
||||
static constexpr float e2m1_to_fp32_table[16] = {
|
||||
0, 0.5, 1, 1.5, 2, 3, 4, 6, -0, -0.5, -1, -1.5, -2, -3, -4, -6};
|
||||
static constexpr fp16_t e2m1_to_fp16_table[16] = {
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0x0000)), // 0
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0x3800)), // 0.5
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0x3C00)), // 1
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0x3E00)), // 1.5
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0x4000)), // 2
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0x4200)), // 3
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0x4400)), // 4
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0x4600)), // 6
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0x8000)), // -0
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0xB800)), // -0.5
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0xBC00)), // -1
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0xBE00)), // -1.5
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0xC000)), // -2
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0xC200)), // -3
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0xC400)), // -4
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0xC600)) // -6
|
||||
};
|
||||
#endif
|
||||
};
|
||||
|
||||
using pk_fp4_t = pk_float4_e2m1_t;
|
||||
using pk_fp4_raw_t = typename pk_fp4_t::raw_type;
|
||||
|
||||
template <>
|
||||
struct numeric_traits<pk_fp4_t>
|
||||
{
|
||||
using bitwise_type = pk_fp4_raw_t;
|
||||
|
||||
static constexpr int exp = 2;
|
||||
static constexpr int mant = 1;
|
||||
static constexpr int bias = 1;
|
||||
static constexpr int PackedSize = 2;
|
||||
};
|
||||
|
||||
// limits
|
||||
template <class T>
|
||||
struct numeric;
|
||||
|
||||
template <>
|
||||
struct numeric<pk_fp4_t>
|
||||
{
|
||||
static constexpr pk_fp4_raw_t binary_min_normal = 0b00100010; // 1
|
||||
static constexpr pk_fp4_raw_t binary_max_normal = 0b01110111; // 6
|
||||
static constexpr pk_fp4_raw_t binary_lowest_normal = 0b11111111; // -6
|
||||
static constexpr pk_fp4_raw_t binary_min_subnorm = 0b00010001; // 0.5
|
||||
static constexpr pk_fp4_raw_t binary_max_subnorm = 0b00010001; // 0.5
|
||||
static constexpr pk_fp4_raw_t binary_zero = 0b00000000; // 0
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t min() { return binary_min_normal; }
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t max() { return binary_max_normal; }
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t lowest() { return binary_lowest_normal; }
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t epsilon() { return binary_min_subnorm; }
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t round_error() { return binary_min_subnorm; }
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t zero() { return binary_zero; }
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t denorm_min() { return binary_min_subnorm; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool has_inf() { return false; }
|
||||
// N/A
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t infinity() { return max(); }
|
||||
// N/A
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t quiet_NaN() { return max(); }
|
||||
// N/A
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t signaling_NaN() { return max(); }
|
||||
};
|
||||
|
||||
template <index_t I>
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t pk_fp4_t::unpack(number<I>) const
|
||||
{
|
||||
static_assert(I < 2, "Index is out of range.");
|
||||
if constexpr(I == 1)
|
||||
return (data >> 4);
|
||||
else
|
||||
return data & 0b00001111;
|
||||
}
|
||||
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, pk_fp4_t)
|
||||
// TODO: consider replace this macro to improve performance
|
||||
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
namespace impl {
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T _from_f4(pk_fp4_raw_t src, float scale = 1.0f)
|
||||
{
|
||||
if constexpr(std::is_same_v<T, fp32_t>)
|
||||
return fp32x2_t(__builtin_amdgcn_cvt_scalef32_pk_f32_fp4(src, scale, 0))[0];
|
||||
else if constexpr(std::is_same_v<T, fp32x2_t>)
|
||||
return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(src, scale, 0);
|
||||
else if constexpr(std::is_same_v<T, fp16_t>)
|
||||
return fp16x2_t(__builtin_amdgcn_cvt_scalef32_pk_f16_fp4(src, scale, 0))[0];
|
||||
else if constexpr(std::is_same_v<T, fp16x2_t>)
|
||||
return __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(src, scale, 0);
|
||||
else if constexpr(std::is_same_v<T, bf16_t>)
|
||||
return bf16x2_t(__builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(src, scale, 0))[0];
|
||||
else if constexpr(std::is_same_v<T, bf16x2_t>)
|
||||
return __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(src, scale, 0);
|
||||
else
|
||||
static_assert(std::false_type::value, "Unsupported type.");
|
||||
return T{};
|
||||
}
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE pk_fp4_raw_t _to_f4(T src, float scale = 1.0f)
|
||||
{
|
||||
union
|
||||
{
|
||||
uint32_t u32;
|
||||
pk_fp4_raw_t pf4[4];
|
||||
} cvt{0};
|
||||
if constexpr(std::is_same_v<T, fp32_t>)
|
||||
cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(cvt.u32, src, src, scale, 0);
|
||||
else if constexpr(std::is_same_v<T, fp32x2_t>)
|
||||
cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(cvt.u32, src[0], src[1], scale, 0);
|
||||
else if constexpr(std::is_same_v<T, fp16_t>)
|
||||
cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(cvt.u32, fp16x2_t{src, src}, scale, 0);
|
||||
else if constexpr(std::is_same_v<T, fp16x2_t>)
|
||||
cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(cvt.u32, src, scale, 0);
|
||||
else if constexpr(std::is_same_v<T, bf16_t>)
|
||||
cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_bf16(cvt.u32, bf16x2_t{src, src}, scale, 0);
|
||||
else if constexpr(std::is_same_v<T, bf16x2_t>)
|
||||
cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_bf16(cvt.u32, src, scale, 0);
|
||||
else
|
||||
static_assert(std::false_type::value, "Unsupported type.");
|
||||
return cvt.pf4[0];
|
||||
}
|
||||
|
||||
} // namespace impl
|
||||
#endif
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr bf16_t pk_fp4_t::to_bf16(float scale) const
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_from_f4<bf16_t>(data, scale);
|
||||
#else
|
||||
return bf16_t{type_convert<bf16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{}), scale))};
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr bf16x2_t pk_fp4_t::to_bf16x2(float scale) const
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_from_f4<bf16x2_t>(data, scale);
|
||||
#else
|
||||
return bf16x2_t{type_convert<bf16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{}), scale)),
|
||||
type_convert<bf16_t>(convert_to_float<pk_fp4_t>(unpack(number<1>{}), scale))};
|
||||
#endif
|
||||
}
|
||||
|
||||
// TODO: make float_to_e2m1 generic so that we can convert from directrly.
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t float_to_e2m1(float x, float scale)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x, scale);
|
||||
#else
|
||||
return convert_to_type<pk_fp4_t>(x, scale);
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t float_to_pk_fp4(const float& x, float scale)
|
||||
{
|
||||
return float_to_e2m1(x, scale);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16_to_pk_fp4(const fp16_t& x, float scale)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x, scale);
|
||||
#else
|
||||
return float_to_e2m1(type_convert<float>(x), scale);
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16_to_pk_fp4(const bf16_t& x, float scale)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x, scale);
|
||||
#else
|
||||
return float_to_e2m1(type_convert<float>(x), scale);
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t& x, float scale)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x, scale);
|
||||
#else
|
||||
return pk_fp4_t::pack(float_to_e2m1(type_convert<float>(x[0]), scale),
|
||||
float_to_e2m1(type_convert<float>(x[1]), scale));
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x, float scale)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x, scale);
|
||||
#else
|
||||
return pk_fp4_t::pack(float_to_e2m1(type_convert<float>(x[0]), scale),
|
||||
float_to_e2m1(type_convert<float>(x[1]), scale));
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x, float scale)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x, scale);
|
||||
#else
|
||||
return pk_fp4_t::pack(float_to_e2m1(x[0], scale), float_to_e2m1(x[1], scale));
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_to_fp32x2(const pk_fp4_t& x, float scale)
|
||||
{
|
||||
return x.to_fp32x2(scale);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_to_fp16x2(const pk_fp4_t& x, float scale)
|
||||
{
|
||||
return x.to_fp16x2(scale);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr bf16x2_t pk_fp4_to_bf16x2(const pk_fp4_t& x, float scale)
|
||||
{
|
||||
return x.to_bf16x2(scale);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr float pk_fp4_to_float(const pk_fp4_t& x, float scale)
|
||||
{
|
||||
return x.to_float(scale);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_to_fp16(const pk_fp4_t& x, float scale)
|
||||
{
|
||||
return x.to_fp16(scale);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr bf16_t pk_fp4_to_bf16(const pk_fp4_t& x, float scale)
|
||||
{
|
||||
return x.to_bf16(scale);
|
||||
}
|
||||
|
||||
#if TEST_convert_with_table == 0
|
||||
CK_TILE_HOST_DEVICE constexpr float pk_fp4_t::to_float(float scale) const
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_from_f4<fp32_t>(data, scale);
|
||||
#else
|
||||
return convert_to_float<pk_fp4_t>(unpack(number<0>{}), scale);
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_t::to_fp32x2(float scale) const
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_from_f4<fp32x2_t>(data, scale);
|
||||
#else
|
||||
return fp32x2_t{convert_to_float<pk_fp4_t>(unpack(number<0>{}), scale),
|
||||
convert_to_float<pk_fp4_t>(unpack(number<1>{}), scale)};
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_t::to_fp16(float scale) const
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_from_f4<fp16_t>(data, scale);
|
||||
#else
|
||||
return fp16_t{type_convert<fp16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{}), scale))};
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_from_f4<fp16x2_t>(data, scale);
|
||||
#else
|
||||
return fp16x2_t{type_convert<fp16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{}), scale)),
|
||||
type_convert<fp16_t>(convert_to_float<pk_fp4_t>(unpack(number<1>{}), scale))};
|
||||
#endif
|
||||
}
|
||||
#else
|
||||
CK_TILE_HOST_DEVICE constexpr float pk_fp4_t::to_float(float scale) const
|
||||
{
|
||||
return e2m1_to_fp32_table[unpack(number<0>{})] * scale;
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_t::to_fp32x2(float scale) const
|
||||
{
|
||||
return fp32x2_t{e2m1_to_fp32_table[unpack(number<0>{})] * scale, e2m1_to_fp32_table[unpack(number<1>{}] * scale};
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_t::to_fp16(float scale) const
|
||||
{
|
||||
return type_convert<float>(e2m1_to_fp16_table[unpack(number<0>{})]) * scale;
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const
|
||||
{
|
||||
return fp16x2_t{
|
||||
type_convert<fp16_t>(type_convert<float>(e2m1_to_fp16_table[unpack(number<0>{})]) * scale),
|
||||
type_convert<fp16_t>(type_convert<float>(e2m1_to_fp16_table[unpack(number<1>{})]) * scale)};
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -99,7 +99,7 @@ struct numeric_traits<pk_int4_t>
|
||||
|
||||
using fp32x2_t = float __attribute__((ext_vector_type(2)));
|
||||
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
|
||||
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
|
||||
using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2)));
|
||||
|
||||
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t& x)
|
||||
{
|
||||
@@ -116,6 +116,24 @@ CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t& x)
|
||||
return res;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t_signed_conversion(const pk_int4_t& x)
|
||||
{
|
||||
uint8_t x_u8 = ck_tile::bit_cast<uint8_t>(x);
|
||||
|
||||
float x_l = ((x_u8 & 0x0f) >> 0);
|
||||
float x_h = ((x_u8 & 0xf0) >> 4);
|
||||
|
||||
x_l = x_l > 7 ? x_l - 16 : x_l;
|
||||
x_h = x_l > 7 ? x_l - 16 : x_l;
|
||||
|
||||
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
|
||||
fp32x2_t res = {x_h, x_l};
|
||||
#elif
|
||||
fp32x2_t res = {x_l, x_h};
|
||||
#endif
|
||||
return res;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE fp16x2_t pk_int4_t_to_halfx2_t(const pk_int4_t& x)
|
||||
{
|
||||
uint8_t x_u8 = ck_tile::bit_cast<uint8_t>(x);
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "ck_tile/core/numeric/bfloat16.hpp"
|
||||
#include "ck_tile/core/numeric/float8.hpp"
|
||||
#include "ck_tile/core/numeric/int8.hpp"
|
||||
#include "ck_tile/core/numeric/mxfp_convert.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -63,8 +64,44 @@ CK_TILE_TYPE_CONVERT(bf8_t, bf8, float, float)
|
||||
|
||||
CK_TILE_TYPE_CONVERT(float, float, int8_t, int8)
|
||||
CK_TILE_TYPE_CONVERT(int8_t, int8, float, float)
|
||||
|
||||
#undef CK_TILE_TYPE_CONVERT
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
#include "ck_tile/core/numeric/pk_fp4.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr Y scaled_type_convert(X x, float scale);
|
||||
|
||||
#define CK_TILE_SCALED_TYPE_CONVERT(dtype_, dname_, stype_, sname_) \
|
||||
template <> \
|
||||
CK_TILE_HOST_DEVICE constexpr dtype_ scaled_type_convert<dtype_, stype_>(stype_ x, \
|
||||
float scale) \
|
||||
{ \
|
||||
return sname_##_to_##dname_(x, scale); \
|
||||
} \
|
||||
template <> \
|
||||
CK_TILE_HOST_DEVICE constexpr dtype_ type_convert<dtype_, stype_>(stype_ x) \
|
||||
{ \
|
||||
return sname_##_to_##dname_(x, 1.f); \
|
||||
}
|
||||
|
||||
CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp32x2_t, fp32x2)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(fp32x2_t, fp32x2, pk_fp4_t, pk_fp4)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp16x2_t, fp16x2)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(fp16x2_t, fp16x2, pk_fp4_t, pk_fp4)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, bf16x2_t, bf16x2)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(bf16x2_t, bf16x2, pk_fp4_t, pk_fp4)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, float, float)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(float, float, pk_fp4_t, pk_fp4)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, bf16_t, bf16)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(bf16_t, bf16, pk_fp4_t, pk_fp4)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp16_t, fp16)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(fp16_t, fp16, pk_fp4_t, pk_fp4)
|
||||
#undef CK_TILE_SCALED_TYPE_CONVERT
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -84,7 +84,7 @@ using ext_vector_t = typename impl::ext_vector<T, N>::type;
|
||||
|
||||
// by default, any type will result in a vector_size=1 with scalar_type=T traits.
|
||||
// ... unless we have other vector_traits specialization
|
||||
template <typename T, typename>
|
||||
template <typename T, typename = void>
|
||||
struct vector_traits
|
||||
{
|
||||
using scalar_type =
|
||||
@@ -94,7 +94,7 @@ struct vector_traits
|
||||
|
||||
// specialization for ext_vector_type()
|
||||
template <typename T, index_t N>
|
||||
struct vector_traits<T __attribute__((ext_vector_type(N)))>
|
||||
struct vector_traits<T __attribute__((ext_vector_type(N))), void>
|
||||
{
|
||||
using scalar_type = std::conditional_t<std::is_same_v<T, pk_int4_t>, int8_t, T>;
|
||||
static constexpr index_t vector_size = N;
|
||||
@@ -131,12 +131,12 @@ using fp16x64_t = _Float16 __attribute__((ext_vector_type(64)));
|
||||
|
||||
// bf16
|
||||
// using bf16_t = ...
|
||||
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
|
||||
using bf16x4_t = bf16_raw_t __attribute__((ext_vector_type(4)));
|
||||
using bf16x8_t = bf16_raw_t __attribute__((ext_vector_type(8)));
|
||||
using bf16x16_t = bf16_raw_t __attribute__((ext_vector_type(16)));
|
||||
using bf16x32_t = bf16_raw_t __attribute__((ext_vector_type(32)));
|
||||
using bf16x64_t = bf16_raw_t __attribute__((ext_vector_type(64)));
|
||||
using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2)));
|
||||
using bf16x4_t = bfloat16_t __attribute__((ext_vector_type(4)));
|
||||
using bf16x8_t = bfloat16_t __attribute__((ext_vector_type(8)));
|
||||
using bf16x16_t = bfloat16_t __attribute__((ext_vector_type(16)));
|
||||
using bf16x32_t = bfloat16_t __attribute__((ext_vector_type(32)));
|
||||
using bf16x64_t = bfloat16_t __attribute__((ext_vector_type(64)));
|
||||
|
||||
// i32
|
||||
// using int32_t = ...
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#if __clang_major__ == 20
|
||||
#if __clang_major__ >= 20
|
||||
#include "ck_tile/core/arch/amd_buffer_addressing_builtins.hpp"
|
||||
#else
|
||||
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
|
||||
@@ -18,6 +18,7 @@
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/bfloat16.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/utility/ignore.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -133,6 +134,28 @@ struct buffer_view<address_space_enum::generic,
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
In the generic address space, we do not support the transpose instruction in the buffer view.
|
||||
Will report compilation error when developer wants to use it.
|
||||
*/
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE constexpr auto transpose_get(index_t i,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element,
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
static_assert(false, "Error: transpose load not supported in global memory space.");
|
||||
ignore = i;
|
||||
ignore = linear_offset;
|
||||
ignore = is_valid_element;
|
||||
return;
|
||||
}
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <memory_operation_enum Op,
|
||||
typename X,
|
||||
@@ -187,28 +210,6 @@ struct buffer_view<address_space_enum::generic,
|
||||
|
||||
// FIXME: remove
|
||||
CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("buffer_view{");
|
||||
|
||||
// AddressSpace
|
||||
printf("AddressSpace: generic, ");
|
||||
|
||||
// p_data_
|
||||
printf("p_data_: %p, ", static_cast<void*>(const_cast<remove_cvref_t<T>*>(p_data_)));
|
||||
|
||||
// buffer_size_
|
||||
printf("buffer_size_: ");
|
||||
print(buffer_size_);
|
||||
printf(", ");
|
||||
|
||||
// invalid_element_value_
|
||||
printf("invalid_element_value_: ");
|
||||
print(invalid_element_value_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
// Address Space: Global
|
||||
@@ -359,6 +360,28 @@ struct buffer_view<address_space_enum::global,
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
In the global memory address space, we do not support the transpose instruction in the buffer
|
||||
view. Will report compilation error when developer wants to use it.
|
||||
*/
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE constexpr auto transpose_get(index_t i,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element,
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
static_assert(false, "Error: transpose load not supported in global memory space.");
|
||||
ignore = i;
|
||||
ignore = linear_offset;
|
||||
ignore = is_valid_element;
|
||||
return;
|
||||
}
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
@@ -407,10 +430,12 @@ struct buffer_view<address_space_enum::global,
|
||||
"wrong! X should contain multiple T");
|
||||
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
const int32x4_t src_wave_buffer_resource =
|
||||
make_wave_buffer_resource(p_data_, (buffer_size_) * sizeof(type));
|
||||
|
||||
amd_async_buffer_load_with_oob<remove_cvref_t<T>, t_per_x, Coherence>(
|
||||
smem,
|
||||
cached_buf_res_,
|
||||
src_wave_buffer_resource,
|
||||
i,
|
||||
linear_offset,
|
||||
is_valid_element,
|
||||
@@ -710,28 +735,6 @@ struct buffer_view<address_space_enum::global,
|
||||
|
||||
// FIXME: remove
|
||||
CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("buffer_view{");
|
||||
|
||||
// AddressSpace
|
||||
printf("AddressSpace: Global, ");
|
||||
|
||||
// p_data_
|
||||
printf("p_data_: %p, ", static_cast<void*>(const_cast<remove_cvref_t<T>*>(p_data_)));
|
||||
|
||||
// buffer_size_
|
||||
printf("buffer_size_: ");
|
||||
print(buffer_size_);
|
||||
printf(", ");
|
||||
|
||||
// invalid_element_value_
|
||||
printf("invalid_element_value_: ");
|
||||
print(invalid_element_value_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
// Address Space: LDS
|
||||
@@ -852,6 +855,47 @@ struct buffer_view<address_space_enum::lds,
|
||||
smem_load<sizeof(X)>{}(dst, v_offset * sizeof(T), i_offset * sizeof(T));
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename std::enable_if<
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE constexpr auto transpose_get([[maybe_unused]] index_t i,
|
||||
[[maybe_unused]] index_t linear_offset,
|
||||
bool is_valid_element) const
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
|
||||
|
||||
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
|
||||
|
||||
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
||||
"wrong! X should contain multiple T");
|
||||
|
||||
if(is_valid_element)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
constexpr address_space_enum addr_space = get_address_space();
|
||||
return amd_transpose_load_to_vgpr<remove_cvref_t<T>, t_per_x, addr_space>(
|
||||
p_data_ + i + linear_offset);
|
||||
#else
|
||||
return X{numeric<remove_cvref_t<T>>::zero()};
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(InvalidElementUseNumericalZeroValue)
|
||||
{
|
||||
return X{numeric<remove_cvref_t<T>>::zero()};
|
||||
}
|
||||
else
|
||||
{
|
||||
return X{invalid_element_value_};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <memory_operation_enum Op,
|
||||
typename X,
|
||||
@@ -906,45 +950,39 @@ struct buffer_view<address_space_enum::lds,
|
||||
// ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to
|
||||
// ds_write_b128
|
||||
// TODO: remove this after compiler fix
|
||||
// clang-format off
|
||||
static_assert(
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x2_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x4_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8x4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x4_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8x8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8x16_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, int8_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, int8x2_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, int8x4_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8x4_t> && std::is_same_v<remove_cvref_t<X>, int8x4_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8x8_t> && std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8x16_t> && std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
|
||||
// int8 on thread buffer
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 16>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 8>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 4>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 2>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 1>>) ||
|
||||
// ext_vector_type for pk_int4 must use int8_t as type
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 1>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 2>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 4>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 8>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 16>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4x4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 4>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4x8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 8>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4x16_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 16>>),
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 1>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 2>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 4>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 8>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 16>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4x4_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 4>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4x8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 8>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4x16_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 16>>),
|
||||
"wrong! not implemented for this combination, please add "
|
||||
"implementation");
|
||||
// clang-format on
|
||||
|
||||
if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 1>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 1>>))
|
||||
{
|
||||
@@ -955,6 +993,8 @@ struct buffer_view<address_space_enum::lds,
|
||||
}
|
||||
else if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x2_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 2>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 2>>))
|
||||
{
|
||||
@@ -965,6 +1005,8 @@ struct buffer_view<address_space_enum::lds,
|
||||
}
|
||||
else if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x4_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 4>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 4>>))
|
||||
{
|
||||
@@ -975,6 +1017,8 @@ struct buffer_view<address_space_enum::lds,
|
||||
}
|
||||
else if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 8>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 8>>))
|
||||
{
|
||||
@@ -985,6 +1029,8 @@ struct buffer_view<address_space_enum::lds,
|
||||
}
|
||||
else if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 16>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 16>>))
|
||||
{
|
||||
@@ -1048,28 +1094,6 @@ struct buffer_view<address_space_enum::lds,
|
||||
|
||||
// FIXME: remove
|
||||
CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("buffer_view{");
|
||||
|
||||
// AddressSpace
|
||||
printf("AddressSpace: Lds, ");
|
||||
|
||||
// p_data_
|
||||
printf("p_data_: %p, ", static_cast<void*>(const_cast<remove_cvref_t<T>*>(p_data_)));
|
||||
|
||||
// buffer_size_
|
||||
printf("buffer_size_: ");
|
||||
print(buffer_size_);
|
||||
printf(", ");
|
||||
|
||||
// invalid_element_value_
|
||||
printf("invalid_element_value_: ");
|
||||
print(invalid_element_value_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
// Address Space: Vgpr
|
||||
@@ -1223,28 +1247,6 @@ struct buffer_view<address_space_enum::vgpr,
|
||||
|
||||
// FIXME: remove
|
||||
CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("buffer_view{");
|
||||
|
||||
// AddressSpace
|
||||
printf("AddressSpace: Vgpr, ");
|
||||
|
||||
// p_data_
|
||||
printf("p_data_: %p, ", static_cast<void*>(const_cast<remove_cvref_t<T>*>(p_data_)));
|
||||
|
||||
// buffer_size_
|
||||
printf("buffer_size_: ");
|
||||
print(buffer_size_);
|
||||
printf(", ");
|
||||
|
||||
// invalid_element_value_
|
||||
printf("invalid_element_value_: ");
|
||||
print(invalid_element_value_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <address_space_enum BufferAddressSpace,
|
||||
@@ -1270,4 +1272,25 @@ make_buffer_view(T* p, BufferSizeType buffer_size, X invalid_element_value)
|
||||
p, buffer_size, invalid_element_value};
|
||||
}
|
||||
|
||||
// Generalized print function for all buffer_view variants
|
||||
template <address_space_enum BufferAddressSpace,
|
||||
typename T,
|
||||
typename BufferSizeType,
|
||||
bool InvalidElementUseNumericalZeroValue,
|
||||
amd_buffer_coherence_enum Coherence>
|
||||
CK_TILE_HOST_DEVICE void print(const buffer_view<BufferAddressSpace,
|
||||
T,
|
||||
BufferSizeType,
|
||||
InvalidElementUseNumericalZeroValue,
|
||||
Coherence>& bv)
|
||||
{
|
||||
printf("buffer_view{AddressSpace: %s, p_data_: %p, buffer_size_: ",
|
||||
address_space_to_string(BufferAddressSpace),
|
||||
static_cast<void*>(const_cast<remove_cvref_t<T>*>(bv.p_data_)));
|
||||
print(bv.buffer_size_);
|
||||
printf(", invalid_element_value_: ");
|
||||
print(bv.invalid_element_value_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -18,32 +18,8 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
return tile_window.load(number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename LinearBottomDims_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_tile(const tile_window_linear<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
LinearBottomDims_>& tile_window,
|
||||
template <typename TileWindow_, index_t i_access = -1, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
@@ -51,35 +27,11 @@ CK_TILE_DEVICE auto load_tile(const tile_window_linear<BottomTensorView_,
|
||||
}
|
||||
|
||||
template <typename DistributedTensor_,
|
||||
typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
typename TileWindow_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile,
|
||||
const tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
return tile_window.load(dst_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename DistributedTensor_,
|
||||
typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename LinearBottomDims_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile,
|
||||
const tile_window_linear<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
LinearBottomDims_>& tile_window,
|
||||
const TileWindow_& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
@@ -138,42 +90,25 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
|
||||
}
|
||||
|
||||
template <typename LdsTileWindow_,
|
||||
typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
typename TileWindow_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE auto
|
||||
async_load_tile_raw(LdsTileWindow_&& lds_tile,
|
||||
const tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {})
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto async_load_tile(LdsTileWindow_&& lds_tile,
|
||||
const TileWindow_& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
return tile_window.async_load_raw(lds_tile,
|
||||
number<i_access>{},
|
||||
bool_constant<oob_conditional_check>{},
|
||||
bool_constant<pre_nop>{});
|
||||
return tile_window.async_load(
|
||||
lds_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename LdsTileWindow_,
|
||||
typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename LinearBottomDims_,
|
||||
typename TileWindow_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile,
|
||||
const tile_window_linear<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
LinearBottomDims_>& tile_window,
|
||||
const TileWindow_& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {})
|
||||
|
||||
446
include/ck_tile/core/tensor/load_tile_transpose.hpp
Normal file
446
include/ck_tile/core/tensor/load_tile_transpose.hpp
Normal file
@@ -0,0 +1,446 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/container/thread_buffer.hpp"
|
||||
#include "ck_tile/core/container/statically_indexed_array.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
constexpr int DS_READ_TR_SIZE()
|
||||
{
|
||||
return 8; // Literal constant, evaluated at compile time
|
||||
}
|
||||
|
||||
namespace util {
|
||||
template <typename Suffix, typename Sequence>
|
||||
struct is_sequence_suffix
|
||||
{
|
||||
static constexpr bool size_check = (Suffix::size() <= Sequence::size());
|
||||
|
||||
static constexpr index_t start_pos = Sequence::size() - Suffix::size();
|
||||
using extract_indices = typename arithmetic_sequence_gen<start_pos, Sequence::size(), 1>::type;
|
||||
|
||||
static constexpr bool value =
|
||||
size_check && (Suffix{} == decltype(Sequence::extract(extract_indices{})){});
|
||||
};
|
||||
|
||||
template <index_t... Xs>
|
||||
struct is_sequence_suffix<sequence<>, sequence<Xs...>>
|
||||
{
|
||||
static constexpr bool value = true;
|
||||
};
|
||||
|
||||
template <typename Suffix, typename Sequence>
|
||||
constexpr bool is_sequence_suffix_v = is_sequence_suffix<Suffix, Sequence>::value;
|
||||
|
||||
} // namespace util
|
||||
|
||||
// Default policy: Retains original 2D transpose behavior
|
||||
template <typename DataType>
|
||||
struct DefaultTranspose
|
||||
{
|
||||
template <index_t LaneGroupSize>
|
||||
struct Quad16
|
||||
{
|
||||
static_assert(LaneGroupSize == 64 || LaneGroupSize == 32 || LaneGroupSize == 16,
|
||||
"LaneGroupSize must be 64, 32, or 16");
|
||||
using InputEncoding =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<4>, sequence<LaneGroupSize / 16, 4, 4>>,
|
||||
tuple<sequence<2, 1, 2>>,
|
||||
tuple<sequence<0, 0, 1>>,
|
||||
sequence<2>,
|
||||
sequence<2>>;
|
||||
|
||||
using OutputEncoding =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<LaneGroupSize>, sequence<4>>,
|
||||
tuple<sequence<1>>,
|
||||
tuple<sequence<0>>,
|
||||
sequence<2>,
|
||||
sequence<0>>;
|
||||
};
|
||||
|
||||
template <index_t LaneGroupSize>
|
||||
struct Quad8
|
||||
{
|
||||
static_assert(LaneGroupSize == 64 || LaneGroupSize == 32 || LaneGroupSize == 16,
|
||||
"LaneGroupSize must be 64, 32, or 16");
|
||||
using InputEncoding =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<8>, sequence<LaneGroupSize / 16, 2, 8>>,
|
||||
tuple<sequence<2, 1, 2>>,
|
||||
tuple<sequence<0, 0, 1>>,
|
||||
sequence<2>,
|
||||
sequence<2>>;
|
||||
|
||||
using OutputEncoding =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<LaneGroupSize>, sequence<8>>,
|
||||
tuple<sequence<1>>,
|
||||
tuple<sequence<0>>,
|
||||
sequence<2>,
|
||||
sequence<0>>;
|
||||
};
|
||||
|
||||
// Select based on data size
|
||||
template <index_t LaneGroupSize>
|
||||
using QuadInputEncoding = std::conditional_t<sizeof(DataType) == 2,
|
||||
typename Quad16<LaneGroupSize>::InputEncoding,
|
||||
typename Quad8<LaneGroupSize>::InputEncoding>;
|
||||
|
||||
template <index_t LaneGroupSize>
|
||||
using QuadOutputEncoding = std::conditional_t<sizeof(DataType) == 2,
|
||||
typename Quad16<LaneGroupSize>::OutputEncoding,
|
||||
typename Quad8<LaneGroupSize>::OutputEncoding>;
|
||||
|
||||
// Always swap last two dimensions
|
||||
static constexpr auto transpose_dims = sequence<1, 0>{};
|
||||
|
||||
// Programmable: Element grouping function
|
||||
static constexpr auto group_func = [](auto idx) {
|
||||
return idx; // Identity mapping
|
||||
};
|
||||
|
||||
template <typename InDstrEncode, bool ReverseDirection, index_t LaneGroupSize>
|
||||
struct ValidationTraitsImpl
|
||||
{
|
||||
using QuadEncoding = std::conditional_t<ReverseDirection,
|
||||
QuadOutputEncoding<LaneGroupSize>,
|
||||
QuadInputEncoding<LaneGroupSize>>;
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static constexpr auto input_hs = InDstrEncode::hs_lengthss_;
|
||||
static constexpr auto quad_hs = QuadEncoding::hs_lengthss_;
|
||||
// 1. Must be 2D tensor
|
||||
static constexpr bool dims_valid = (InDstrEncode::NDimX == 2);
|
||||
// 2. Quad pattern must be suffix of input pattern
|
||||
static constexpr bool suffix_valid_dim0 =
|
||||
util::is_sequence_suffix_v<decltype(quad_hs[I0]), decltype(input_hs[I0])>;
|
||||
static constexpr bool suffix_valid_dim1 =
|
||||
util::is_sequence_suffix_v<decltype(quad_hs[I1]), decltype(input_hs[I1])>;
|
||||
|
||||
// 3. PS→RHS mapping constraints
|
||||
static constexpr auto input_ps_major = InDstrEncode::ps_to_rhss_major_;
|
||||
static constexpr auto input_ps_minor = InDstrEncode::ps_to_rhss_minor_;
|
||||
|
||||
static constexpr auto quad_ps_major0 = QuadEncoding::ps_to_rhss_major_[I0];
|
||||
static constexpr auto quad_ps_minor0 = QuadEncoding::ps_to_rhss_minor_[I0];
|
||||
|
||||
static constexpr auto input_ps_major_last =
|
||||
input_ps_major[number<input_ps_major.size() - 1>{}];
|
||||
static constexpr auto input_ps_minor_last =
|
||||
input_ps_minor[number<input_ps_minor.size() - 1>{}];
|
||||
|
||||
using psys_offset = ck_tile::sequence<input_hs[I0].size() - quad_hs[I0].size(),
|
||||
input_hs[I1].size() - quad_hs[I1].size()>;
|
||||
static constexpr auto shifted_quad_ps_minor0 = generate_sequence_v2(
|
||||
[](auto i) {
|
||||
return number<quad_ps_minor0[i] + psys_offset{}[quad_ps_major0[i] - 1]>{};
|
||||
},
|
||||
number<quad_ps_minor0.size()>{});
|
||||
|
||||
static constexpr bool ps_mapping_valid =
|
||||
util::is_sequence_suffix_v<decltype(quad_ps_major0), decltype(input_ps_major_last)> &&
|
||||
util::is_sequence_suffix_v<decltype(shifted_quad_ps_minor0),
|
||||
decltype(input_ps_minor_last)>;
|
||||
|
||||
// 4. YS→RHS mapping constraints
|
||||
static constexpr auto input_ys_major = InDstrEncode::ys_to_rhs_major_;
|
||||
static constexpr auto input_ys_minor = InDstrEncode::ys_to_rhs_minor_;
|
||||
static constexpr auto quad_ys_major = QuadEncoding::ys_to_rhs_major_;
|
||||
static constexpr auto quad_ys_minor = QuadEncoding::ys_to_rhs_minor_;
|
||||
|
||||
static_assert(quad_ys_major.size() == 1 && quad_ys_minor.size() == 1,
|
||||
"YS->RHS mapping must be single dimension");
|
||||
static_assert(quad_ys_major.back() == 2 && quad_ys_minor.back() == quad_hs[I1].size() - 1,
|
||||
"YS->RHS mapping must be the last dimension");
|
||||
static constexpr bool ys_mapping_valid =
|
||||
(input_ys_major.back() == 2) && (input_ys_minor.back() == input_hs[I1].size() - 1);
|
||||
|
||||
static constexpr bool value = dims_valid && suffix_valid_dim0 && suffix_valid_dim1 &&
|
||||
ps_mapping_valid && ys_mapping_valid;
|
||||
};
|
||||
|
||||
template <typename InDstrEncode, bool ReverseDirection = false>
|
||||
struct ValidationTraits
|
||||
{
|
||||
static constexpr bool value =
|
||||
ValidationTraitsImpl<InDstrEncode, ReverseDirection, 64>::value ||
|
||||
ValidationTraitsImpl<InDstrEncode, ReverseDirection, 32>::value ||
|
||||
ValidationTraitsImpl<InDstrEncode, ReverseDirection, 16>::value;
|
||||
static constexpr index_t LaneGroupSize =
|
||||
ValidationTraitsImpl<InDstrEncode, ReverseDirection, 64>::value ? 64
|
||||
: ValidationTraitsImpl<InDstrEncode, ReverseDirection, 32>::value ? 32
|
||||
: ValidationTraitsImpl<InDstrEncode, ReverseDirection, 16>::value ? 16
|
||||
: 0;
|
||||
};
|
||||
};
|
||||
template <typename TileDistribution_, typename DataType_, typename Policy>
|
||||
struct TransposeTileDistrChecker
|
||||
{
|
||||
using InDstrEncode = typename remove_cvref_t<TileDistribution_>::DstrEncode;
|
||||
|
||||
using Validator = typename Policy::template ValidationTraits<InDstrEncode>;
|
||||
|
||||
static constexpr bool distr_encoding_valid = Validator::value;
|
||||
};
|
||||
|
||||
// this is used to generate the transposed output tile distribution encoding
|
||||
// based on the input tile distribution encoding
|
||||
template <typename TileDistributionEncoding_,
|
||||
typename DataType_,
|
||||
typename Policy = DefaultTranspose<DataType_>,
|
||||
bool ReverseDirection = false>
|
||||
struct TransposeTileDistributionTraits
|
||||
{
|
||||
using InDstrEncode = remove_cvref_t<TileDistributionEncoding_>;
|
||||
static constexpr auto input_hs_lengthss = InDstrEncode::hs_lengthss_;
|
||||
static constexpr index_t LaneGroupSize =
|
||||
Policy::template ValidationTraits<InDstrEncode, ReverseDirection>::LaneGroupSize;
|
||||
static_assert(Policy::template ValidationTraits<InDstrEncode, ReverseDirection>::value,
|
||||
"The input tile distribution encoding is not valid for transpose!");
|
||||
|
||||
using QuadInputEncoding = std::conditional_t< //
|
||||
ReverseDirection,
|
||||
typename Policy::template QuadOutputEncoding<LaneGroupSize>,
|
||||
typename Policy::template QuadInputEncoding<LaneGroupSize>>;
|
||||
using QuadOutputEncoding = std::conditional_t< //
|
||||
ReverseDirection,
|
||||
typename Policy::template QuadInputEncoding<LaneGroupSize>,
|
||||
typename Policy::template QuadOutputEncoding<LaneGroupSize>>;
|
||||
|
||||
static constexpr auto quad_input_hs_lengthss = QuadInputEncoding::hs_lengthss_;
|
||||
static constexpr auto quad_output_hs_lengthss = QuadOutputEncoding::hs_lengthss_;
|
||||
|
||||
static constexpr auto input_ps_to_rhss_major = InDstrEncode::ps_to_rhss_major_;
|
||||
static constexpr auto input_ps_to_rhss_minor = InDstrEncode::ps_to_rhss_minor_;
|
||||
static constexpr auto input_ys_to_rhs_major = InDstrEncode::ys_to_rhs_major_;
|
||||
static constexpr auto input_ys_to_rhs_minor = InDstrEncode::ys_to_rhs_minor_;
|
||||
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto quad_input_ps_to_rhss_major0 = QuadInputEncoding::ps_to_rhss_major_[I0];
|
||||
static constexpr auto quad_input_ps_to_rhss_minor0 = QuadInputEncoding::ps_to_rhss_minor_[I0];
|
||||
static constexpr auto quad_output_ps_to_rhss_major0 = QuadOutputEncoding::ps_to_rhss_major_[I0];
|
||||
static constexpr auto quad_output_ps_to_rhss_minor0 = QuadOutputEncoding::ps_to_rhss_minor_[I0];
|
||||
static constexpr auto quad_output_ys_to_rhs_major = QuadOutputEncoding::ys_to_rhs_major_;
|
||||
static constexpr auto quad_output_ys_to_rhs_minor = QuadOutputEncoding::ys_to_rhs_minor_;
|
||||
|
||||
static constexpr index_t dim0 = Policy::transpose_dims[0];
|
||||
static constexpr index_t dim1 = Policy::transpose_dims[1];
|
||||
|
||||
static constexpr auto swap_one_and_two = [](const index_t idx) {
|
||||
return (idx == 1) ? 2 : (idx == 2) ? 1 : idx;
|
||||
};
|
||||
|
||||
// for transpose load
|
||||
// remove the quad_input_hs_lengthss from the input_hs_lengthss for each dimension and reverse
|
||||
// dims and append the quad_output_hs_lengthss to the end of each dimension
|
||||
static constexpr auto outer_hs_lengthss = generate_tuple(
|
||||
[](auto i) {
|
||||
constexpr auto input_i = input_hs_lengthss[i];
|
||||
constexpr auto outer_len = input_i.size() - quad_input_hs_lengthss[i].size();
|
||||
return typename sequence_split<decltype(input_i), outer_len>::left_type{};
|
||||
},
|
||||
number<InDstrEncode::NDimX>{});
|
||||
static constexpr auto reversed_outer_hs_lengthss = tuple_reverse(outer_hs_lengthss);
|
||||
static constexpr auto dst_out_hs_lengthss = generate_tuple(
|
||||
[](auto i) {
|
||||
auto outer_i = reversed_outer_hs_lengthss[i];
|
||||
// append the reversed quad output hs lengths to the outer hs lengths
|
||||
return outer_i.push_back(quad_output_hs_lengthss[i]);
|
||||
},
|
||||
number<InDstrEncode::NDimX>{});
|
||||
|
||||
// for PS→RHS mapping(both major and minor), we need to modify the last element (which is for
|
||||
// thread distr) of the major sequence
|
||||
static constexpr auto dst_ps_to_rhss_major = generate_tuple(
|
||||
// for major because of dst_out_hs_lengthss is reversed, this index also need to be reversed
|
||||
[](auto i) {
|
||||
if constexpr(i == input_ps_to_rhss_major.size() - 1)
|
||||
{
|
||||
constexpr auto current_size = input_ps_to_rhss_major[i].size();
|
||||
constexpr auto reduce_size = quad_input_ps_to_rhss_major0.size();
|
||||
constexpr auto quad_out = quad_output_ps_to_rhss_major0;
|
||||
constexpr auto reduced_ps_to_rhss_major = input_ps_to_rhss_major[i].extract(
|
||||
typename arithmetic_sequence_gen<0, current_size - reduce_size, 1>::type{});
|
||||
return reduced_ps_to_rhss_major.transform(swap_one_and_two).push_back(quad_out);
|
||||
}
|
||||
else
|
||||
{
|
||||
// For all other sequences (i.e. warp), keep them unchanged
|
||||
return input_ps_to_rhss_major[i].transform(swap_one_and_two);
|
||||
}
|
||||
},
|
||||
number<input_ps_to_rhss_major.size()>{});
|
||||
|
||||
static constexpr auto quad_idx_offset =
|
||||
transform_tuples([](auto x) { return number<x.size()>{}; }, reversed_outer_hs_lengthss);
|
||||
|
||||
// minus 1 because RsLength is not counted
|
||||
static constexpr auto quad_output_ps_minor_offset = to_sequence(generate_tuple_for(
|
||||
[](auto x) { return quad_idx_offset[number<x - 1>{}]; }, quad_output_ps_to_rhss_major0));
|
||||
static constexpr auto quad_output_ys_minor_offset = to_sequence(generate_tuple_for(
|
||||
[](auto x) { return quad_idx_offset[number<x - 1>{}]; }, quad_output_ys_to_rhs_major));
|
||||
|
||||
static constexpr auto dst_ps_to_rhss_minor = generate_tuple(
|
||||
[](auto i) {
|
||||
constexpr auto input_i = input_ps_to_rhss_minor[i];
|
||||
if constexpr(i == input_ps_to_rhss_minor.size() - 1)
|
||||
{
|
||||
constexpr auto outer_len = input_i.size() - quad_input_ps_to_rhss_minor0.size();
|
||||
constexpr auto outer_ps =
|
||||
typename sequence_split<decltype(input_i), outer_len>::left_type{};
|
||||
|
||||
return outer_ps.push_back(quad_output_ps_minor_offset +
|
||||
quad_output_ps_to_rhss_minor0);
|
||||
}
|
||||
else
|
||||
{
|
||||
// For all other sequences, keep them unchanged
|
||||
return input_i;
|
||||
}
|
||||
},
|
||||
number<input_ps_to_rhss_minor.size()>{});
|
||||
|
||||
static constexpr auto outer_input_ys_to_rhs_major = input_ys_to_rhs_major.pop_back();
|
||||
|
||||
// for major because of dst_out_hs_lengthss is reversed, this index also need to be reversed
|
||||
static constexpr auto dst_ys_to_rhs_major =
|
||||
outer_input_ys_to_rhs_major.transform(swap_one_and_two).push_back(number<2>{});
|
||||
|
||||
static constexpr auto dst_ys_to_rhs_minor = input_ys_to_rhs_minor.pop_back().push_back(
|
||||
number<(quad_output_ys_minor_offset + quad_output_ys_to_rhs_minor)[I0]>{});
|
||||
|
||||
using TransposedDstrEncode =
|
||||
tile_distribution_encoding<typename InDstrEncode::RsLengths,
|
||||
remove_cvref_t<decltype(dst_out_hs_lengthss)>,
|
||||
remove_cvref_t<decltype(dst_ps_to_rhss_major)>,
|
||||
remove_cvref_t<decltype(dst_ps_to_rhss_minor)>,
|
||||
remove_cvref_t<decltype(dst_ys_to_rhs_major)>,
|
||||
remove_cvref_t<decltype(dst_ys_to_rhs_minor)>>;
|
||||
};
|
||||
|
||||
template <typename TileDistributionEncoding_,
|
||||
typename DataType_,
|
||||
typename Policy = DefaultTranspose<DataType_>>
|
||||
using OutputTileDistributionTraits =
|
||||
TransposeTileDistributionTraits<TileDistributionEncoding_, DataType_, Policy, false>;
|
||||
template <typename TileDistributionEncoding_,
|
||||
typename DataType_,
|
||||
typename Policy = DefaultTranspose<DataType_>>
|
||||
using InputTileDistributionTraits =
|
||||
TransposeTileDistributionTraits<TileDistributionEncoding_, DataType_, Policy, true>;
|
||||
|
||||
template <typename InnerEncode,
|
||||
index_t kLeadIterPerWarp,
|
||||
index_t kSecondIterPerWarp,
|
||||
index_t kLeadNumWarps,
|
||||
index_t kSecondNumWarps>
|
||||
CK_TILE_HOST_DEVICE constexpr auto InputTileDistributionEncoding()
|
||||
{
|
||||
constexpr auto block_outer_dst_encoding =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<kSecondIterPerWarp, kSecondNumWarps>,
|
||||
sequence<kLeadIterPerWarp, kLeadNumWarps>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<2, 1>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto blk_distr_encode =
|
||||
detail::make_embed_tile_distribution_encoding(block_outer_dst_encoding, InnerEncode{});
|
||||
|
||||
return blk_distr_encode;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief transpose loads tile from a tensor and returns the resulting tensor with a new
|
||||
* (transposed) tile distribution. use SFINAE to ensure the tile distribution encoding is valid.
|
||||
*
|
||||
* This function is intended for use with statically distributed tensor tiles, where the input
|
||||
* and output tile distributions differ due to the transpose operation. It ensures that the
|
||||
* element space size and vector length remain consistent between the input and output
|
||||
* distributions.
|
||||
*
|
||||
* @tparam BottomTensorView_ The type of the bottom tensor view.
|
||||
* @tparam WindowLengths_ The type representing the window lengths.
|
||||
* @tparam TileDistribution_ The type representing the tile distribution.
|
||||
* @tparam NumCoord The number of coordinates (dimensions).
|
||||
* @tparam Policy The transpose policy to use (defaults to DefaultTranspose).
|
||||
* the last is SFINAE to ensure the tile distribution encoding is valid.
|
||||
*
|
||||
* @param tile_window The tile window with static distribution to load and transpose.
|
||||
*
|
||||
* @return A statically distributed tensor containing the transposed tile data.
|
||||
*
|
||||
* @note
|
||||
* - The function uses compile-time checks to ensure the input and output tile distributions
|
||||
* are compatible in terms of element space size and vector length.
|
||||
* - The transpose operation is performed according to the specified Policy.
|
||||
*/
|
||||
template <
|
||||
typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
typename Policy = DefaultTranspose<typename BottomTensorView_::DataType>,
|
||||
typename = std::enable_if_t<TransposeTileDistrChecker<TileDistribution_,
|
||||
typename BottomTensorView_::DataType,
|
||||
Policy>::distr_encoding_valid,
|
||||
Policy>>
|
||||
CK_TILE_DEVICE auto
|
||||
load_tile_transpose(const tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window)
|
||||
{
|
||||
using OutTileDstrEncode = typename OutputTileDistributionTraits<
|
||||
typename TileDistribution_::DstrEncode,
|
||||
typename BottomTensorView_::DataType>::TransposedDstrEncode;
|
||||
auto out_tensor = make_static_distributed_tensor<typename BottomTensorView_::DataType>(
|
||||
make_static_tile_distribution(OutTileDstrEncode{}));
|
||||
auto trans_tensor = tile_window.template load_transpose<Policy>();
|
||||
constexpr auto input_distr = TileDistribution_{};
|
||||
constexpr auto output_distr = make_static_tile_distribution(OutTileDstrEncode{});
|
||||
|
||||
constexpr auto y_in_desc = input_distr.get_ys_to_d_descriptor();
|
||||
constexpr auto y_out_desc = output_distr.get_ys_to_d_descriptor();
|
||||
|
||||
constexpr index_t NDimYIn = input_distr.get_num_of_dimension_y();
|
||||
constexpr index_t NDimYOut = output_distr.get_num_of_dimension_y();
|
||||
|
||||
constexpr auto y_in_lengths = to_sequence(y_in_desc.get_lengths());
|
||||
constexpr auto y_out_lengths = to_sequence(y_out_desc.get_lengths());
|
||||
|
||||
constexpr auto y_in_element_space_size = y_in_desc.get_element_space_size();
|
||||
constexpr auto y_out_element_space_size = y_out_desc.get_element_space_size();
|
||||
static_assert(y_in_element_space_size == y_out_element_space_size,
|
||||
"the element space size is not the same!");
|
||||
static_assert(y_in_lengths[NDimYIn - 1] == y_out_lengths[NDimYOut - 1],
|
||||
"the vector length is not the same!");
|
||||
constexpr index_t vecLoadSize = y_in_lengths[NDimYIn - 1];
|
||||
constexpr index_t num_of_access =
|
||||
reduce_on_sequence(y_in_lengths, multiplies{}, number<1>{}) / vecLoadSize;
|
||||
|
||||
using DataVec = array<typename BottomTensorView_::DataType, vecLoadSize>;
|
||||
static_for<0, num_of_access, 1>{}([&](auto iAccess) {
|
||||
out_tensor.get_thread_buffer().template set_as<DataVec>(
|
||||
number<iAccess>{},
|
||||
trans_tensor.get_thread_buffer().template get_as<DataVec>(number<iAccess>{}));
|
||||
});
|
||||
|
||||
return out_tensor;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -53,10 +53,13 @@ struct is_null_tile_window<null_tile_window<T>> : public std::true_type
|
||||
};
|
||||
} // namespace impl
|
||||
|
||||
template <typename T>
|
||||
constexpr bool is_null_tile_window_v = impl::is_null_tile_window<remove_cvref_t<T>>::value;
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE constexpr auto is_null_tile_window(const T&)
|
||||
{
|
||||
return impl::is_null_tile_window<remove_cvref_t<T>>::value;
|
||||
return is_null_tile_window_v<remove_cvref_t<T>>;
|
||||
}
|
||||
|
||||
template <typename WindowLengths>
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -129,7 +129,10 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT
|
||||
// set output vectors
|
||||
static_for<0, num_vec_out, 1>{}([&](auto i) {
|
||||
constexpr auto idx_y_out_tmp = generate_array(
|
||||
[&](auto ii) { return ii == y_dim_vec_in ? idx_y_start[ii] + i : idx_y_start[ii]; },
|
||||
[&](auto ii) {
|
||||
return ii == y_dim_vec_in ? static_cast<index_t>(idx_y_start[ii]) + i
|
||||
: static_cast<index_t>(idx_y_start[ii]);
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr auto idx_y_out =
|
||||
|
||||
@@ -303,6 +303,6 @@ struct tile_sweeper
|
||||
template <typename T,
|
||||
typename F,
|
||||
typename U = typename uniform_sequence_gen<T::get_num_of_dimension(), 1>::type>
|
||||
CK_TILE_HOST_DEVICE_EXTERN tile_sweeper(const T&, const F&, U = {})->tile_sweeper<T, F, U>;
|
||||
CK_TILE_HOST_DEVICE_EXTERN tile_sweeper(const T&, const F&, U = {}) -> tile_sweeper<T, F, U>;
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -81,7 +81,7 @@ struct tensor_adaptor
|
||||
|
||||
template <index_t IDimHidden>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto
|
||||
get_transform_and_its_upper_dimension(number<IDimHidden>)
|
||||
get_transform_and_its_upper_dimension(number<IDimHidden>)
|
||||
{
|
||||
// FIXME: length of bottom dimension is not known, since info about lower dim length are not
|
||||
// saved in transformation
|
||||
@@ -119,13 +119,13 @@ struct tensor_adaptor
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_hidden_dimension()
|
||||
{
|
||||
constexpr auto all_low_dim_ids = unpack(
|
||||
[](auto&&... xs) constexpr { return merge_sequences(xs...); },
|
||||
LowerDimensionHiddenIdss{});
|
||||
constexpr auto all_low_dim_ids =
|
||||
unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); },
|
||||
LowerDimensionHiddenIdss{});
|
||||
|
||||
constexpr auto all_up_dim_ids = unpack(
|
||||
[](auto&&... xs) constexpr { return merge_sequences(xs...); },
|
||||
UpperDimensionHiddenIdss{});
|
||||
constexpr auto all_up_dim_ids =
|
||||
unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); },
|
||||
UpperDimensionHiddenIdss{});
|
||||
|
||||
constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids);
|
||||
|
||||
@@ -259,6 +259,7 @@ struct tensor_adaptor
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() { return is_static(); }
|
||||
|
||||
template <index_t Internal = 0>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_top_dimension_safe_vector_length_strides(
|
||||
const array<index_t, ndim_hidden_>& guaranteed_vector_lengths,
|
||||
const array<index_t, ndim_hidden_>& guaranteed_vector_strides)
|
||||
@@ -266,7 +267,9 @@ struct tensor_adaptor
|
||||
auto vector_lengths = guaranteed_vector_lengths;
|
||||
auto vector_strides = guaranteed_vector_strides;
|
||||
|
||||
static_for<0, get_num_of_transform(), 1>{}([&](auto itran) {
|
||||
static_for<0,
|
||||
Internal ? std::min(Internal, get_num_of_transform()) : get_num_of_transform(),
|
||||
1>{}([&](auto itran) {
|
||||
constexpr auto low_dims = get_lower_dimension_hidden_idss().at(itran);
|
||||
constexpr auto up_dims = get_upper_dimension_hidden_idss().at(itran);
|
||||
|
||||
@@ -298,42 +301,16 @@ struct tensor_adaptor
|
||||
set_container_subset(vector_lengths, up_dims, up_vector_lengths);
|
||||
set_container_subset(vector_strides, up_dims, up_vector_strides);
|
||||
});
|
||||
|
||||
constexpr auto top_dims = TopDimensionHiddenIds{};
|
||||
|
||||
return make_tuple(get_container_subset(vector_lengths, top_dims),
|
||||
get_container_subset(vector_strides, top_dims));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("tensor_adaptor{");
|
||||
|
||||
//
|
||||
printf("transforms: ");
|
||||
print(transforms_);
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("LowerDimensionHiddenIds: ");
|
||||
print(LowerDimensionHiddenIdss{});
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("UpperDimensionHiddenIds: ");
|
||||
print(UpperDimensionHiddenIdss{});
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("BottomDimensionHiddenIds: ");
|
||||
print(BottomDimensionHiddenIds{});
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("TopDimensionHiddenIds: ");
|
||||
print(TopDimensionHiddenIds{});
|
||||
|
||||
printf("}");
|
||||
if constexpr(Internal > 0)
|
||||
{
|
||||
return make_tuple(vector_lengths, vector_strides);
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto top_dims = TopDimensionHiddenIds{};
|
||||
return make_tuple(get_container_subset(vector_lengths, top_dims),
|
||||
get_container_subset(vector_strides, top_dims));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -341,6 +318,40 @@ struct tensor_adaptor
|
||||
ElementSize element_size_;
|
||||
};
|
||||
|
||||
template <typename Transforms,
|
||||
typename LowerDimensionHiddenIdss,
|
||||
typename UpperDimensionHiddenIdss,
|
||||
typename BottomDimensionHiddenIds,
|
||||
typename TopDimensionHiddenIds>
|
||||
CK_TILE_HOST_DEVICE static void print(const tensor_adaptor<Transforms,
|
||||
LowerDimensionHiddenIdss,
|
||||
UpperDimensionHiddenIdss,
|
||||
BottomDimensionHiddenIds,
|
||||
TopDimensionHiddenIds>& adaptor)
|
||||
{
|
||||
printf("tensor_adaptor{\n");
|
||||
printf(" transforms: [");
|
||||
print(adaptor.get_transforms());
|
||||
printf("],\n");
|
||||
|
||||
printf(" LowerDimensionHiddenIds: [");
|
||||
print(LowerDimensionHiddenIdss{});
|
||||
printf("],\n");
|
||||
|
||||
printf(" UpperDimensionHiddenIds: [");
|
||||
print(UpperDimensionHiddenIdss{});
|
||||
printf("],\n");
|
||||
|
||||
printf(" BottomDimensionHiddenIds: [");
|
||||
print(BottomDimensionHiddenIds{});
|
||||
printf("],\n");
|
||||
|
||||
//
|
||||
printf(" TopDimensionHiddenIds: [");
|
||||
print(TopDimensionHiddenIds{});
|
||||
printf("]\n}\n");
|
||||
}
|
||||
|
||||
// Transforms: Tuple<transforms...>
|
||||
// LowerDimensionOldTopIdss: Tuple<Sequence<...>, ...>
|
||||
// UpperDimensionNewTopIdss: Tuple<Sequence<...>, ...>
|
||||
@@ -461,7 +472,7 @@ transform_tensor_adaptor(const OldTensorAdaptor& old_tensor_adaptor,
|
||||
sequence<0>{}, inclusive_scan_sequence(up_dim_numbers, plus<index_t>{}, number<0>{}));
|
||||
|
||||
constexpr auto up_dim_hidden_idss = generate_tuple(
|
||||
[ old_hidden_dim_number, up_dim_numbers_scan ](auto i) constexpr {
|
||||
[old_hidden_dim_number, up_dim_numbers_scan](auto i) constexpr {
|
||||
return
|
||||
typename arithmetic_sequence_gen<old_hidden_dim_number + up_dim_numbers_scan[i],
|
||||
old_hidden_dim_number + up_dim_numbers_scan[i + 1],
|
||||
@@ -470,8 +481,8 @@ transform_tensor_adaptor(const OldTensorAdaptor& old_tensor_adaptor,
|
||||
number<num_new_transform>{});
|
||||
|
||||
// new top dimension's hidden ids
|
||||
constexpr auto unordered_new_top_dim_hidden_ids = unpack(
|
||||
[](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss);
|
||||
constexpr auto unordered_new_top_dim_hidden_ids =
|
||||
unpack([](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss);
|
||||
|
||||
constexpr auto new_top_dim_unordered2ordered = unpack(
|
||||
[](auto... xs) constexpr { return merge_sequences(xs...); }, NewUpperDimensionNewTopIdss{});
|
||||
@@ -595,8 +606,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
|
||||
TensorAdaptor1::get_lower_dimension_hidden_idss()[itran];
|
||||
|
||||
// sequence in, sequence out
|
||||
constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr
|
||||
{
|
||||
constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr {
|
||||
auto low_dim_hidden_ids_1_mod_ = to_multi_index(low_dim_hidden_ids_1);
|
||||
|
||||
// shift hidden id so every dim id is unique
|
||||
@@ -619,8 +629,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
|
||||
});
|
||||
|
||||
return low_dim_hidden_ids_1_mod_;
|
||||
}
|
||||
();
|
||||
}();
|
||||
|
||||
return generate_sequence_v2(
|
||||
[&](auto i) constexpr { return number<low_dim_hidden_ids_1_mod[i]>{}; },
|
||||
@@ -643,8 +652,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
|
||||
TensorAdaptor1::get_upper_dimension_hidden_idss()[itran];
|
||||
|
||||
// sequence in, constexpr tuple out
|
||||
constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr
|
||||
{
|
||||
constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr {
|
||||
auto up_dim_hidden_ids_1_mod_ = to_multi_index(up_dim_hidden_ids_1);
|
||||
|
||||
// shift hidden id
|
||||
@@ -653,8 +661,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
|
||||
});
|
||||
|
||||
return up_dim_hidden_ids_1_mod_;
|
||||
}
|
||||
();
|
||||
}();
|
||||
|
||||
// constexpr tuple to sequence
|
||||
return generate_sequence_v2(
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -133,32 +133,45 @@ struct tensor_descriptor : public tensor_adaptor<Transforms,
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() { return is_static(); }
|
||||
|
||||
template <index_t Internal = 0>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_top_dimension_safe_vector_length_strides()
|
||||
{
|
||||
return Base::get_top_dimension_safe_vector_length_strides(
|
||||
return Base::template get_top_dimension_safe_vector_length_strides<Internal>(
|
||||
to_array<index_t, ndim_hidden_>(GuaranteedVectorLengths{}),
|
||||
to_array<index_t, ndim_hidden_>(GuaranteedVectorStrides{}));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("tensor_descriptor{");
|
||||
|
||||
// tensor_adaptor
|
||||
Base::print();
|
||||
printf(", ");
|
||||
|
||||
// element_space_size_
|
||||
printf("element_space_size_: ");
|
||||
print(element_space_size_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
|
||||
// TODO make these private
|
||||
ElementSpaceSize element_space_size_;
|
||||
};
|
||||
|
||||
template <typename Transforms,
|
||||
typename LowerDimensionHiddenIdss,
|
||||
typename UpperDimensionHiddenIdss,
|
||||
typename TopDimensionHiddenIds,
|
||||
typename ElementSpaceSize,
|
||||
typename GuaranteedVectorLengths,
|
||||
typename GuaranteedVectorStrides>
|
||||
CK_TILE_HOST_DEVICE static void print(const tensor_descriptor<Transforms,
|
||||
LowerDimensionHiddenIdss,
|
||||
UpperDimensionHiddenIdss,
|
||||
TopDimensionHiddenIds,
|
||||
ElementSpaceSize,
|
||||
GuaranteedVectorLengths,
|
||||
GuaranteedVectorStrides>& descriptor)
|
||||
{
|
||||
printf("tensor_descriptor{\n");
|
||||
// first print the tensor adaptor part of the descriptor using the base class print
|
||||
print(static_cast<const typename decltype(descriptor)::Base&>(descriptor));
|
||||
printf("element_space_size_: %ld,\n",
|
||||
static_cast<long>(descriptor.get_element_space_size().value));
|
||||
printf("guaranteed_vector_lengths: ");
|
||||
print(GuaranteedVectorLengths{});
|
||||
printf(",\nguaranteed_vector_strides: ");
|
||||
print(GuaranteedVectorStrides{});
|
||||
printf("}\n}\n");
|
||||
}
|
||||
|
||||
template <typename Adaptor, typename ElementSpaceSize>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_tensor_descriptor_from_adaptor(const Adaptor& adaptor,
|
||||
@@ -365,12 +378,29 @@ make_naive_tensor_descriptor_packed(const tuple<Lengths...>& lengths,
|
||||
|
||||
const auto element_space_size = container_reduce(lengths, multiplies{}, long_number<1>{});
|
||||
|
||||
constexpr index_t first_dim_length = []() {
|
||||
if constexpr(is_constant_v<remove_cvref_t<decltype(element_space_size)>>)
|
||||
return decltype(element_space_size)::value;
|
||||
else
|
||||
return -1;
|
||||
}();
|
||||
using last_t = remove_cvref_t<decltype(lengths.template get<N - 1>())>;
|
||||
constexpr index_t last_dim_length = []() {
|
||||
if constexpr(is_constant_v<last_t>)
|
||||
return std::max(last_t::value, GuaranteedLastDimensionVectorLength);
|
||||
else
|
||||
return -1;
|
||||
}();
|
||||
|
||||
using GuaranteedVectorLengths =
|
||||
typename sequence_merge<typename uniform_sequence_gen<N, -1>::type,
|
||||
sequence<GuaranteedLastDimensionVectorLength>>::type;
|
||||
typename sequence_merge<sequence<first_dim_length>,
|
||||
typename uniform_sequence_gen<N - 1, -1>::type,
|
||||
sequence<last_dim_length>>::type;
|
||||
|
||||
using GuaranteedVectorStrides =
|
||||
typename sequence_merge<typename uniform_sequence_gen<N, -1>::type, sequence<1>>::type;
|
||||
typename sequence_merge<sequence<1>,
|
||||
typename uniform_sequence_gen<N - 1, -1>::type,
|
||||
sequence<1>>::type;
|
||||
|
||||
return tensor_descriptor<remove_cv_t<decltype(transforms)>,
|
||||
remove_cv_t<decltype(low_dim_hidden_idss)>,
|
||||
|
||||
@@ -161,7 +161,8 @@ struct tensor_view
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
async_get_vectorized_elements(CK_TILE_LDS_ADDR remove_cvref_t<DataType>* smem,
|
||||
const TensorCoord& coord,
|
||||
index_t linear_offset) const
|
||||
index_t linear_offset,
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
return buf_.template async_get<X>(
|
||||
smem,
|
||||
@@ -181,7 +182,8 @@ struct tensor_view
|
||||
async_get_vectorized_elements(CK_TILE_LDS_ADDR remove_cvref_t<DataType>* smem,
|
||||
const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element) const
|
||||
bool is_valid_element,
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
return buf_.template async_get<X>(smem,
|
||||
coord.get_offset() / PackedSize,
|
||||
@@ -210,6 +212,27 @@ struct tensor_view
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool pre_nop = false,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
async_get_vectorized_elements_raw(remove_cvref_t<DataType>* smem,
|
||||
const TensorCoord& coord,
|
||||
index_t coord_extra_offset,
|
||||
index_t linear_offset,
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
return buf_.template async_get_raw<X>(
|
||||
smem,
|
||||
(coord.get_offset() + coord_extra_offset) / PackedSize,
|
||||
linear_offset / PackedSize,
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool pre_nop = false,
|
||||
typename std::enable_if<
|
||||
@@ -230,6 +253,33 @@ struct tensor_view
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<X>
|
||||
get_transpose_vectorized_elements(const TensorCoord& coord, index_t linear_offset) const
|
||||
{
|
||||
return buf_.template transpose_get<X>(
|
||||
coord.get_offset(),
|
||||
linear_offset,
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord));
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<X>
|
||||
get_transpose_vectorized_elements(const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element // flag
|
||||
) const
|
||||
{
|
||||
return buf_.template transpose_get<X>(coord.get_offset(), linear_offset, is_valid_element);
|
||||
}
|
||||
// X is vector of DataType.
|
||||
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
|
||||
template <typename X,
|
||||
@@ -384,22 +434,6 @@ struct tensor_view
|
||||
coord.get_offset() / PackedSize, linear_offset / PackedSize, is_valid_element, x);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("tensor_view{");
|
||||
|
||||
// buf_
|
||||
printf("buf_: ");
|
||||
print(buf_);
|
||||
printf(", ");
|
||||
|
||||
// desc_
|
||||
printf("desc_: ");
|
||||
print(desc_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
|
||||
// member
|
||||
buffer_view buf_;
|
||||
TensorDesc desc_;
|
||||
@@ -411,18 +445,21 @@ struct null_tensor_view
|
||||
};
|
||||
|
||||
template <address_space_enum BufferAddressSpace = address_space_enum::generic,
|
||||
amd_buffer_coherence_enum Coherence = amd_buffer_coherence_enum::coherence_default,
|
||||
typename DataType,
|
||||
typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType* p,
|
||||
const tensor_descriptor<Ts...>& desc)
|
||||
{
|
||||
auto buffer_view = make_buffer_view<BufferAddressSpace>(p, desc.get_element_space_size());
|
||||
auto buffer_view =
|
||||
make_buffer_view<BufferAddressSpace, Coherence>(p, desc.get_element_space_size());
|
||||
|
||||
return tensor_view<decltype(buffer_view), decltype(desc)>{buffer_view, desc};
|
||||
}
|
||||
|
||||
template <address_space_enum BufferAddressSpace = address_space_enum::generic,
|
||||
memory_operation_enum DstInMemOp = memory_operation_enum::set,
|
||||
amd_buffer_coherence_enum Coherence = amd_buffer_coherence_enum::coherence_default,
|
||||
typename DataType,
|
||||
typename... Lengths,
|
||||
typename... Strides,
|
||||
@@ -441,12 +478,14 @@ make_naive_tensor_view(DataType* p,
|
||||
number<GuaranteedLastDimensionVectorLength>{},
|
||||
number<GuaranteedLastDimensionVectorStride>{});
|
||||
|
||||
auto buffer_view = make_buffer_view<BufferAddressSpace>(p, desc.get_element_space_size());
|
||||
auto buffer_view =
|
||||
make_buffer_view<BufferAddressSpace, Coherence>(p, desc.get_element_space_size());
|
||||
|
||||
return tensor_view<decltype(buffer_view), decltype(desc), DstInMemOp>{buffer_view, desc};
|
||||
}
|
||||
|
||||
template <address_space_enum BufferAddressSpace = address_space_enum::generic,
|
||||
amd_buffer_coherence_enum Coherence = amd_buffer_coherence_enum::coherence_default,
|
||||
typename DataType,
|
||||
typename... Lengths,
|
||||
index_t GuaranteedLastDimensionVectorLength = -1>
|
||||
@@ -458,7 +497,8 @@ make_naive_tensor_view_packed(DataType* p,
|
||||
auto desc =
|
||||
make_naive_tensor_descriptor_packed(lengths, number<GuaranteedLastDimensionVectorLength>{});
|
||||
|
||||
auto buffer_view = make_buffer_view<BufferAddressSpace>(p, desc.get_element_space_size());
|
||||
auto buffer_view =
|
||||
make_buffer_view<BufferAddressSpace, Coherence>(p, desc.get_element_space_size());
|
||||
|
||||
return tensor_view<decltype(buffer_view), decltype(desc)>{buffer_view, desc};
|
||||
}
|
||||
@@ -488,6 +528,7 @@ template <typename TensorView,
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
pad_tensor_view(const TensorView& tensor_view, const TileLengths& tile_lengths, DoPads)
|
||||
{
|
||||
|
||||
constexpr index_t num_dim = DoPads::size();
|
||||
|
||||
static_assert(num_dim == TileLengths::size() && num_dim == TensorView::get_num_of_dimension(),
|
||||
|
||||
@@ -204,7 +204,7 @@ struct tile_distribution
|
||||
// FIXME: it's hacky to get Y index from Distributed-Index
|
||||
template <typename DistributedIndices>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto
|
||||
get_y_indices_from_distributed_indices(DistributedIndices)
|
||||
get_y_indices_from_distributed_indices(DistributedIndices)
|
||||
{
|
||||
constexpr auto ys_idx_arr = [] {
|
||||
array<index_t, NDimY> ys_idx;
|
||||
@@ -230,24 +230,6 @@ struct tile_distribution
|
||||
{
|
||||
return PsYs2XsAdaptor::is_static() && Ys2DDescriptor::is_static();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("tile_distribution{");
|
||||
//
|
||||
printf("tile_distribution_encoding: ");
|
||||
print(DstrEncode{});
|
||||
printf(", ");
|
||||
//
|
||||
printf("ps_ys_to_xs_: ");
|
||||
print(ps_ys_to_xs_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ys_to_d_: ");
|
||||
print(ys_to_d_);
|
||||
//
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
@@ -268,7 +250,7 @@ CK_TILE_HOST_DEVICE constexpr auto make_sequential_index(index_t ibegin, index_t
|
||||
// this returns a constexpr encoding of tile_distribution
|
||||
template <typename StaticTileDistributionEncoding_>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_)
|
||||
make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_)
|
||||
{
|
||||
using RsLengths = typename StaticTileDistributionEncoding_::RsLengths;
|
||||
using HsLengthss = typename StaticTileDistributionEncoding_::HsLengthss;
|
||||
@@ -544,26 +526,26 @@ namespace detail {
|
||||
//
|
||||
// e.g
|
||||
// X0 X1
|
||||
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 32>, (0 means all length)
|
||||
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice start:<0, 0>, end:<-1, 32>, (-1 means the last one)
|
||||
// Y P P Y P Y P Y
|
||||
// => <1, 4, 32> - <1, 1, 4, 2, 4> -> OK
|
||||
// |--> slice along this Y dim, is the first dim of X1, totally 4 slices
|
||||
//
|
||||
// X0 X1
|
||||
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 8>, (0 means all length)
|
||||
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice start:<0, 0>, end:<-1, 8>, (-1 means the last one)
|
||||
// Y P P Y P Y P Y
|
||||
// => <1, 4, 32> - <1, 1, 1, 2, 4> -> OK
|
||||
// |--> slice along this Y dim, the P dim is 1 in the left, so is OK
|
||||
// totally 16 slices
|
||||
//
|
||||
// X0 X1
|
||||
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 4>, (0 means all length)
|
||||
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice start:<0, 0>, end:<-1, 4>, (-1 means the last one)
|
||||
// Y P P Y P Y P Y
|
||||
// => <1, 4, 32> - <1, 1, 1, 1, 4> -> Fail
|
||||
// |--> slice along this P dim, will split threads, not supported
|
||||
//
|
||||
// X0 X1
|
||||
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 16>, (0 means all length)
|
||||
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice start:<0, 0>, end:<-1, 16>, (-1 means the last one)
|
||||
// Y P P Y P Y P Y
|
||||
// => <1, 4, 32> - <1, 1, 2, 2, 4> -> OK
|
||||
// |--> slice along this Y dim, but this Y sim need to split into 2
|
||||
@@ -579,27 +561,55 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
|
||||
using Encoding = decltype(Distribution::get_static_tile_distribution_encoding());
|
||||
|
||||
static_assert(sizeof...(XSliceBegins) == sizeof...(XSliceEnds));
|
||||
static_assert(sizeof...(XSliceBegins) == Encoding::NDimX, "only support slice over h, not r");
|
||||
|
||||
constexpr auto x_slice_lengths = x_slice_ends - x_slice_begins;
|
||||
constexpr auto p_len_over_h = Encoding::detail::get_uniformed_p_dim_lengths_over_h();
|
||||
|
||||
constexpr auto x_slice_ends_ = generate_sequence_v2(
|
||||
[&](auto i) {
|
||||
if constexpr(x_slice_ends[i] == -1)
|
||||
{
|
||||
// -1 means till the end
|
||||
constexpr auto x_length_ =
|
||||
container_reduce(typename Encoding::HsLengthss{}[i], multiplies{}, number<1>{});
|
||||
return x_length_;
|
||||
}
|
||||
else
|
||||
{
|
||||
return x_slice_ends[i];
|
||||
}
|
||||
},
|
||||
number<x_slice_ends.size()>{});
|
||||
|
||||
constexpr auto x_slice_lengths = x_slice_ends_ - x_slice_begins;
|
||||
|
||||
constexpr auto x_slice_lengths_without_p = generate_sequence_v2(
|
||||
[&](auto i) constexpr {
|
||||
constexpr auto len_ = x_slice_lengths[i];
|
||||
static_assert(len_ % p_len_over_h[i] == 0,
|
||||
"slice length must be dividable by p_len_over_h");
|
||||
return number<len_ / p_len_over_h[i]>{};
|
||||
},
|
||||
number<x_slice_lengths.size()>{});
|
||||
|
||||
constexpr auto src_h_prefix_sum = Encoding::detail::get_h_dim_lengths_prefix_sum();
|
||||
constexpr auto src_y_info = Encoding::detail::get_sorted_y_info();
|
||||
constexpr auto src_y_info = Encoding::detail::get_sorted_y_to_h_info();
|
||||
constexpr auto src_y_dims = src_y_info[number<0>{}];
|
||||
constexpr auto src_y_maps = src_y_info[number<1>{}];
|
||||
constexpr auto src_y_prefix_sum = src_y_info[number<2>{}];
|
||||
|
||||
constexpr auto sliced_hlen_yidx_ylen = [&]() constexpr
|
||||
{
|
||||
constexpr auto sliced_hlen_yidx_ylen = [&]() constexpr {
|
||||
auto y_slice_sorted_origins = make_zero_multi_index<Encoding::NDimY>();
|
||||
auto y_slice_lengths = Encoding::detail::ys_lengths_;
|
||||
constexpr auto y_to_h_masks = Encoding::detail::get_y_to_h_masks();
|
||||
|
||||
// This lambda will modify some value outside, so c++ will not treat return value as
|
||||
// constexpr
|
||||
// TODO: ugly
|
||||
auto new_h_lengths = transform_tuples(
|
||||
[&](auto h_len, auto id) {
|
||||
constexpr auto sliced_h =
|
||||
reverse_slice_sequence(h_len, number<x_slice_lengths[id]>{});
|
||||
constexpr auto sliced_h = reverse_slice_sequence(
|
||||
h_len, number<x_slice_lengths_without_p[id]>{}, y_to_h_masks[id]);
|
||||
|
||||
constexpr auto sliced_h_lens = sliced_h[number<0>{}];
|
||||
constexpr auto sliced_h_index = sliced_h[number<2>{}];
|
||||
@@ -607,26 +617,39 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
|
||||
// update y_slice_lengths
|
||||
constexpr auto uniformed_h_index = sliced_h_index + number<src_h_prefix_sum[id]>{};
|
||||
constexpr auto found_y_index = container_find(src_y_dims, uniformed_h_index);
|
||||
constexpr auto y_to_h_dim_end = src_y_prefix_sum[id + 1];
|
||||
|
||||
static_assert(found_y_index >= 0 && found_y_index < src_y_dims.size(),
|
||||
"not sliced at y dim, please check");
|
||||
|
||||
static_for<0, sliced_h_index + 1, 1>{}([&](auto i) {
|
||||
y_slice_lengths(src_y_maps[found_y_index - i]) =
|
||||
sliced_h_lens[sliced_h_index - i];
|
||||
});
|
||||
{
|
||||
constexpr auto sliced_y_to_h_lens =
|
||||
pick_sequence_elements_by_mask(sliced_h_lens, y_to_h_masks[id]);
|
||||
constexpr auto sliced_y_to_h_dims = sliced_y_to_h_lens.size();
|
||||
static_for<0, sliced_y_to_h_dims, 1>{}([&](auto i) {
|
||||
y_slice_lengths(src_y_maps[y_to_h_dim_end - 1 - i]) =
|
||||
sliced_y_to_h_lens[sliced_y_to_h_dims - 1 - i];
|
||||
});
|
||||
}
|
||||
// TODO: add validations not across p dim
|
||||
|
||||
// NOTE: this y_origin is for all dims, not only current dim
|
||||
// will later use pick to select target dim
|
||||
constexpr auto y_origin = [&]() {
|
||||
constexpr auto h_trans = make_merge_transform_v3_division_mod(h_len);
|
||||
auto h_origin_ = make_zero_multi_index<h_trans.NDimLow>();
|
||||
h_trans.calculate_lower_index(h_origin_, sequence<x_slice_begins[id].value>{});
|
||||
// can't use Encoding::Ys2RHsMajor/Ys2RHsMinor, these are unordered
|
||||
constexpr auto y_to_h_len =
|
||||
pick_sequence_elements_by_mask(h_len, y_to_h_masks[id]);
|
||||
constexpr auto y_to_h_dims = y_to_h_len.size();
|
||||
|
||||
constexpr auto h_trans = make_merge_transform_v3_division_mod(y_to_h_len);
|
||||
auto h_origin_ = make_zero_multi_index<h_trans.NDimLow>();
|
||||
constexpr auto y_begin_ = x_slice_begins[id] / p_len_over_h[id];
|
||||
h_trans.calculate_lower_index(h_origin_, sequence<y_begin_.value>{});
|
||||
|
||||
auto y_origin_ = make_zero_multi_index<Encoding::NDimY>();
|
||||
static_for<0, sliced_h_index + 1, 1>{}([&](auto i) {
|
||||
y_origin_(found_y_index - i) = h_origin_[sliced_h_index - i];
|
||||
|
||||
static_for<0, y_to_h_dims, 1>{}([&](auto i) {
|
||||
y_origin_(y_to_h_dim_end - 1 - i) = h_origin_[y_to_h_dims - 1 - i];
|
||||
});
|
||||
return y_origin_;
|
||||
}();
|
||||
@@ -645,8 +668,7 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
|
||||
auto y_slice_origins = container_reorder_given_old2new(y_slice_sorted_origins, src_y_maps);
|
||||
|
||||
return make_tuple(new_h_lengths, y_slice_origins, y_slice_lengths);
|
||||
}
|
||||
();
|
||||
}();
|
||||
|
||||
constexpr auto sliced_h_lengths = sliced_hlen_yidx_ylen[number<0>{}];
|
||||
constexpr auto sliced_y_origins_array = sliced_hlen_yidx_ylen[number<1>{}];
|
||||
@@ -672,4 +694,27 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// Free print function for tile_distribution
|
||||
template <typename PsYs2XsAdaptor_,
|
||||
typename Ys2DDescriptor_,
|
||||
typename StaticTileDistributionEncoding_,
|
||||
typename TileDistributionDetail_>
|
||||
CK_TILE_HOST_DEVICE void print(const tile_distribution<PsYs2XsAdaptor_,
|
||||
Ys2DDescriptor_,
|
||||
StaticTileDistributionEncoding_,
|
||||
TileDistributionDetail_>& distribution)
|
||||
{
|
||||
printf("tile_distribution{");
|
||||
printf("tile_distribution_encoding: ");
|
||||
print(StaticTileDistributionEncoding_{});
|
||||
printf(", ");
|
||||
printf("ps_ys_to_xs_: ");
|
||||
print(distribution.ps_ys_to_xs_);
|
||||
printf(", ");
|
||||
printf("ys_to_d_: ");
|
||||
print(distribution.ys_to_d_);
|
||||
printf("}\n");
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -47,6 +47,11 @@ struct tile_distribution_encoding
|
||||
static constexpr auto ys_to_rhs_major_ = Ys2RHsMajor{};
|
||||
static constexpr auto ys_to_rhs_minor_ = Ys2RHsMinor{};
|
||||
|
||||
#if !CK_TILE_ENC_SUPPORT_Y_TO_R
|
||||
static_assert(container_find(ys_to_rhs_major_, 0) == NDimY,
|
||||
"do not support Y dim pointed to R dim");
|
||||
#endif
|
||||
|
||||
// redundant but useful info
|
||||
// TODO: really bad code, should be over-hauled
|
||||
struct detail
|
||||
@@ -255,33 +260,107 @@ struct tile_distribution_encoding
|
||||
}
|
||||
}();
|
||||
|
||||
// e.g. tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>> --> seq<3, 5> --> seq<0, 3, 8>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_h_dim_lengths_prefix_sum()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_h_dim_lengths()
|
||||
{
|
||||
// <len_d0, len_d1, ...>
|
||||
// e.g. tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>> --> seq<3, 5>
|
||||
constexpr auto uniformed_h_dim_lengths = generate_sequence_v2(
|
||||
[&](auto i) {
|
||||
constexpr index_t size = HsLengthss{}[i].size();
|
||||
return number<size>{};
|
||||
constexpr index_t size_ = HsLengthss{}[i].size();
|
||||
return number<size_>{};
|
||||
},
|
||||
number<NDimX>{});
|
||||
return uniformed_h_dim_lengths;
|
||||
}
|
||||
|
||||
// note: this function only count the p dim length along h, not r
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_p_dim_lengths_over_h()
|
||||
{
|
||||
// e.g. tuple<seq<1, 4, 32>, seq<1, 2, 8, 4, 4>>
|
||||
// Y P Y Y P Y P Y
|
||||
// | | |
|
||||
// v v v
|
||||
// return : seq<4, 2 * 4> => seq<4, 8>
|
||||
constexpr auto uniformed_ps_to_rhss_major_ =
|
||||
unpack([](auto... xs_) { return merge_sequences(xs_...); }, ps_to_rhss_major_);
|
||||
constexpr auto uniformed_ps_to_rhss_minor_ =
|
||||
unpack([](auto... xs_) { return merge_sequences(xs_...); }, ps_to_rhss_minor_);
|
||||
|
||||
constexpr auto p_len_ = [&]() {
|
||||
array<index_t, NDimX> len_{1};
|
||||
static_for<0, NDimX, 1>{}([&](auto idim_x_) {
|
||||
constexpr auto major_ = number<idim_x_ + 1>{}; // RDim
|
||||
static_for<0, uniformed_ps_to_rhss_major_.size(), 1>{}([&](auto idim_u_) {
|
||||
if constexpr(major_.value == uniformed_ps_to_rhss_major_[idim_u_])
|
||||
{
|
||||
constexpr auto minor_ = uniformed_ps_to_rhss_minor_[idim_u_];
|
||||
constexpr auto h_length_ = hs_lengthss_[idim_x_][minor_];
|
||||
len_[idim_x_] *= h_length_;
|
||||
}
|
||||
});
|
||||
});
|
||||
return len_;
|
||||
}();
|
||||
constexpr auto p_len_over_h_seq_ = TO_SEQUENCE(p_len_, NDimX);
|
||||
return p_len_over_h_seq_;
|
||||
}
|
||||
|
||||
//
|
||||
// R: seq<3>, H: tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>>
|
||||
// => return seq<1, 3, 5>
|
||||
// R: seq<>, H: tuple<seq<2, 4>, seq<16, 8, 8>>
|
||||
// => return seq<0, 2, 3>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_rh_dim_lengths()
|
||||
{
|
||||
constexpr auto uniformed_rh_dim_lengths =
|
||||
merge_sequences(sequence<NDimR>{} /*for R dims*/, get_uniformed_h_dim_lengths());
|
||||
|
||||
return uniformed_rh_dim_lengths;
|
||||
}
|
||||
|
||||
// e.g. tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>> --> seq<3, 5> --> seq<0, 3, 8>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_h_dim_lengths_prefix_sum()
|
||||
{
|
||||
// <0, len_d0, len_d0+len_d1, ...>
|
||||
// e.g. seq<3, 5> --> seq<0, 3, 8>
|
||||
constexpr auto h_dim_prefix_sum = prefix_sum_sequence(uniformed_h_dim_lengths);
|
||||
constexpr auto h_dim_prefix_sum = prefix_sum_sequence(get_uniformed_h_dim_lengths());
|
||||
|
||||
return h_dim_prefix_sum;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_idx_y_to_h()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_rh_dim_lengths_prefix_sum()
|
||||
{
|
||||
// <0, len_d0, len_d0+len_d1, ...>
|
||||
// e.g. seq<3, 5> --> seq<0, 3, 8>
|
||||
constexpr auto rh_dim_prefix_sum = prefix_sum_sequence(get_uniformed_rh_dim_lengths());
|
||||
|
||||
return rh_dim_prefix_sum;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_idx_p_to_h()
|
||||
{
|
||||
// tuple<seq<xx..>, seq<yy..>> -> seq<xx..yy..>
|
||||
constexpr auto uniformed_ps_to_rhss_major_ =
|
||||
unpack([](auto... xs_) { return merge_sequences(xs_...); }, ps_to_rhss_major_);
|
||||
constexpr auto uniformed_ps_to_rhss_minor_ =
|
||||
unpack([](auto... xs_) { return merge_sequences(xs_...); }, ps_to_rhss_minor_);
|
||||
|
||||
constexpr auto all_ps_2_rhss = transform_sequences(
|
||||
[](auto major, auto minor) constexpr {
|
||||
constexpr auto rh_dim_prefix_sum = get_rh_dim_lengths_prefix_sum();
|
||||
return rh_dim_prefix_sum.at(major) + minor;
|
||||
},
|
||||
uniformed_ps_to_rhss_major_,
|
||||
uniformed_ps_to_rhss_minor_);
|
||||
|
||||
return all_ps_2_rhss;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_idx_y_to_rh()
|
||||
{
|
||||
constexpr auto all_ys_2_rhss = transform_sequences(
|
||||
[](auto major, auto minor) constexpr {
|
||||
// <0, 0, len_d0, len_d0+len_d1, ...>
|
||||
constexpr auto x_dim_prefix_sum = merge_sequences(
|
||||
sequence<0>{} /*for R dims*/, get_h_dim_lengths_prefix_sum());
|
||||
return x_dim_prefix_sum.at(major) + minor;
|
||||
constexpr auto rh_dim_prefix_sum = get_rh_dim_lengths_prefix_sum();
|
||||
return rh_dim_prefix_sum.at(major) + minor;
|
||||
},
|
||||
Ys2RHsMajor{},
|
||||
Ys2RHsMinor{});
|
||||
@@ -289,6 +368,45 @@ struct tile_distribution_encoding
|
||||
return all_ys_2_rhss;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_idx_y_to_h()
|
||||
{
|
||||
// TODO: Y can't point to R
|
||||
constexpr auto all_ys_2_rhss = transform_sequences(
|
||||
[](auto major, auto minor) constexpr {
|
||||
constexpr auto rh_dim_prefix_sum = get_rh_dim_lengths_prefix_sum();
|
||||
return rh_dim_prefix_sum.at(major) + minor - NDimR;
|
||||
},
|
||||
Ys2RHsMajor{},
|
||||
Ys2RHsMinor{});
|
||||
|
||||
return all_ys_2_rhss;
|
||||
}
|
||||
|
||||
// return tuple of seq
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_y_to_h_masks()
|
||||
{
|
||||
constexpr auto masks_ = generate_tuple(
|
||||
[&](auto i) {
|
||||
constexpr auto size_ = HsLengthss{}[i].size();
|
||||
constexpr auto current_y_to_h_mask_ = [&]() {
|
||||
array<index_t, size_> m_{0};
|
||||
// TODO: we loop over all y for each h dim
|
||||
for(auto j = 0; j < NDimY; j++)
|
||||
{
|
||||
if(Ys2RHsMajor{}[j] == (i + 1) /*RDim need plus 1*/)
|
||||
{
|
||||
m_[Ys2RHsMinor{}[j]] = 1;
|
||||
}
|
||||
}
|
||||
return m_;
|
||||
}();
|
||||
|
||||
return TO_SEQUENCE(current_y_to_h_mask_, size_);
|
||||
},
|
||||
number<NDimX>{});
|
||||
return masks_;
|
||||
}
|
||||
|
||||
// return tuple<sorted_dims, sorted_maps, sorted_prefix_sum>
|
||||
template <typename IdxSeq, typename PrefixSumSeq>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_sorted_info(IdxSeq, PrefixSumSeq)
|
||||
@@ -305,115 +423,34 @@ struct tile_distribution_encoding
|
||||
return make_tuple(sorted_dims, sorted_maps, sorted_prefix_sum);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_sorted_y_info()
|
||||
// Note here y_to_h does not count R dim!
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_sorted_y_to_h_info()
|
||||
{
|
||||
return get_sorted_info(get_uniformed_idx_y_to_h(), get_h_dim_lengths_prefix_sum());
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("tile_distribution_encoding::detail{");
|
||||
//
|
||||
printf("ndim_rh_major_: ");
|
||||
print(ndim_rh_major_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ndim_span_major_: ");
|
||||
print(ndim_span_major_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ndims_rhs_minor_: ");
|
||||
print(ndims_rhs_minor_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ndim_rh_major_: ");
|
||||
print(ndim_rh_major_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("max_ndim_rh_minor_: ");
|
||||
print(max_ndim_rh_minor_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("rhs_lengthss_: ");
|
||||
print(rhs_lengthss_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ys_lengths_: ");
|
||||
print(ys_lengths_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("rhs_major_minor_to_ys_: ");
|
||||
print(rhs_major_minor_to_ys_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ndims_span_minor_: ");
|
||||
print(ndims_span_minor_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("max_ndim_span_minor_: ");
|
||||
print(max_ndim_span_minor_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ys_to_span_major_: ");
|
||||
print(ys_to_span_major_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ys_to_span_minor_: ");
|
||||
print(ys_to_span_minor_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("distributed_spans_lengthss_: ");
|
||||
print(distributed_spans_lengthss_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ndims_distributed_spans_minor_: ");
|
||||
print(ndims_distributed_spans_minor_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ps_over_rs_derivative_: ");
|
||||
print(ps_over_rs_derivative_);
|
||||
//
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("tile_distribution_encoding{");
|
||||
//
|
||||
printf("NDimX: %d, NDimP: %d, NDimY: %d, ", NDimX, NDimP, NDimY);
|
||||
//
|
||||
printf("rs_lengths_: ");
|
||||
print(rs_lengths_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("hs_lengthss_: ");
|
||||
print(hs_lengthss_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ps_to_rhss_major_: ");
|
||||
print(ps_to_rhss_major_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ps_to_rhss_minor_: ");
|
||||
print(ps_to_rhss_minor_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ys_to_rhs_major_: ");
|
||||
print(ys_to_rhs_major_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ys_to_rhs_minor_: ");
|
||||
print(ys_to_rhs_minor_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("detail: ");
|
||||
print(detail{});
|
||||
//
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename encoding, typename shuffle>
|
||||
class tile_distribution_encoding_shuffle;
|
||||
template <typename encoding, index_t... shuffle>
|
||||
class tile_distribution_encoding_shuffle<encoding, sequence<shuffle...>>
|
||||
{
|
||||
template <typename Ys2RHs>
|
||||
using shuffled = sequence<(Ys2RHs::template get<shuffle>())...>;
|
||||
|
||||
public:
|
||||
using type = tile_distribution_encoding<typename encoding::RsLengths,
|
||||
typename encoding::HsLengthss,
|
||||
typename encoding::Ps2RHssMajor,
|
||||
typename encoding::Ps2RHssMinor,
|
||||
shuffled<typename encoding::Ys2RHsMajor>,
|
||||
shuffled<typename encoding::Ys2RHsMinor>>;
|
||||
};
|
||||
template <typename encoding, typename shuffle>
|
||||
using tile_distribution_encoding_shuffle_t =
|
||||
typename tile_distribution_encoding_shuffle<encoding, shuffle>::type;
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename OuterDstr, typename InnerDstr>
|
||||
@@ -757,4 +794,106 @@ make_reduce_tile_distribution_encoding(InDstr, sequence<InReduceDimXs...> reduce
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// Free print function for tile_distribution_encoding::detail
|
||||
template <typename RsLengths_,
|
||||
typename HsLengthss_,
|
||||
typename Ps2RHssMajor_,
|
||||
typename Ps2RHssMinor_,
|
||||
typename Ys2RHsMajor_,
|
||||
typename Ys2RHsMinor_>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
print(const typename tile_distribution_encoding<RsLengths_,
|
||||
HsLengthss_,
|
||||
Ps2RHssMajor_,
|
||||
Ps2RHssMinor_,
|
||||
Ys2RHsMajor_,
|
||||
Ys2RHsMinor_>::detail& detail_obj)
|
||||
{
|
||||
printf("tile_distribution_encoding::detail{");
|
||||
printf("ndim_rh_major_: ");
|
||||
print(detail_obj.ndim_rh_major_);
|
||||
printf(", ");
|
||||
printf("ndim_span_major_: ");
|
||||
print(detail_obj.ndim_span_major_);
|
||||
printf(", ");
|
||||
printf("ndims_rhs_minor_: ");
|
||||
print(detail_obj.ndims_rhs_minor_);
|
||||
printf(", ");
|
||||
printf("ndim_rh_major_: ");
|
||||
print(detail_obj.ndim_rh_major_);
|
||||
printf(", ");
|
||||
printf("max_ndim_rh_minor_: ");
|
||||
print(detail_obj.max_ndim_rh_minor_);
|
||||
printf(", ");
|
||||
printf("rhs_lengthss_: ");
|
||||
print(detail_obj.rhs_lengthss_);
|
||||
printf(", ");
|
||||
printf("ys_lengths_: ");
|
||||
print(detail_obj.ys_lengths_);
|
||||
printf(", ");
|
||||
printf("rhs_major_minor_to_ys_: ");
|
||||
print(detail_obj.rhs_major_minor_to_ys_);
|
||||
printf(", ");
|
||||
printf("ndims_span_minor_: ");
|
||||
print(detail_obj.ndims_span_minor_);
|
||||
printf(", ");
|
||||
printf("max_ndim_span_minor_: ");
|
||||
print(detail_obj.max_ndim_span_minor_);
|
||||
printf(", ");
|
||||
printf("ys_to_span_major_: ");
|
||||
print(detail_obj.ys_to_span_major_);
|
||||
printf(", ");
|
||||
printf("ys_to_span_minor_: ");
|
||||
print(detail_obj.ys_to_span_minor_);
|
||||
printf(", ");
|
||||
printf("distributed_spans_lengthss_: ");
|
||||
print(detail_obj.distributed_spans_lengthss_);
|
||||
printf(", ");
|
||||
printf("ndims_distributed_spans_minor_: ");
|
||||
print(detail_obj.ndims_distributed_spans_minor_);
|
||||
printf(", ");
|
||||
printf("ps_over_rs_derivative_: ");
|
||||
print(detail_obj.ps_over_rs_derivative_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
// Free print function for tile_distribution_encoding
|
||||
template <typename RsLengths_,
|
||||
typename HsLengthss_,
|
||||
typename Ps2RHssMajor_,
|
||||
typename Ps2RHssMinor_,
|
||||
typename Ys2RHsMajor_,
|
||||
typename Ys2RHsMinor_>
|
||||
CK_TILE_HOST_DEVICE void print(const tile_distribution_encoding<RsLengths_,
|
||||
HsLengthss_,
|
||||
Ps2RHssMajor_,
|
||||
Ps2RHssMinor_,
|
||||
Ys2RHsMajor_,
|
||||
Ys2RHsMinor_>& encoding)
|
||||
{
|
||||
printf("tile_distribution_encoding{");
|
||||
|
||||
printf("NDimX: %d, NDimP: %d, NDimY: %d, ", encoding.NDimX, encoding.NDimP, encoding.NDimY);
|
||||
printf("rs_lengths_: ");
|
||||
print(encoding.rs_lengths_);
|
||||
printf(", ");
|
||||
printf("hs_lengthss_: ");
|
||||
print(encoding.hs_lengthss_);
|
||||
printf(", ");
|
||||
printf("ps_to_rhss_major_: ");
|
||||
print(encoding.ps_to_rhss_major_);
|
||||
printf(", ");
|
||||
printf("ps_to_rhss_minor_: ");
|
||||
print(encoding.ps_to_rhss_minor_);
|
||||
printf(", ");
|
||||
printf("ys_to_rhs_major_: ");
|
||||
print(encoding.ys_to_rhs_major_);
|
||||
printf(", ");
|
||||
printf("ys_to_rhs_minor_: ");
|
||||
print(encoding.ys_to_rhs_minor_);
|
||||
printf(", ");
|
||||
printf("}");
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -59,6 +59,38 @@ CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc& in_element_func,
|
||||
return out_dstr_tensor;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Template function that "unpacks" a tuple and applies an element-wise operation.
|
||||
*
|
||||
* @param in_element_func Function to apply element-wise.
|
||||
* @param t Any container containing elements to process, with known size and
|
||||
* tuple-like semantic.
|
||||
* @return Calls tile_elementwise_inout with unpacked tuple elements.
|
||||
*/
|
||||
template <typename InElementFunc, typename Tuple, size_t... I>
|
||||
CK_TILE_DEVICE auto tile_elementwise_inout_unpack(const InElementFunc& in_element_func,
|
||||
const Tuple& t,
|
||||
std::index_sequence<I...>)
|
||||
{
|
||||
return tile_elementwise_inout(in_element_func, t[number<I>{}]...);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Template function that "unpacks" a tuple and applies an element-wise operation.
|
||||
*
|
||||
* @param in_element_func Function to apply element-wise.
|
||||
* @param t Any container containing elements to process, with known size and
|
||||
* tuple-like semantic.
|
||||
* @return Calls the overloaded function, passing an index sequence.
|
||||
*/
|
||||
template <typename InElementFunc, typename Tuple>
|
||||
CK_TILE_DEVICE auto tile_elementwise_inout_unpack(const InElementFunc& in_element_func,
|
||||
const Tuple& t)
|
||||
{
|
||||
static constexpr auto size = Tuple::size();
|
||||
return tile_elementwise_inout_unpack(in_element_func, t, std::make_index_sequence<size>{});
|
||||
}
|
||||
|
||||
template <typename DstrTensors, typename T>
|
||||
CK_TILE_DEVICE void set_tile(DstrTensors& dstr_tensor, const T& value)
|
||||
{
|
||||
@@ -295,9 +327,8 @@ CK_TILE_DEVICE auto cast_tile_opt_subdword(const InTensor& in_dstr_tensors)
|
||||
template <typename DstType, typename SrcTensor>
|
||||
CK_TILE_DEVICE auto cast_tile(const SrcTensor& src_tensor)
|
||||
{
|
||||
if constexpr((std::is_same_v<DstType, fp8_t> ||
|
||||
std::is_same_v<DstType, bf8_t>)&&std::is_same_v<typename SrcTensor::DataType,
|
||||
float> &&
|
||||
if constexpr((std::is_same_v<DstType, fp8_t> || std::is_same_v<DstType, bf8_t>) &&
|
||||
std::is_same_v<typename SrcTensor::DataType, float> &&
|
||||
(SrcTensor::get_thread_buffer_size() % 4 == 0))
|
||||
{
|
||||
return impl::cast_tile_pk_fp8_fp32<DstType, SrcTensor>(src_tensor);
|
||||
|
||||
858
include/ck_tile/core/tensor/tile_scatter_gather.hpp
Normal file
858
include/ck_tile/core/tensor/tile_scatter_gather.hpp
Normal file
@@ -0,0 +1,858 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/utility.hpp"
|
||||
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/**
|
||||
* @brief This class provides tile (windowed) view and access to the device memory.
|
||||
*
|
||||
* @note This tile window does not support single issue you need to use tile_window_linear
|
||||
* structure for this purpose
|
||||
*
|
||||
* @tparam BottomTensorView_ Class describing & holding device tensor memory.
|
||||
* @tparam WindowLengths_ Spatial sizes of windowed view on tensor.
|
||||
* @tparam StaticTileDistribution_ Thread distribution (mapping) into Tile dimensions
|
||||
* @tparam NumCoord TBD
|
||||
*/
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename StaticTileDistribution_,
|
||||
typename StaticPageIndexArray_,
|
||||
typename StaticValidArray_,
|
||||
index_t HsGatherDim = 0,
|
||||
index_t NumCoord = 1,
|
||||
index_t YsGatherDim = 0>
|
||||
struct tile_scatter_gather
|
||||
{
|
||||
using BottomTensorView = remove_reference_t<BottomTensorView_>;
|
||||
using WindowLengths = remove_cvref_t<WindowLengths_>;
|
||||
using TileDstr = remove_cvref_t<StaticTileDistribution_>;
|
||||
using PageIdxArray = remove_cvref_t<StaticPageIndexArray_>;
|
||||
using ValidArray = remove_cvref_t<StaticValidArray_>;
|
||||
using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor;
|
||||
using BottomTensorDesc = typename BottomTensorView::TensorDesc;
|
||||
|
||||
using DataType = remove_cvref_t<typename BottomTensorView::DataType>;
|
||||
|
||||
static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::get_num_of_top_dimension();
|
||||
static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension();
|
||||
|
||||
static constexpr index_t NDimP = TileDstr::get_num_of_dimension_p();
|
||||
static constexpr index_t NDimY = TileDstr::get_num_of_dimension_y();
|
||||
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static_assert(NumCoord == 1);
|
||||
|
||||
// TODO: check WindowLengths and StaticTileDistribution are consistent
|
||||
|
||||
static_assert(ck_tile::is_known_at_compile_time<WindowLengths>::value,
|
||||
"wrong! lengths should be static");
|
||||
static_assert(TileDstr::is_static(), "wrong!");
|
||||
|
||||
static_assert(NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(),
|
||||
"wrong! inconsistent # of diemsnions");
|
||||
|
||||
using AdaptorTopIndex = array<index_t, NDimWindowAdaptorTop>;
|
||||
using BottomTensorIndex = array<index_t, NDimBottomTensor>;
|
||||
|
||||
using WindowAdaptorCoord =
|
||||
decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{}));
|
||||
|
||||
using BottomTensorCoord =
|
||||
decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{}));
|
||||
|
||||
struct load_store_traits
|
||||
{
|
||||
private:
|
||||
static constexpr auto get_vector_dim_y_scalar_per_vector()
|
||||
{
|
||||
const auto [ys_vector_lengths, ys_vector_strides] =
|
||||
tile_scatter_gather::get_window_adaptor_ys_safe_vector_length_strides();
|
||||
|
||||
index_t VectorDimY_ = 0;
|
||||
index_t ScalarPerVector_ = 1;
|
||||
|
||||
for(index_t i = 0; i < NDimY; ++i)
|
||||
{
|
||||
if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_)
|
||||
{
|
||||
ScalarPerVector_ = ys_vector_lengths[i];
|
||||
VectorDimY_ = i;
|
||||
}
|
||||
}
|
||||
|
||||
return make_tuple(VectorDimY_, ScalarPerVector_);
|
||||
}
|
||||
|
||||
public:
|
||||
static constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
|
||||
static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>();
|
||||
static constexpr index_t ScalarPerVector =
|
||||
get_vector_dim_y_scalar_per_vector().template at<1>();
|
||||
|
||||
// using vector_type_t = vector_type_maker_t<DataType, ScalarPerVector>;
|
||||
// using vector_t = typename vector_type_t::type;
|
||||
using vector_t = thread_buffer<DataType, ScalarPerVector / PackedSize>;
|
||||
|
||||
private:
|
||||
static constexpr auto scalars_per_access_ = [] {
|
||||
constexpr auto scalars_per_access_arr = generate_array(
|
||||
[&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, number<NDimY>{});
|
||||
|
||||
/// TODO: add non-automatic storage argument support to macro TO_SEQUENCE()
|
||||
constexpr auto NDimY_ = NDimY;
|
||||
|
||||
return TO_SEQUENCE(scalars_per_access_arr, NDimY_);
|
||||
}();
|
||||
|
||||
static constexpr auto get_space_filling_curve()
|
||||
{
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
|
||||
constexpr auto thread_tensor_lengths_ys =
|
||||
to_sequence(tile_dstr.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
// FIXME: need logic to judge dim access order
|
||||
using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type;
|
||||
|
||||
return space_filling_curve<decltype(thread_tensor_lengths_ys),
|
||||
DimAccessOrder,
|
||||
decltype(scalars_per_access_)>{};
|
||||
}
|
||||
|
||||
public:
|
||||
using SFC_Ys = decltype(get_space_filling_curve());
|
||||
|
||||
static constexpr index_t NumAccess = SFC_Ys::get_num_of_access();
|
||||
|
||||
static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0");
|
||||
static_assert(NumAccess % NumCoord == 0, "wrong! # of access is not divisible by NumCoord");
|
||||
};
|
||||
|
||||
static constexpr index_t NumAccessPerCoord = load_store_traits::NumAccess / NumCoord;
|
||||
|
||||
CK_TILE_DEVICE constexpr tile_scatter_gather() = default;
|
||||
|
||||
CK_TILE_DEVICE constexpr tile_scatter_gather(const BottomTensorView& bottom_tensor_view,
|
||||
const WindowLengths& window_lengths,
|
||||
const BottomTensorIndex& window_origin,
|
||||
const TileDstr& tile_distribution,
|
||||
const PageIdxArray& page_idx,
|
||||
const ValidArray& valids)
|
||||
: bottom_tensor_view_{bottom_tensor_view},
|
||||
window_lengths_{window_lengths},
|
||||
window_origin_{window_origin},
|
||||
tile_dstr_{tile_distribution},
|
||||
page_idx_{page_idx},
|
||||
valids_{valids},
|
||||
pre_computed_coords_{}
|
||||
{
|
||||
#if 0 // debug
|
||||
// TODO: this use more register for FA, but less register for GEMM
|
||||
// need investigation
|
||||
// only support warp-tile and block-tile
|
||||
static_assert(NDimP == 1 or NDimP == 2, "wrong!");
|
||||
|
||||
WindowAdaptorCoord window_adaptor_thread_coord_tmp;
|
||||
|
||||
if constexpr(NDimP == 1)
|
||||
{
|
||||
window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
|
||||
tile_distribution.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0});
|
||||
}
|
||||
else if constexpr(NDimP == 2)
|
||||
{
|
||||
window_adaptor_thread_coord_tmp =
|
||||
make_tensor_adaptor_coordinate(tile_distribution.get_ps_ys_to_xs_adaptor(),
|
||||
AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
|
||||
}
|
||||
#else
|
||||
// TODO: this use less register for FA, but more register for GEMM
|
||||
// need investigation
|
||||
const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
|
||||
tile_distribution.get_ps_ys_to_xs_adaptor(),
|
||||
container_concat(detail::get_partition_index(tile_distribution),
|
||||
array<index_t, NDimY>{0}));
|
||||
#endif
|
||||
|
||||
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
|
||||
window_origin + window_adaptor_thread_coord_tmp.get_bottom_index();
|
||||
bottom_tensor_thread_origin_idx_tmp(HsGatherDim) = 0;
|
||||
const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
|
||||
bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
|
||||
|
||||
// pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
|
||||
// future load/store() calls (might allocate more registers)
|
||||
using Traits = load_store_traits;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
|
||||
auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
|
||||
|
||||
constexpr auto idx_diff_ys =
|
||||
SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}), idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
|
||||
pre_computed_coords_(iCoord) =
|
||||
make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; }
|
||||
|
||||
CK_TILE_DEVICE static constexpr bool has_static_tile_distribution()
|
||||
{
|
||||
return TileDstr::is_static();
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; }
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_tile_distribution() const { return tile_dstr_; }
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; }
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; }
|
||||
|
||||
CK_TILE_DEVICE constexpr void
|
||||
set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data)
|
||||
{
|
||||
bottom_tensor_view_.buf_.p_data_ = data;
|
||||
}
|
||||
|
||||
// move thread's window adaptor coordinate and bottom tensor coordinate
|
||||
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
|
||||
template <typename ATopIndex>
|
||||
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
WindowAdaptorCoord& window_adaptor_thread_coord,
|
||||
BottomTensorCoord& bottom_tensor_thread_coord,
|
||||
const ATopIndex& idx_diff_adaptor_top) const
|
||||
{
|
||||
array<index_t, NDimBottomTensor> idx_diff_adaptor_bottom;
|
||||
|
||||
move_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
|
||||
window_adaptor_thread_coord,
|
||||
idx_diff_adaptor_top,
|
||||
idx_diff_adaptor_bottom);
|
||||
|
||||
move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
|
||||
bottom_tensor_thread_coord,
|
||||
idx_diff_adaptor_bottom);
|
||||
}
|
||||
|
||||
// return vector dimension among [y0, y1, ...]
|
||||
CK_TILE_DEVICE static constexpr auto get_window_adaptor_ys_safe_vector_length_strides()
|
||||
{
|
||||
// bottom tensor top dimension vector lengths and strides
|
||||
const auto [bottom_tensor_top_dim_vector_lengths, bottom_tensor_top_dim_vector_strides] =
|
||||
BottomTensorDesc::get_top_dimension_safe_vector_length_strides();
|
||||
|
||||
// window vector lengths/strides
|
||||
const auto window_adaptor_bottom_dim_vector_lengths = bottom_tensor_top_dim_vector_lengths;
|
||||
const auto window_adaptor_bottom_dim_vector_strides = bottom_tensor_top_dim_vector_strides;
|
||||
|
||||
// window adaptor [p0, p1, ..., y0, y1, ...]
|
||||
array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_lengths{
|
||||
-1};
|
||||
array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_strides{
|
||||
-1};
|
||||
|
||||
constexpr auto window_adaptor_bottom_dims =
|
||||
WindowAdaptor::get_bottom_dimension_hidden_ids();
|
||||
|
||||
set_container_subset(window_adaptor_vector_lengths,
|
||||
window_adaptor_bottom_dims,
|
||||
window_adaptor_bottom_dim_vector_lengths);
|
||||
set_container_subset(window_adaptor_vector_strides,
|
||||
window_adaptor_bottom_dims,
|
||||
window_adaptor_bottom_dim_vector_strides);
|
||||
|
||||
const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] =
|
||||
WindowAdaptor{}.get_top_dimension_safe_vector_length_strides(
|
||||
window_adaptor_vector_lengths, window_adaptor_vector_strides);
|
||||
|
||||
// [y0, y1, ...]
|
||||
constexpr auto y_dims = typename arithmetic_sequence_gen<TileDstr::get_num_of_dimension_p(),
|
||||
NDimWindowAdaptorTop,
|
||||
1>::type{};
|
||||
|
||||
return make_tuple(get_container_subset(window_adaptor_ps_ys_vector_lengths, y_dims),
|
||||
get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims));
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_num_of_access() const { return load_store_traits::NumAccess; }
|
||||
|
||||
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load(number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
|
||||
load(dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
|
||||
return dst_tensor;
|
||||
}
|
||||
|
||||
template <typename DistributedTensor,
|
||||
index_t i_access_unsupport_ = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using Traits = load_store_traits;
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
/// TODO: use structure binding (to be captured later) if compiled in C++20
|
||||
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
|
||||
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
|
||||
|
||||
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
|
||||
// data index [y0, y1, ...]
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
|
||||
constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
|
||||
const auto page_offset = page_idx_[idx_gather];
|
||||
|
||||
// read from bottom tensor
|
||||
const vector_t vec_value = [&]() {
|
||||
if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
|
||||
{
|
||||
return get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
page_offset,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
page_offset,
|
||||
valids_[idx_gather],
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
}();
|
||||
#if 1
|
||||
// write into distributed tensor
|
||||
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
: idx_ys_start[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
Traits::PackedSize;
|
||||
|
||||
dst_tensor.get_thread_buffer().template at<d>() =
|
||||
vec_value.template get_as<DataType>()[j / Traits::PackedSize];
|
||||
});
|
||||
#else
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
|
||||
static_assert(d % Traits::ScalarPerVector == 0);
|
||||
|
||||
dst_tensor.get_thread_buffer().template get_as<vector_t>()(
|
||||
number<d / Traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
|
||||
#endif
|
||||
// move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto forward_step_scatter = generate_tuple(
|
||||
[&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
forward_step_scatter);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// TODO: currently async load only implemented in inline asm
|
||||
template <typename LdsTileWindow_,
|
||||
index_t i_access_unsupport_ = -1,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
|
||||
// using LdsTensorView = typename LdsTileWindow::BottomTensorView;
|
||||
using LdsDataType = typename LdsTileWindow::DataType;
|
||||
// using LdsDescriptor = typename LdsTileWindow::BottomTensorDesc;
|
||||
|
||||
// issues * warps * lanes
|
||||
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
|
||||
|
||||
const index_t size_per_buf =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<0>{}, number<0>{}, number<0>{})) *
|
||||
sizeof(LdsDataType);
|
||||
|
||||
const index_t size_per_wave =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<0>{}, number<1>{}, number<0>{})) *
|
||||
sizeof(LdsDataType) -
|
||||
size_per_buf;
|
||||
|
||||
const index_t size_per_issue =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<1>{}, number<0>{}, number<0>{})) *
|
||||
sizeof(LdsDataType) -
|
||||
size_per_buf;
|
||||
|
||||
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
|
||||
m0_set_with_memory(m0_init_value); // This should be wave independent
|
||||
|
||||
using Traits = load_store_traits;
|
||||
|
||||
// using vector_type_t = typename Traits::vector_type_t;
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
/// TODO: use structure binding (to be captured later) if compiled in C++20
|
||||
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
|
||||
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
|
||||
|
||||
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
constexpr auto pre_nop_ = [&]() {
|
||||
if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
|
||||
return bool_constant<true>{};
|
||||
else
|
||||
return bool_constant<false>{};
|
||||
}();
|
||||
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
|
||||
constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
|
||||
const auto page_offset = page_idx_[idx_gather];
|
||||
|
||||
// read from bottom tensor
|
||||
if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
|
||||
{
|
||||
get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
|
||||
smem, bottom_tensor_thread_coord, page_offset, 0, pre_nop_);
|
||||
}
|
||||
else
|
||||
{
|
||||
get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
|
||||
smem,
|
||||
bottom_tensor_thread_coord,
|
||||
page_offset,
|
||||
valids_[idx_gather],
|
||||
0,
|
||||
pre_nop_);
|
||||
}
|
||||
|
||||
// move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto forward_step_scatter = generate_tuple(
|
||||
[&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
forward_step_scatter);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
|
||||
m0_inc_with_memory(size_per_issue);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE void store(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using Traits = load_store_traits;
|
||||
|
||||
// using vector_type_t = typename Traits::vector_type_t;
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
// printf("off %d\n", page_idx_[I0]);
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
|
||||
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
|
||||
|
||||
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
|
||||
// data index [y0, y1, ...]
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
|
||||
constexpr auto idx_gather = idx_ys_start[number<0>{}];
|
||||
const auto page_offset = page_idx_[idx_gather];
|
||||
|
||||
// printf("idx_ys_start[0], idx_ys_start[1](%d, %d) \n",
|
||||
// idx_ys_start[number<0>{}]+0, idx_ys_start[number<1>{}]+0);
|
||||
|
||||
// read from distributed tensor
|
||||
// vector_type_t vec;
|
||||
vector_t vec_value;
|
||||
|
||||
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
: idx_ys_start[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
Traits::PackedSize;
|
||||
// printf("thread_idx_m: %d j: %d\n", idx_ys[number<0>{}] + 0, 0+j);
|
||||
vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
|
||||
dstr_tensor.get_thread_buffer().template at<d>();
|
||||
});
|
||||
|
||||
// const vector_t vec_value = vec.template get_as<vector_t>().template at<0>();
|
||||
|
||||
// write into bottom tensor
|
||||
if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
|
||||
{
|
||||
get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
page_offset,
|
||||
vec_value,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
page_offset,
|
||||
valids_[idx_gather],
|
||||
vec_value,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
// printf("coord_offset:%d, scatter_offset:%d \n",
|
||||
// bottom_tensor_thread_coord.get_offset(), offset); move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto forward_step_scatter = generate_tuple(
|
||||
[&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
forward_step_scatter);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// move thread's botom tensor coordiante
|
||||
// [x0', x1', ... ] ==> [offset]
|
||||
// also move window-origin
|
||||
CK_TILE_DEVICE void move(const BottomTensorIndex& step)
|
||||
{
|
||||
window_origin_ += step;
|
||||
BottomTensorIndex step_new = step;
|
||||
step_new(HsGatherDim) = 0;
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
|
||||
pre_computed_coords_(iCoord)(I1),
|
||||
step_new);
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void update_page_idx(const PageIdxArray& new_idx) { page_idx_ = new_idx; }
|
||||
|
||||
CK_TILE_DEVICE void update_valids(const ValidArray& new_valids)
|
||||
{
|
||||
if constexpr(std::is_same_v<ValidArray, std::nullptr_t> == false)
|
||||
{
|
||||
valids_ = new_valids;
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void update_page_idx_and_valids(const PageIdxArray& new_idx,
|
||||
const ValidArray& new_valids)
|
||||
{
|
||||
update_page_idx(new_idx);
|
||||
update_valids(new_valids);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin)
|
||||
{
|
||||
window_origin_ = new_window_origin;
|
||||
|
||||
#if 0 // debug
|
||||
// TODO: this use more register for FA, but less register for GEMM
|
||||
// need investigation
|
||||
// only support warp-tile and block-tile
|
||||
static_assert(NDimP == 1 or NDimP == 2, "wrong!");
|
||||
|
||||
WindowAdaptorCoord window_adaptor_thread_coord_tmp;
|
||||
|
||||
if constexpr(NDimP == 1)
|
||||
{
|
||||
window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
|
||||
tile_dstr_.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0});
|
||||
}
|
||||
else if constexpr(NDimP == 2)
|
||||
{
|
||||
window_adaptor_thread_coord_tmp =
|
||||
make_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
|
||||
AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
|
||||
}
|
||||
#else
|
||||
// TODO: this use less register for FA, but more register for GEMM
|
||||
// need investigation
|
||||
const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
|
||||
tile_dstr_.get_ps_ys_to_xs_adaptor(),
|
||||
container_concat(detail::get_partition_index(tile_dstr_), array<index_t, NDimY>{0}));
|
||||
#endif
|
||||
|
||||
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
|
||||
window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
|
||||
|
||||
bottom_tensor_thread_origin_idx_tmp(HsGatherDim) = 0;
|
||||
const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
|
||||
bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
|
||||
|
||||
// pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
|
||||
// future load/store() calls (might allocate more registers)
|
||||
using Traits = load_store_traits;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
|
||||
auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
|
||||
|
||||
constexpr auto idx_diff_ys =
|
||||
SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}), idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
|
||||
pre_computed_coords_(iCoord) =
|
||||
make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void init_raw() { bottom_tensor_view_.init_raw(); }
|
||||
|
||||
// this is the bottom tensor view
|
||||
// [x0', x1', ...] ==> [offset]
|
||||
BottomTensorView bottom_tensor_view_;
|
||||
|
||||
//
|
||||
WindowLengths window_lengths_;
|
||||
|
||||
// origin ([x0', x1', ...]) of window on bottom tensor
|
||||
BottomTensorIndex window_origin_;
|
||||
|
||||
// Tile tensor distribution, which contains:
|
||||
// 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...]
|
||||
// 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d]
|
||||
TileDstr tile_dstr_;
|
||||
|
||||
PageIdxArray page_idx_;
|
||||
ValidArray valids_;
|
||||
|
||||
// this contains:
|
||||
// per-thread coordinate for window adaptor
|
||||
// per-thread coordinate for bottom tensor
|
||||
array<tuple<WindowAdaptorCoord, BottomTensorCoord>, NumCoord> pre_computed_coords_;
|
||||
};
|
||||
|
||||
// TODO: use strategy
|
||||
template <typename TensorView_,
|
||||
typename WindowLengths_,
|
||||
typename StaticTileDistribution_,
|
||||
typename StaticPageIndexArray_,
|
||||
index_t HsGatherDim = 0,
|
||||
index_t NumCoord = 1>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
make_tile_scatter_gather(const TensorView_& tensor_view,
|
||||
const WindowLengths_& window_lengths,
|
||||
const multi_index<TensorView_::get_num_of_dimension()>& origin,
|
||||
const StaticTileDistribution_& tile_distribution,
|
||||
const StaticPageIndexArray_& page_idx,
|
||||
number<HsGatherDim> = {},
|
||||
number<NumCoord> = {})
|
||||
{
|
||||
return tile_scatter_gather<remove_cvref_t<TensorView_>,
|
||||
remove_cvref_t<WindowLengths_>,
|
||||
remove_cvref_t<StaticTileDistribution_>,
|
||||
remove_cvref_t<StaticPageIndexArray_>,
|
||||
std::nullptr_t,
|
||||
HsGatherDim,
|
||||
NumCoord>{
|
||||
tensor_view, window_lengths, origin, tile_distribution, page_idx, nullptr};
|
||||
}
|
||||
|
||||
template <typename TensorView,
|
||||
typename WindowLengths,
|
||||
typename StaticTileDistribution,
|
||||
typename StaticPageIndexArray,
|
||||
index_t HsGatherDim>
|
||||
CK_TILE_DEVICE constexpr auto make_tile_scatter_gather(
|
||||
const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
|
||||
const multi_index<TensorView::get_num_of_dimension()>& origin,
|
||||
const StaticTileDistribution& tile_distribution,
|
||||
const StaticPageIndexArray& page_idx,
|
||||
number<HsGatherDim> = {})
|
||||
{
|
||||
return make_tile_scatter_gather(tile_window.get_bottom_tensor_view(),
|
||||
tile_window.get_window_lengths(),
|
||||
origin,
|
||||
tile_distribution,
|
||||
page_idx,
|
||||
number<HsGatherDim>{});
|
||||
}
|
||||
|
||||
template <typename TensorView,
|
||||
typename WindowLengths,
|
||||
typename StaticTileDistribution,
|
||||
typename StaticPageIndexArray,
|
||||
index_t HsGatherDim>
|
||||
CK_TILE_DEVICE constexpr auto make_tile_scatter_gather(
|
||||
const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
|
||||
const StaticTileDistribution& tile_distribution,
|
||||
const StaticPageIndexArray& page_idx,
|
||||
number<HsGatherDim> = {})
|
||||
{
|
||||
return make_tile_scatter_gather(tile_window.get_bottom_tensor_view(),
|
||||
tile_window.get_window_lengths(),
|
||||
tile_window.get_window_origin(),
|
||||
tile_distribution,
|
||||
page_idx,
|
||||
number<HsGatherDim>{});
|
||||
}
|
||||
|
||||
template <typename TensorView_,
|
||||
typename WindowLengths_,
|
||||
typename StaticTileDistribution_,
|
||||
typename StaticPageIndexArray_,
|
||||
typename StaticValidArray_,
|
||||
index_t HsGatherDim = 0,
|
||||
index_t NumCoord = 1>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
make_tile_scatter_gather(const TensorView_& tensor_view,
|
||||
const WindowLengths_& window_lengths,
|
||||
const multi_index<TensorView_::get_num_of_dimension()>& origin,
|
||||
const StaticTileDistribution_& tile_distribution,
|
||||
const StaticPageIndexArray_& page_idx,
|
||||
const StaticValidArray_& valids,
|
||||
number<HsGatherDim> = {},
|
||||
number<NumCoord> = {})
|
||||
{
|
||||
return tile_scatter_gather<remove_cvref_t<TensorView_>,
|
||||
remove_cvref_t<WindowLengths_>,
|
||||
remove_cvref_t<StaticTileDistribution_>,
|
||||
remove_cvref_t<StaticPageIndexArray_>,
|
||||
remove_cvref_t<StaticValidArray_>,
|
||||
HsGatherDim,
|
||||
NumCoord>{
|
||||
tensor_view, window_lengths, origin, tile_distribution, page_idx, valids};
|
||||
}
|
||||
|
||||
template <typename TensorView,
|
||||
typename WindowLengths,
|
||||
typename StaticTileDistribution,
|
||||
typename StaticPageIndexArray,
|
||||
typename StaticValidArray,
|
||||
index_t HsGatherDim>
|
||||
CK_TILE_DEVICE constexpr auto make_tile_scatter_gather(
|
||||
const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
|
||||
const multi_index<TensorView::get_num_of_dimension()>& origin,
|
||||
const StaticTileDistribution& tile_distribution,
|
||||
const StaticPageIndexArray& page_idx,
|
||||
const StaticValidArray& valids,
|
||||
number<HsGatherDim> = {})
|
||||
{
|
||||
return make_tile_scatter_gather(tile_window.get_bottom_tensor_view(),
|
||||
tile_window.get_window_lengths(),
|
||||
origin,
|
||||
tile_distribution,
|
||||
page_idx,
|
||||
valids,
|
||||
number<HsGatherDim>{});
|
||||
}
|
||||
|
||||
template <typename TensorView,
|
||||
typename WindowLengths,
|
||||
typename StaticTileDistribution,
|
||||
typename StaticPageIndexArray,
|
||||
typename StaticValidArray,
|
||||
index_t HsGatherDim>
|
||||
CK_TILE_DEVICE constexpr auto make_tile_scatter_gather(
|
||||
const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
|
||||
const StaticTileDistribution& tile_distribution,
|
||||
const StaticPageIndexArray& page_idx,
|
||||
const StaticValidArray& valids,
|
||||
number<HsGatherDim> = {})
|
||||
{
|
||||
return make_tile_scatter_gather(tile_window.get_bottom_tensor_view(),
|
||||
tile_window.get_window_lengths(),
|
||||
tile_window.get_window_origin(),
|
||||
tile_distribution,
|
||||
page_idx,
|
||||
valids,
|
||||
number<HsGatherDim>{});
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
File diff suppressed because it is too large
Load Diff
256
include/ck_tile/core/tensor/tile_window_base.hpp
Normal file
256
include/ck_tile/core/tensor/tile_window_base.hpp
Normal file
@@ -0,0 +1,256 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/utility.hpp"
|
||||
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/**
|
||||
* @brief This class provides description of tile windowed view on the device memory.
|
||||
*
|
||||
* @note This class does not provide any functions to read or modify device memory.
|
||||
*
|
||||
* @tparam BottomTensorView_ Class describing & holding device tensor memory.
|
||||
* @tparam WindowLengths_ Spatial sizes of windowed view on tensor.
|
||||
*/
|
||||
template <typename TileWindowType_, typename BottomTensorView_, typename WindowLengths_>
|
||||
struct tile_window_base
|
||||
{
|
||||
|
||||
using BottomTensorView = remove_reference_t<BottomTensorView_>;
|
||||
using WindowLengths = remove_cvref_t<WindowLengths_>;
|
||||
using BottomTensorDesc = typename BottomTensorView::TensorDesc;
|
||||
using DataType = remove_cvref_t<typename BottomTensorView::DataType>;
|
||||
|
||||
static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension();
|
||||
|
||||
static_assert(ck_tile::is_known_at_compile_time<WindowLengths>::value,
|
||||
"wrong! lengths should be static");
|
||||
|
||||
using BottomTensorIndex = array<index_t, NDimBottomTensor>;
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; }
|
||||
CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; }
|
||||
CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; }
|
||||
CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; }
|
||||
|
||||
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin)
|
||||
{
|
||||
window_origin_ = new_window_origin;
|
||||
|
||||
// Delegate to child if it implements extra logic
|
||||
static_cast<TileWindowType_*>(this)->set_window_origin_extended(new_window_origin);
|
||||
}
|
||||
// Default no-op; can be overridden in child
|
||||
CK_TILE_DEVICE void set_window_origin_extended(const BottomTensorIndex&) {}
|
||||
|
||||
CK_TILE_DEVICE constexpr void
|
||||
set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data)
|
||||
{
|
||||
bottom_tensor_view_.buf_.p_data_ = data;
|
||||
}
|
||||
|
||||
// move window-origin
|
||||
CK_TILE_DEVICE void move(const BottomTensorIndex& step)
|
||||
{
|
||||
window_origin_ += step;
|
||||
|
||||
// Delegate to child if it implements extra movement logic
|
||||
static_cast<TileWindowType_*>(this)->move_extended(step);
|
||||
}
|
||||
|
||||
// Default no-op; can be overridden in child
|
||||
CK_TILE_DEVICE void move_extended(const BottomTensorIndex&) {}
|
||||
|
||||
// origin ([x0', x1', ...]) of window on bottom tensor
|
||||
BottomTensorIndex window_origin_;
|
||||
|
||||
WindowLengths window_lengths_;
|
||||
|
||||
// this is the bottom tensor view
|
||||
// [x0', x1', ...] ==> [offset]
|
||||
BottomTensorView bottom_tensor_view_;
|
||||
};
|
||||
|
||||
template <typename TileWindowType_,
|
||||
typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename StaticTileDistribution_>
|
||||
struct tile_window_with_tile_dstr_base
|
||||
: public tile_window_base<TileWindowType_, BottomTensorView_, WindowLengths_>
|
||||
{
|
||||
using TileDstr = remove_cvref_t<StaticTileDistribution_>;
|
||||
using TileWindowBase = tile_window_base<TileWindowType_, BottomTensorView_, WindowLengths_>;
|
||||
|
||||
using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor;
|
||||
|
||||
static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::get_num_of_top_dimension();
|
||||
|
||||
static constexpr index_t NDimP = TileDstr::get_num_of_dimension_p();
|
||||
static constexpr index_t NDimY = TileDstr::get_num_of_dimension_y();
|
||||
|
||||
using AdaptorTopIndex = array<index_t, NDimWindowAdaptorTop>;
|
||||
// using BottomTensorIndex = array<index_t, TileWindowBase::NDimBottomTensor>;
|
||||
|
||||
using WindowAdaptorCoord =
|
||||
decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{}));
|
||||
|
||||
using BottomTensorCoord = decltype(make_tensor_coordinate(
|
||||
typename TileWindowBase::BottomTensorDesc{}, typename TileWindowBase::BottomTensorIndex{}));
|
||||
|
||||
static_assert(TileDstr::is_static(), "wrong!");
|
||||
static_assert(TileWindowBase::NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(),
|
||||
"wrong! inconsistent # of diemsnions");
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_tile_distribution() const { return tile_dstr_; }
|
||||
CK_TILE_HOST_DEVICE void init_raw() { this->bottom_tensor_view_.init_raw(); }
|
||||
|
||||
CK_TILE_DEVICE static constexpr bool has_static_tile_distribution()
|
||||
{
|
||||
return TileDstr::is_static();
|
||||
}
|
||||
|
||||
// move thread's window adaptor coordinate and bottom tensor coordinate
|
||||
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
|
||||
template <typename ATopIndex>
|
||||
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
WindowAdaptorCoord& window_adaptor_thread_coord,
|
||||
BottomTensorCoord& bottom_tensor_thread_coord,
|
||||
const ATopIndex& idx_diff_adaptor_top) const
|
||||
{
|
||||
array<index_t, TileWindowBase::NDimBottomTensor> idx_diff_adaptor_bottom;
|
||||
|
||||
move_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
|
||||
window_adaptor_thread_coord,
|
||||
idx_diff_adaptor_top,
|
||||
idx_diff_adaptor_bottom);
|
||||
|
||||
move_tensor_coordinate(this->bottom_tensor_view_.get_tensor_descriptor(),
|
||||
bottom_tensor_thread_coord,
|
||||
idx_diff_adaptor_bottom);
|
||||
}
|
||||
|
||||
struct Traits
|
||||
{
|
||||
public:
|
||||
static constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<typename TileWindowBase::DataType>>::PackedSize;
|
||||
|
||||
static constexpr auto get_vector_dim_y_scalar_per_vector()
|
||||
{
|
||||
const auto [ys_vector_lengths, ys_vector_strides] =
|
||||
tile_window_with_tile_dstr_base::get_window_adaptor_ys_safe_vector_length_strides();
|
||||
|
||||
index_t VectorDimY_ = 0;
|
||||
index_t ScalarPerVector_ = 1;
|
||||
|
||||
for(index_t i = 0; i < NDimY; ++i)
|
||||
{
|
||||
if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_)
|
||||
{
|
||||
ScalarPerVector_ = ys_vector_lengths[i];
|
||||
VectorDimY_ = i;
|
||||
}
|
||||
}
|
||||
|
||||
return make_tuple(VectorDimY_, ScalarPerVector_);
|
||||
}
|
||||
|
||||
static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>();
|
||||
static constexpr index_t ScalarPerVector =
|
||||
get_vector_dim_y_scalar_per_vector().template at<1>();
|
||||
using vector_t =
|
||||
thread_buffer<typename TileWindowBase::DataType, ScalarPerVector / PackedSize>;
|
||||
|
||||
static constexpr auto scalars_per_access_ = [] {
|
||||
constexpr auto scalars_per_access_arr = generate_array(
|
||||
[&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, number<NDimY>{});
|
||||
|
||||
/// TODO: add non-automatic storage argument support to macro TO_SEQUENCE()
|
||||
constexpr auto NDimY_ = NDimY;
|
||||
|
||||
return TO_SEQUENCE(scalars_per_access_arr, NDimY_);
|
||||
}();
|
||||
|
||||
static constexpr auto get_space_filling_curve()
|
||||
{
|
||||
constexpr auto thread_tensor_lengths_ys =
|
||||
to_sequence(TileDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
// FIXME: need logic to judge dim access order
|
||||
using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type;
|
||||
|
||||
return space_filling_curve<decltype(thread_tensor_lengths_ys),
|
||||
DimAccessOrder,
|
||||
decltype(scalars_per_access_),
|
||||
false /*!!! no snaked curve! */>{};
|
||||
}
|
||||
|
||||
using SFC_Ys = decltype(get_space_filling_curve());
|
||||
|
||||
static constexpr index_t NumAccess = SFC_Ys::get_num_of_access();
|
||||
|
||||
static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0");
|
||||
};
|
||||
|
||||
// return vector dimension among [y0, y1, ...]
|
||||
CK_TILE_DEVICE static constexpr auto get_window_adaptor_ys_safe_vector_length_strides()
|
||||
{
|
||||
// bottom tensor top dimension vector lengths and strides
|
||||
const auto [bottom_tensor_top_dim_vector_lengths, bottom_tensor_top_dim_vector_strides] =
|
||||
TileWindowBase::BottomTensorDesc::get_top_dimension_safe_vector_length_strides();
|
||||
|
||||
// window vector lengths/strides
|
||||
const auto window_adaptor_bottom_dim_vector_lengths = bottom_tensor_top_dim_vector_lengths;
|
||||
const auto window_adaptor_bottom_dim_vector_strides = bottom_tensor_top_dim_vector_strides;
|
||||
|
||||
// window adaptor [p0, p1, ..., y0, y1, ...]
|
||||
array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_lengths{
|
||||
-1};
|
||||
array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_strides{
|
||||
-1};
|
||||
|
||||
constexpr auto window_adaptor_bottom_dims =
|
||||
WindowAdaptor::get_bottom_dimension_hidden_ids();
|
||||
|
||||
set_container_subset(window_adaptor_vector_lengths,
|
||||
window_adaptor_bottom_dims,
|
||||
window_adaptor_bottom_dim_vector_lengths);
|
||||
set_container_subset(window_adaptor_vector_strides,
|
||||
window_adaptor_bottom_dims,
|
||||
window_adaptor_bottom_dim_vector_strides);
|
||||
|
||||
const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] =
|
||||
WindowAdaptor{}.get_top_dimension_safe_vector_length_strides(
|
||||
window_adaptor_vector_lengths, window_adaptor_vector_strides);
|
||||
|
||||
// [y0, y1, ...]
|
||||
constexpr auto y_dims = typename arithmetic_sequence_gen<TileDstr::get_num_of_dimension_p(),
|
||||
NDimWindowAdaptorTop,
|
||||
1>::type{};
|
||||
|
||||
return make_tuple(get_container_subset(window_adaptor_ps_ys_vector_lengths, y_dims),
|
||||
get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims));
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_num_of_access() const { return Traits::NumAccess; }
|
||||
// Tile tensor distribution, which contains:
|
||||
// 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...]
|
||||
// 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d]
|
||||
TileDstr tile_dstr_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
File diff suppressed because it is too large
Load Diff
@@ -18,6 +18,13 @@
|
||||
#pragma once
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename TileWindow_>
|
||||
CK_TILE_DEVICE void move_tile_window(TileWindow_& window,
|
||||
const typename TileWindow_::BottomTensorIndex& step)
|
||||
{
|
||||
window.move(step);
|
||||
}
|
||||
|
||||
// input a lds store tile, extract some information from it
|
||||
// used to set m0 value for gfx9 serious
|
||||
template <typename LdsTileWindow_>
|
||||
|
||||
@@ -83,12 +83,14 @@ CK_TILE_DEVICE void transpose_tile2d_impl_in_thread(OutTensor& out_tensor,
|
||||
constexpr index_t num_vec_in = vec_length_out;
|
||||
constexpr index_t num_vec_out = vec_length_in;
|
||||
|
||||
using InVec = array<DataType, vec_length_in>;
|
||||
using OutVec = array<DataType, vec_length_out>;
|
||||
|
||||
// SFC
|
||||
constexpr auto scalars_per_access_arr = generate_array(
|
||||
[&](auto i) { return (i == y_dim_vec_in or i == y_dim_vec_out) ? y_lengths[i] : 1; },
|
||||
[&](auto i) {
|
||||
if constexpr(vec_length_in == 1)
|
||||
return 1;
|
||||
else
|
||||
return (i == y_dim_vec_in || i == y_dim_vec_out) ? y_lengths[i] : 1;
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr auto scalars_per_access = TO_SEQUENCE(scalars_per_access_arr, NDimY);
|
||||
@@ -101,51 +103,90 @@ CK_TILE_DEVICE void transpose_tile2d_impl_in_thread(OutTensor& out_tensor,
|
||||
|
||||
static_assert(num_access > 0, "wrong! num_access should be larger than 0");
|
||||
|
||||
// in/out vectors to be transposed
|
||||
thread_buffer<InVec, num_vec_in> in_vectors;
|
||||
thread_buffer<OutVec, num_vec_out> out_vectors;
|
||||
|
||||
// loop over SFC and do transpose
|
||||
static_for<0, num_access, 1>{}([&](auto iAccess) {
|
||||
// data index [y0, y1, ...] in the order of input tensor
|
||||
constexpr auto idx_y_start = SFC_Y::get_index(iAccess);
|
||||
|
||||
// get input vectors
|
||||
static_for<0, num_vec_in, 1>{}([&](auto i) {
|
||||
constexpr auto idx_y_in = generate_tuple(
|
||||
[&](auto ii) {
|
||||
return ii == y_dim_vec_out ? idx_y_start[ii] + i : idx_y_start[ii];
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
if constexpr(num_vec_in == 1 || num_vec_out == 1)
|
||||
{
|
||||
// loop over SFC
|
||||
static_for<0, num_access, 1>{}([&](auto iAccess) {
|
||||
// data index [y0, y1, ...] in the order of input tensor
|
||||
constexpr auto idx_y_start = SFC_Y::get_index(iAccess);
|
||||
constexpr auto idx_y_in =
|
||||
generate_tuple([&](auto ii) { return idx_y_start[ii].value; }, number<NDimY>{});
|
||||
constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y_in);
|
||||
static_assert(in_offset % vec_length_in == 0);
|
||||
|
||||
in_vectors(i).template get_as<InVec>()(I0) =
|
||||
in_tensor.get_thread_buffer()
|
||||
.template get_as<InVec>()[number<in_offset / vec_length_in>{}];
|
||||
});
|
||||
|
||||
// transpose
|
||||
transpose_vectors<DataType, num_vec_in, num_vec_out>{}(in_vectors, out_vectors);
|
||||
|
||||
// set output vectors
|
||||
static_for<0, num_vec_out, 1>{}([&](auto i) {
|
||||
constexpr auto idx_y_out_tmp = generate_array(
|
||||
[&](auto ii) { return ii == y_dim_vec_in ? idx_y_start[ii] + i : idx_y_start[ii]; },
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr auto idx_y_out_tmp =
|
||||
generate_array([&](auto ii) { return idx_y_start[ii].value; }, number<NDimY>{});
|
||||
constexpr auto idx_y_out =
|
||||
container_reorder_given_new2old(idx_y_out_tmp, y_dim_out_to_in);
|
||||
|
||||
constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y_out);
|
||||
static_assert(out_offset % vec_length_out == 0);
|
||||
if constexpr(vec_length_in == 1)
|
||||
{
|
||||
|
||||
out_tensor.get_thread_buffer().template set_as<OutVec>(
|
||||
number<out_offset / vec_length_out>{},
|
||||
out_vectors[i].template get_as<OutVec>()[I0]);
|
||||
out_tensor.get_thread_buffer()[number<out_offset>{}] =
|
||||
in_tensor.get_thread_buffer()[number<in_offset>{}];
|
||||
}
|
||||
else
|
||||
{
|
||||
using Vec = array<DataType, vec_length_in>;
|
||||
out_tensor.get_thread_buffer().template get_as<Vec>(
|
||||
number<out_offset / vec_length_in>{}) =
|
||||
in_tensor.get_thread_buffer().template get_as<Vec>(
|
||||
number<in_offset / vec_length_in>{});
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
using InVec = array<DataType, vec_length_in>;
|
||||
using OutVec = array<DataType, vec_length_out>;
|
||||
|
||||
// in/out vectors to be transposed
|
||||
thread_buffer<InVec, num_vec_in> in_vectors;
|
||||
thread_buffer<OutVec, num_vec_out> out_vectors;
|
||||
|
||||
// loop over SFC and do transpose
|
||||
static_for<0, num_access, 1>{}([&](auto iAccess) {
|
||||
// data index [y0, y1, ...] in the order of input tensor
|
||||
constexpr auto idx_y_start = SFC_Y::get_index(iAccess);
|
||||
|
||||
// get input vectors
|
||||
static_for<0, num_vec_in, 1>{}([&](auto i) {
|
||||
constexpr auto idx_y_in = generate_tuple(
|
||||
[&](auto ii) {
|
||||
return ii == y_dim_vec_out ? idx_y_start[ii] + i : idx_y_start[ii];
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y_in);
|
||||
static_assert(in_offset % vec_length_in == 0);
|
||||
|
||||
in_vectors(i).template get_as<InVec>()(I0) =
|
||||
in_tensor.get_thread_buffer()
|
||||
.template get_as<InVec>()[number<in_offset / vec_length_in>{}];
|
||||
});
|
||||
|
||||
// transpose
|
||||
transpose_vectors<DataType, num_vec_in, num_vec_out>{}(in_vectors, out_vectors);
|
||||
|
||||
// set output vectors
|
||||
static_for<0, num_vec_out, 1>{}([&](auto i) {
|
||||
constexpr auto idx_y_out_tmp = generate_array(
|
||||
[&](auto ii) {
|
||||
return ii == y_dim_vec_in ? idx_y_start[ii] + i : idx_y_start[ii];
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr auto idx_y_out =
|
||||
container_reorder_given_new2old(idx_y_out_tmp, y_dim_out_to_in);
|
||||
|
||||
constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y_out);
|
||||
static_assert(out_offset % vec_length_out == 0);
|
||||
|
||||
out_tensor.get_thread_buffer().template set_as<OutVec>(
|
||||
number<out_offset / vec_length_out>{},
|
||||
out_vectors[i].template get_as<OutVec>()[I0]);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
156
include/ck_tile/core/utility/debug.hpp
Normal file
156
include/ck_tile/core/utility/debug.hpp
Normal file
@@ -0,0 +1,156 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include <stdio.h>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
template <auto... val>
|
||||
[[deprecated("Help function to print value")]] inline constexpr void CK_PRINT()
|
||||
{
|
||||
}
|
||||
template <typename... type>
|
||||
[[deprecated("Help function to print value")]] inline constexpr void CK_PRINT()
|
||||
{
|
||||
}
|
||||
|
||||
template <char... Xs>
|
||||
struct str_literal
|
||||
{
|
||||
static constexpr const char data[] = {Xs..., '\0'};
|
||||
static constexpr const size_t size = sizeof...(Xs);
|
||||
|
||||
template <char... Ys>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator+(str_literal<Ys...> /*rhs*/) const
|
||||
{
|
||||
return str_literal<Xs..., Ys...>{};
|
||||
}
|
||||
|
||||
template <index_t N, char... Ys>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto duplicate_n(const str_literal<Ys...> sep)
|
||||
{
|
||||
if constexpr(N == 0)
|
||||
return str_literal<>{};
|
||||
else if constexpr(N == 1)
|
||||
return str_literal<Xs...>{};
|
||||
else
|
||||
return duplicate_n<N - 1>(sep) + str_literal<Ys..., Xs...>{};
|
||||
}
|
||||
};
|
||||
|
||||
#define make_str_literal(lit_) \
|
||||
std::apply([](auto... indices) { return str_literal<(lit_)[decltype(indices)::value]...>{}; }, \
|
||||
makeTuple(std::make_index_sequence<constexpr_strlen(lit_)>()))
|
||||
|
||||
template <size_t... Idx>
|
||||
constexpr std::tuple<std::integral_constant<size_t, Idx>...>
|
||||
makeTuple(std::index_sequence<Idx...>) noexcept
|
||||
{
|
||||
return {};
|
||||
}
|
||||
constexpr size_t constexpr_strlen(const char* c)
|
||||
{
|
||||
size_t t = 0;
|
||||
while(*c++)
|
||||
++t;
|
||||
return t;
|
||||
}
|
||||
|
||||
template <typename DataType_, typename StaticTileDistribution_>
|
||||
struct static_distributed_tensor;
|
||||
|
||||
template <typename T_, index_t N_>
|
||||
struct thread_buffer;
|
||||
|
||||
// Usage example: CK_PRINTF<float>{}(tensor);
|
||||
template <typename ConvertTo = void,
|
||||
typename FMT = str_literal<>,
|
||||
typename PREFIX = str_literal<>,
|
||||
typename SUFFIX = str_literal<>>
|
||||
struct CK_PRINTF;
|
||||
template <typename ConvertTo, char... FMTChars, char... PREFIXChars, char... SUFFIXChars>
|
||||
struct CK_PRINTF<ConvertTo,
|
||||
str_literal<FMTChars...>,
|
||||
str_literal<PREFIXChars...>,
|
||||
str_literal<SUFFIXChars...>>
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto default_format()
|
||||
{
|
||||
if constexpr(std::is_same_v<T, float>)
|
||||
return make_str_literal("%8.3f");
|
||||
else if constexpr(std::is_same_v<T, int>)
|
||||
return make_str_literal("%5d");
|
||||
else if constexpr(std::is_same_v<T, unsigned int>)
|
||||
return make_str_literal("%5u");
|
||||
else
|
||||
return make_str_literal("0x%08x");
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_prefix()
|
||||
{
|
||||
constexpr auto fmt_tid = make_str_literal("tid %03d: [%02d] ");
|
||||
if constexpr(sizeof...(PREFIXChars) == 0)
|
||||
return fmt_tid;
|
||||
else
|
||||
return fmt_tid + make_str_literal(" ") + str_literal<PREFIXChars...>{};
|
||||
}
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_suffix()
|
||||
{
|
||||
constexpr auto lf = make_str_literal("\n");
|
||||
if constexpr(sizeof...(SUFFIXChars) == 0)
|
||||
return lf;
|
||||
else
|
||||
return str_literal<SUFFIXChars...>{} + lf;
|
||||
}
|
||||
|
||||
template <typename T, index_t N, typename Y, index_t... Is>
|
||||
CK_TILE_HOST_DEVICE void impl(const thread_buffer<T, N>& buf,
|
||||
std::integer_sequence<index_t, Is...>) const
|
||||
{
|
||||
using FMT1 = std::conditional_t<sizeof...(FMTChars) == 0,
|
||||
decltype(default_format<Y>()),
|
||||
str_literal<FMTChars...>>;
|
||||
constexpr auto fmt_v = FMT1::template duplicate_n<N>(make_str_literal(" "));
|
||||
constexpr auto fmt_wrap_v = get_prefix() + fmt_v + get_suffix();
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wformat-nonliteral"
|
||||
printf(fmt_wrap_v.data, get_thread_id(), N, type_convert<Y>(buf[Is])...);
|
||||
#pragma clang diagnostic pop
|
||||
}
|
||||
|
||||
template <typename T, index_t N>
|
||||
CK_TILE_HOST_DEVICE void operator()(const thread_buffer<T, N>& buf) const
|
||||
{
|
||||
using ConvertTo_ = std::conditional_t<std::is_same_v<ConvertTo, void>, T, ConvertTo>;
|
||||
impl<T, N, ConvertTo_>(buf, std::make_integer_sequence<index_t, N>{});
|
||||
}
|
||||
|
||||
template <typename... TS>
|
||||
CK_TILE_HOST_DEVICE void operator()(const static_distributed_tensor<TS...>& tensor) const
|
||||
{
|
||||
return operator()(tensor.get_thread_buffer());
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ConvertTo = void,
|
||||
typename FMT = str_literal<>,
|
||||
typename PREFIX = str_literal<>,
|
||||
typename SUFFIX = str_literal<>>
|
||||
struct CK_PRINTF_WARP0 : public CK_PRINTF<ConvertTo, FMT, PREFIX, SUFFIX>
|
||||
{
|
||||
using base_t = CK_PRINTF<ConvertTo, FMT, PREFIX, SUFFIX>;
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(const T& buf) const
|
||||
{
|
||||
if(get_thread_id() < get_warp_size())
|
||||
base_t::operator()(buf);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -202,3 +202,7 @@ void UpdateEnvVar(EnvVar, const std::string_view& val)
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
// environment variable to enable logging:
|
||||
// export CK_TILE_LOGGING=ON or CK_TILE_LOGGING=1 or CK_TILE_LOGGING=ENABLED
|
||||
CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_LOGGING)
|
||||
|
||||
@@ -58,6 +58,30 @@ struct static_for
|
||||
}
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename T, T... Is>
|
||||
struct applier
|
||||
{
|
||||
template <typename F>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
|
||||
{
|
||||
// tweak -fbracket-depth if compilation fails. Clang default limit is 256
|
||||
(f(number<Is>{}), ...);
|
||||
}
|
||||
};
|
||||
|
||||
template <int32_t Size> // == sizeof...(Is)
|
||||
using make_applier = __make_integer_seq<applier, index_t, Size>;
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <index_t N>
|
||||
struct static_for<0, N, 1> : detail::make_applier<N>
|
||||
{
|
||||
using detail::make_applier<N>::operator();
|
||||
};
|
||||
|
||||
struct identity
|
||||
{
|
||||
template <typename T>
|
||||
|
||||
@@ -38,7 +38,7 @@ struct magic_division32_bit_range
|
||||
shift_u32++;
|
||||
};
|
||||
|
||||
uint64_t tmp_u64 = ((1UL << shift_u32) - divisor) << 32;
|
||||
uint64_t tmp_u64 = static_cast<uint64_t>((1UL << shift_u32) - divisor) << 32;
|
||||
uint32_t multiplier_u32 = tmp_u64 / divisor + 1;
|
||||
|
||||
return make_tuple(multiplier_u32, shift_u32);
|
||||
|
||||
76
include/ck_tile/core/utility/print.hpp
Normal file
76
include/ck_tile/core/utility/print.hpp
Normal file
@@ -0,0 +1,76 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/// Declare a ck_tile::print() interface that gets specialized in each header file for types that
|
||||
/// can be printed.
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void print(const T&)
|
||||
{
|
||||
static_assert(sizeof(T) == 0,
|
||||
"No print implementation available for this type. Please specialize "
|
||||
"ck_tile::print for your type.");
|
||||
}
|
||||
|
||||
/// Specialization for int
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void print(const int& value)
|
||||
{
|
||||
printf("%d", value);
|
||||
}
|
||||
|
||||
/// Specialization for float
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void print(const float& value)
|
||||
{
|
||||
printf("%f", value);
|
||||
}
|
||||
|
||||
/// Specialization for double
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void print(const double& value)
|
||||
{
|
||||
printf("%f", value);
|
||||
}
|
||||
|
||||
/// Specialization for long
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void print(const long& value)
|
||||
{
|
||||
printf("%ld", value);
|
||||
}
|
||||
|
||||
/// Specialization for unsigned int
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void print(const unsigned int& value)
|
||||
{
|
||||
printf("%u", value);
|
||||
}
|
||||
|
||||
/// Specialization for char
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void print(const char& value)
|
||||
{
|
||||
printf("%c", value);
|
||||
}
|
||||
|
||||
/// Specialization for array
|
||||
template <typename T, size_t N>
|
||||
CK_TILE_HOST_DEVICE void print(const T (&value)[N])
|
||||
{
|
||||
printf("[");
|
||||
for(size_t i = 0; i < N; ++i)
|
||||
{
|
||||
if(i > 0)
|
||||
printf(", ");
|
||||
print(value[i]); // Recursively call print for each element
|
||||
}
|
||||
printf("]");
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -26,7 +26,8 @@ struct Add
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t>>>
|
||||
typename = std::enable_if_t<std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
|
||||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(T& y, T x) const
|
||||
{
|
||||
float y_ = type_convert<float>(y);
|
||||
@@ -51,13 +52,25 @@ struct SquareAdd
|
||||
{
|
||||
return y + (x * x);
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
|
||||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(T& y, T x) const
|
||||
{
|
||||
float y_ = type_convert<float>(y);
|
||||
float x_ = type_convert<float>(x);
|
||||
return type_convert<T>(y_ + (x_ * x_));
|
||||
}
|
||||
};
|
||||
|
||||
struct Max
|
||||
{
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
|
||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
|
||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t> ||
|
||||
std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
|
||||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
|
||||
CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue()
|
||||
{
|
||||
return numeric<T>::min();
|
||||
@@ -65,7 +78,9 @@ struct Max
|
||||
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
|
||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
|
||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t> ||
|
||||
std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
|
||||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
|
||||
{
|
||||
return max(y, x);
|
||||
@@ -76,7 +91,9 @@ struct AbsMax
|
||||
{
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
|
||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
|
||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t> ||
|
||||
std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
|
||||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
|
||||
CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue()
|
||||
{
|
||||
return numeric<T>::min();
|
||||
@@ -84,7 +101,9 @@ struct AbsMax
|
||||
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
|
||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
|
||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t> ||
|
||||
std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
|
||||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
|
||||
{
|
||||
return max(y, abs(x));
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
#include <stdint.h>
|
||||
|
||||
@@ -57,8 +58,8 @@ struct detector<Default, std::void_t<Op<Args...>>, Op, Args...>
|
||||
|
||||
struct nonesuch
|
||||
{
|
||||
~nonesuch() = delete;
|
||||
nonesuch(nonesuch const&) = delete;
|
||||
~nonesuch() = delete;
|
||||
nonesuch(nonesuch const&) = delete;
|
||||
void operator=(nonesuch const&) = delete;
|
||||
};
|
||||
|
||||
@@ -127,4 +128,44 @@ struct is_any_of<CompareTo, FirstType, Rest...>
|
||||
{
|
||||
};
|
||||
|
||||
// Helper to check if a type is a specialization of a given template
|
||||
template <typename Test, template <typename...> class RefTemplate>
|
||||
struct is_specialization_of : std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
template <template <typename...> class RefTemplate, typename... Args>
|
||||
struct is_specialization_of<RefTemplate<Args...>, RefTemplate> : std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
// Helper to get a tuple element or default type
|
||||
namespace detail {
|
||||
|
||||
template <bool IsWithinBounds, std::size_t Idx, typename Tuple, typename DefaultType>
|
||||
struct tuple_element_or_default_dispatch
|
||||
{
|
||||
using type = DefaultType;
|
||||
};
|
||||
|
||||
template <std::size_t Idx, typename Tuple, typename DefaultType>
|
||||
struct tuple_element_or_default_dispatch<true, Idx, Tuple, DefaultType>
|
||||
{
|
||||
using type = std::tuple_element_t<Idx, Tuple>;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename Tuple_, std::size_t Idx, typename DefaultType>
|
||||
struct tuple_element_or_default
|
||||
{
|
||||
using Tuple = remove_cvref_t<Tuple_>;
|
||||
static constexpr bool is_within_bounds = Idx < std::tuple_size_v<Tuple>;
|
||||
using type = typename detail::
|
||||
tuple_element_or_default_dispatch<is_within_bounds, Idx, Tuple, DefaultType>::type;
|
||||
};
|
||||
template <typename Tuple_, std::size_t Idx, typename DefaultType>
|
||||
using tuple_element_or_default_t =
|
||||
typename tuple_element_or_default<Tuple_, Idx, DefaultType>::type;
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -49,7 +49,7 @@ struct composes<F>
|
||||
|
||||
/// FIXME: create macro to replace '__host__ __device__' and nothing more
|
||||
template <typename... Ts>
|
||||
__host__ __device__ composes(Ts&&...)->composes<remove_cvref_t<Ts>...>;
|
||||
__host__ __device__ composes(Ts&&...) -> composes<remove_cvref_t<Ts>...>;
|
||||
|
||||
template <typename SaturateType>
|
||||
struct saturates
|
||||
@@ -57,8 +57,8 @@ struct saturates
|
||||
// NOTE: this function does not return SaturateType value
|
||||
// it is user's responsiblity to do further cast or not
|
||||
template <typename AccType>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const AccType& a_) const
|
||||
-> std::enable_if_t<std::is_arithmetic_v<AccType>, AccType>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
operator()(const AccType& a_) const -> std::enable_if_t<std::is_arithmetic_v<AccType>, AccType>
|
||||
{
|
||||
return clamp(a_,
|
||||
type_convert<AccType>(numeric<SaturateType>::lowest()),
|
||||
|
||||
@@ -9,7 +9,9 @@
|
||||
#include "ck_tile/host/convolution_host_tensor_descriptor_helper.hpp"
|
||||
#include "ck_tile/host/convolution_parameter.hpp"
|
||||
#include "ck_tile/host/device_memory.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
#include "ck_tile/host/fill.hpp"
|
||||
#include "ck_tile/host/flush_icache.hpp"
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include "ck_tile/host/joinable_thread.hpp"
|
||||
@@ -25,6 +27,8 @@
|
||||
#include "ck_tile/host/reference/reference_elementwise.hpp"
|
||||
#include "ck_tile/host/reference/reference_fused_moe.hpp"
|
||||
#include "ck_tile/host/reference/reference_gemm.hpp"
|
||||
#include "ck_tile/host/reference/reference_grouped_conv_bwd_weight.hpp"
|
||||
#include "ck_tile/host/reference/reference_grouped_conv_fwd.hpp"
|
||||
#include "ck_tile/host/reference/reference_im2col.hpp"
|
||||
#include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp"
|
||||
#include "ck_tile/host/reference/reference_moe_sorting.hpp"
|
||||
@@ -34,5 +38,8 @@
|
||||
#include "ck_tile/host/reference/reference_rowwise_quantization2d.hpp"
|
||||
#include "ck_tile/host/reference/reference_softmax.hpp"
|
||||
#include "ck_tile/host/reference/reference_topk.hpp"
|
||||
#include "ck_tile/host/reference/reference_transpose.hpp"
|
||||
#include "ck_tile/host/rotating_buffers.hpp"
|
||||
#include "ck_tile/host/stream_config.hpp"
|
||||
#include "ck_tile/host/stream_utils.hpp"
|
||||
#include "ck_tile/host/timer.hpp"
|
||||
|
||||
@@ -18,16 +18,36 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/** @brief 8-bit floating point type */
|
||||
using F8 = ck_tile::fp8_t;
|
||||
/** @brief 8-bit brain floating point type */
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
/** @brief 16-bit floating point (half precision) type */
|
||||
using F16 = ck_tile::half_t;
|
||||
/** @brief 16-bit brain floating point type */
|
||||
using BF16 = ck_tile::bf16_t;
|
||||
/** @brief 32-bit floating point (single precision) type */
|
||||
using F32 = float;
|
||||
/** @brief 8-bit signed integer type */
|
||||
using I8 = int8_t;
|
||||
/** @brief 32-bit signed integer type */
|
||||
using I32 = int32_t;
|
||||
|
||||
/**
|
||||
* @brief Calculate relative error threshold for numerical comparisons
|
||||
*
|
||||
* Calculates the relative error threshold based on the mantissa bits and characteristics
|
||||
* of the data types involved in the computation.
|
||||
*
|
||||
* @tparam ComputeDataType Type used for computation
|
||||
* @tparam OutDataType Type used for output
|
||||
* @tparam AccDataType Type used for accumulation (defaults to ComputeDataType)
|
||||
* @param number_of_accumulations Number of accumulation operations performed
|
||||
* @return Relative error threshold based on data type characteristics
|
||||
*/
|
||||
template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
|
||||
double get_relative_threshold(const int number_of_accumulations = 1)
|
||||
CK_TILE_HOST double get_relative_threshold(const int number_of_accumulations = 1)
|
||||
{
|
||||
using F8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using F16 = ck_tile::half_t;
|
||||
using BF16 = ck_tile::bf16_t;
|
||||
using F32 = float;
|
||||
using I8 = int8_t;
|
||||
using I32 = int32_t;
|
||||
|
||||
static_assert(
|
||||
is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
|
||||
@@ -72,16 +92,23 @@ double get_relative_threshold(const int number_of_accumulations = 1)
|
||||
return std::max(acc_error, midway_error);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Calculate absolute error threshold for numerical comparisons
|
||||
*
|
||||
* Calculates the absolute error threshold based on the maximum possible value and
|
||||
* the characteristics of the data types involved in the computation.
|
||||
*
|
||||
* @tparam ComputeDataType Type used for computation
|
||||
* @tparam OutDataType Type used for output
|
||||
* @tparam AccDataType Type used for accumulation (defaults to ComputeDataType)
|
||||
* @param max_possible_num Maximum possible value in the computation
|
||||
* @param number_of_accumulations Number of accumulation operations performed
|
||||
* @return Absolute error threshold based on data type characteristics and maximum value
|
||||
*/
|
||||
template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
|
||||
double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations = 1)
|
||||
CK_TILE_HOST double get_absolute_threshold(const double max_possible_num,
|
||||
const int number_of_accumulations = 1)
|
||||
{
|
||||
using F8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using F16 = ck_tile::half_t;
|
||||
using BF16 = ck_tile::bf16_t;
|
||||
using F32 = float;
|
||||
using I8 = int8_t;
|
||||
using I32 = int32_t;
|
||||
|
||||
static_assert(
|
||||
is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
|
||||
@@ -128,6 +155,16 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
|
||||
return std::max(acc_error, midway_error);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Stream operator overload for vector output
|
||||
*
|
||||
* Provides a formatted string representation of a vector, useful for debugging and logging.
|
||||
*
|
||||
* @tparam T Type of vector elements
|
||||
* @param os Output stream
|
||||
* @param v Vector to output
|
||||
* @return Reference to the output stream
|
||||
*/
|
||||
template <typename T>
|
||||
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
|
||||
{
|
||||
@@ -145,6 +182,66 @@ std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
|
||||
return os << "]";
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Check for size mismatch between output and reference ranges
|
||||
*
|
||||
* Verifies that the output and reference ranges are the same size.
|
||||
*
|
||||
* @tparam Range Type of output range
|
||||
* @tparam RefRange Type of reference range
|
||||
* @param out Output range to check
|
||||
* @param ref Reference range to check against
|
||||
* @param msg Error message to display if sizes mismatch
|
||||
* @return True if sizes mismatch, false otherwise
|
||||
*/
|
||||
template <typename Range, typename RefRange>
|
||||
CK_TILE_HOST bool check_size_mismatch(const Range& out,
|
||||
const RefRange& ref,
|
||||
const std::string& msg = "Error: Incorrect results!")
|
||||
{
|
||||
if(out.size() != ref.size())
|
||||
{
|
||||
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
|
||||
<< std::endl;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Report error statistics for numerical comparisons
|
||||
*
|
||||
* Outputs statistics about numerical comparison errors including count and maximum error.
|
||||
*
|
||||
* @param err_count Number of errors found
|
||||
* @param max_err Maximum error value encountered
|
||||
* @param total_size Total number of elements compared
|
||||
*/
|
||||
CK_TILE_HOST void report_error_stats(int err_count, double max_err, std::size_t total_size)
|
||||
{
|
||||
const float error_percent =
|
||||
static_cast<float>(err_count) / static_cast<float>(total_size) * 100.f;
|
||||
std::cerr << "max err: " << max_err;
|
||||
std::cerr << ", number of errors: " << err_count;
|
||||
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Check errors between floating point ranges using the specified tolerances.
|
||||
*
|
||||
* Compares two ranges of floating point values within specified relative and absolute tolerances.
|
||||
* This overload handles standard floating point types except half precision floating point.
|
||||
*
|
||||
* @tparam Range Type of output range
|
||||
* @tparam RefRange Type of reference range
|
||||
* @param out Output range to check
|
||||
* @param ref Reference range to check against
|
||||
* @param msg Error message to display if check fails
|
||||
* @param rtol Relative tolerance
|
||||
* @param atol Absolute tolerance
|
||||
* @param allow_infinity_ref Whether to allow infinity in reference values
|
||||
* @return True if check passes, false otherwise
|
||||
*/
|
||||
template <typename Range, typename RefRange>
|
||||
typename std::enable_if<
|
||||
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
|
||||
@@ -158,12 +255,9 @@ check_err(const Range& out,
|
||||
double atol = 3e-6,
|
||||
bool allow_infinity_ref = false)
|
||||
{
|
||||
if(out.size() != ref.size())
|
||||
{
|
||||
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
|
||||
<< std::endl;
|
||||
|
||||
if(check_size_mismatch(out, ref, msg))
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto is_infinity_error = [=](auto o, auto r) {
|
||||
const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
|
||||
@@ -196,15 +290,27 @@ check_err(const Range& out,
|
||||
}
|
||||
if(!res)
|
||||
{
|
||||
const float error_percent =
|
||||
static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
|
||||
std::cerr << "max err: " << max_err;
|
||||
std::cerr << ", number of errors: " << err_count;
|
||||
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
|
||||
report_error_stats(err_count, max_err, ref.size());
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Check errors between floating point ranges using the specified tolerances
|
||||
*
|
||||
* Compares two ranges of brain floating point values within specified relative and absolute
|
||||
* tolerances.
|
||||
*
|
||||
* @tparam Range Type of output range
|
||||
* @tparam RefRange Type of reference range
|
||||
* @param out Output range to check
|
||||
* @param ref Reference range to check against
|
||||
* @param msg Error message to display if check fails
|
||||
* @param rtol Relative tolerance
|
||||
* @param atol Absolute tolerance
|
||||
* @param allow_infinity_ref Whether to allow infinity in reference values
|
||||
* @return True if check passes, false otherwise
|
||||
*/
|
||||
template <typename Range, typename RefRange>
|
||||
typename std::enable_if<
|
||||
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
|
||||
@@ -217,12 +323,8 @@ check_err(const Range& out,
|
||||
double atol = 1e-3,
|
||||
bool allow_infinity_ref = false)
|
||||
{
|
||||
if(out.size() != ref.size())
|
||||
{
|
||||
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
|
||||
<< std::endl;
|
||||
if(check_size_mismatch(out, ref, msg))
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto is_infinity_error = [=](auto o, auto r) {
|
||||
const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
|
||||
@@ -256,15 +358,28 @@ check_err(const Range& out,
|
||||
}
|
||||
if(!res)
|
||||
{
|
||||
const float error_percent =
|
||||
static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
|
||||
std::cerr << "max err: " << max_err;
|
||||
std::cerr << ", number of errors: " << err_count;
|
||||
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
|
||||
report_error_stats(err_count, max_err, ref.size());
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Check errors between half precision floating point ranges
|
||||
*
|
||||
* Compares two ranges of half precision floating point values within specified tolerances.
|
||||
* This specialization handles the specific requirements and characteristics of half precision
|
||||
* floating point comparisons.
|
||||
*
|
||||
* @tparam Range Type of output range
|
||||
* @tparam RefRange Type of reference range
|
||||
* @param out Output range to check
|
||||
* @param ref Reference range to check against
|
||||
* @param msg Error message to display if check fails
|
||||
* @param rtol Relative tolerance
|
||||
* @param atol Absolute tolerance
|
||||
* @param allow_infinity_ref Whether to allow infinity in reference values
|
||||
* @return True if check passes, false otherwise
|
||||
*/
|
||||
template <typename Range, typename RefRange>
|
||||
typename std::enable_if<
|
||||
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
|
||||
@@ -277,12 +392,8 @@ check_err(const Range& out,
|
||||
double atol = 1e-3,
|
||||
bool allow_infinity_ref = false)
|
||||
{
|
||||
if(out.size() != ref.size())
|
||||
{
|
||||
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
|
||||
<< std::endl;
|
||||
if(check_size_mismatch(out, ref, msg))
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto is_infinity_error = [=](auto o, auto r) {
|
||||
const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
|
||||
@@ -315,15 +426,26 @@ check_err(const Range& out,
|
||||
}
|
||||
if(!res)
|
||||
{
|
||||
const float error_percent =
|
||||
static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
|
||||
std::cerr << "max err: " << max_err;
|
||||
std::cerr << ", number of errors: " << err_count;
|
||||
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
|
||||
report_error_stats(err_count, max_err, ref.size());
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Check errors between integer ranges
|
||||
*
|
||||
* Compares two ranges of integer values with an absolute tolerance.
|
||||
* This specialization handles integer types and optionally int4_t when the
|
||||
* experimental bit int extension is enabled.
|
||||
*
|
||||
* @tparam Range Type of output range
|
||||
* @tparam RefRange Type of reference range
|
||||
* @param out Output range to check
|
||||
* @param ref Reference range to check against
|
||||
* @param msg Error message to display if check fails
|
||||
* @param atol Absolute tolerance
|
||||
* @return True if check passes, false otherwise
|
||||
*/
|
||||
template <typename Range, typename RefRange>
|
||||
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
|
||||
std::is_integral_v<ranges::range_value_t<Range>> &&
|
||||
@@ -339,12 +461,8 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
|
||||
double = 0,
|
||||
double atol = 0)
|
||||
{
|
||||
if(out.size() != ref.size())
|
||||
{
|
||||
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
|
||||
<< std::endl;
|
||||
if(check_size_mismatch(out, ref, msg))
|
||||
return false;
|
||||
}
|
||||
|
||||
bool res{true};
|
||||
int err_count = 0;
|
||||
@@ -370,15 +488,28 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
|
||||
}
|
||||
if(!res)
|
||||
{
|
||||
const float error_percent =
|
||||
static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
|
||||
std::cerr << "max err: " << max_err;
|
||||
std::cerr << ", number of errors: " << err_count;
|
||||
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
|
||||
report_error_stats(err_count, static_cast<double>(max_err), ref.size());
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Check errors between FP8 ranges
|
||||
*
|
||||
* Specialized comparison for 8-bit floating point values that takes into account
|
||||
* the unique characteristics and limitations of FP8 arithmetic, including
|
||||
* rounding point distances and special handling of infinity values.
|
||||
*
|
||||
* @tparam Range Type of output range
|
||||
* @tparam RefRange Type of reference range
|
||||
* @param out Output range to check
|
||||
* @param ref Reference range to check against
|
||||
* @param msg Error message to display if check fails
|
||||
* @param max_rounding_point_distance Maximum allowed distance between rounding points
|
||||
* @param atol Absolute tolerance
|
||||
* @param allow_infinity_ref Whether to allow infinity in reference values
|
||||
* @return True if check passes, false otherwise
|
||||
*/
|
||||
template <typename Range, typename RefRange>
|
||||
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
|
||||
std::is_same_v<ranges::range_value_t<Range>, fp8_t>),
|
||||
@@ -390,12 +521,8 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
|
||||
double atol = 1e-1,
|
||||
bool allow_infinity_ref = false)
|
||||
{
|
||||
if(out.size() != ref.size())
|
||||
{
|
||||
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
|
||||
<< std::endl;
|
||||
if(check_size_mismatch(out, ref, msg))
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto is_infinity_error = [=](auto o, auto r) {
|
||||
const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
|
||||
@@ -447,15 +574,27 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
|
||||
}
|
||||
if(!res)
|
||||
{
|
||||
const float error_percent =
|
||||
static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
|
||||
std::cerr << "max err: " << max_err;
|
||||
std::cerr << ", number of errors: " << err_count;
|
||||
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
|
||||
report_error_stats(err_count, max_err, ref.size());
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Check errors between BF8 ranges
|
||||
*
|
||||
* Specialized comparison for 8-bit brain floating point values that considers
|
||||
* the specific numerical properties and error characteristics of the BF8 format.
|
||||
*
|
||||
* @tparam Range Type of output range
|
||||
* @tparam RefRange Type of reference range
|
||||
* @param out Output range to check
|
||||
* @param ref Reference range to check against
|
||||
* @param msg Error message to display if check fails
|
||||
* @param rtol Relative tolerance
|
||||
* @param atol Absolute tolerance
|
||||
* @param allow_infinity_ref Whether to allow infinity in reference values
|
||||
* @return True if check passes, false otherwise
|
||||
*/
|
||||
template <typename Range, typename RefRange>
|
||||
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
|
||||
std::is_same_v<ranges::range_value_t<Range>, bf8_t>),
|
||||
@@ -467,12 +606,8 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
|
||||
double atol = 1e-3,
|
||||
bool allow_infinity_ref = false)
|
||||
{
|
||||
if(out.size() != ref.size())
|
||||
{
|
||||
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
|
||||
<< std::endl;
|
||||
if(check_size_mismatch(out, ref, msg))
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto is_infinity_error = [=](auto o, auto r) {
|
||||
const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
|
||||
@@ -505,11 +640,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
|
||||
}
|
||||
if(!res)
|
||||
{
|
||||
const float error_percent =
|
||||
static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
|
||||
std::cerr << "max err: " << max_err;
|
||||
std::cerr << ", number of errors: " << err_count;
|
||||
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
|
||||
report_error_stats(err_count, max_err, ref.size());
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -33,13 +33,14 @@ struct IsCharArray<const char (&)[N]> : std::true_type
|
||||
};
|
||||
|
||||
template <typename... Ts>
|
||||
inline constexpr bool AllConvertibleToStringView = ((std::is_convertible_v<Ts, std::string_view> ||
|
||||
IsCharArray<Ts>::value ||
|
||||
std::is_same_v<Ts, char>)&&...);
|
||||
inline constexpr bool AllConvertibleToStringView =
|
||||
((std::is_convertible_v<Ts, std::string_view> || IsCharArray<Ts>::value ||
|
||||
std::is_same_v<Ts, char>) &&
|
||||
...);
|
||||
|
||||
template <typename... Ts>
|
||||
[[nodiscard]] auto concat(const Ts&... xs)
|
||||
-> std::enable_if_t<!AllConvertibleToStringView<Ts...>, std::string>
|
||||
[[nodiscard]] auto
|
||||
concat(const Ts&... xs) -> std::enable_if_t<!AllConvertibleToStringView<Ts...>, std::string>
|
||||
{
|
||||
using ::operator<<;
|
||||
thread_local std::ostringstream oss;
|
||||
@@ -78,8 +79,8 @@ template <std::size_t N>
|
||||
}
|
||||
|
||||
template <typename... Ts>
|
||||
auto concatInto(std::string& result, const Ts&... xs)
|
||||
-> std::enable_if_t<AllConvertibleToStringView<Ts...>, void>
|
||||
auto concatInto(std::string& result,
|
||||
const Ts&... xs) -> std::enable_if_t<AllConvertibleToStringView<Ts...>, void>
|
||||
{
|
||||
const std::size_t space = (1 + ... + getSize(xs));
|
||||
result.reserve(result.size() + space);
|
||||
@@ -87,8 +88,8 @@ auto concatInto(std::string& result, const Ts&... xs)
|
||||
}
|
||||
|
||||
template <typename... Ts>
|
||||
[[nodiscard]] auto concat(const Ts&... xs)
|
||||
-> std::enable_if_t<AllConvertibleToStringView<Ts...>, std::string>
|
||||
[[nodiscard]] auto
|
||||
concat(const Ts&... xs) -> std::enable_if_t<AllConvertibleToStringView<Ts...>, std::string>
|
||||
{
|
||||
std::string result;
|
||||
concatInto(result, xs...);
|
||||
|
||||
@@ -20,10 +20,35 @@ __global__ void set_buffer_value(T* p, T x, uint64_t buffer_element_size)
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Container for storing data in GPU device memory
|
||||
* @brief Manages device memory allocation and host-device data transfers
|
||||
*
|
||||
* DeviceMem encapsulates GPU memory management operations using HIP runtime API.
|
||||
* It provides functionality for allocating device memory, transferring data between
|
||||
* host and device, and performing basic memory operations.
|
||||
*
|
||||
* Key features:
|
||||
* - Automatic memory allocation and deallocation
|
||||
* - Host-to-device and device-to-host data transfers
|
||||
* - Memory initialization operations
|
||||
* - Integration with HostTensor for simplified data handling
|
||||
*
|
||||
* Usage example:
|
||||
* ```
|
||||
* // Allocate device memory
|
||||
* BHostTensor<float> AHostData({256});
|
||||
* DeviceMem d_mem(BHostData.get_element_space_size_in_bytes());
|
||||
*
|
||||
* // Transfer data to device
|
||||
* HostTensor<float> AHostTensor({256});
|
||||
* d_mem.ToDevice(AHostData.data());
|
||||
*
|
||||
* // Retrieve data from device
|
||||
* HostTensor<float> ResultHostTensor({256});
|
||||
* d_mem.FromDevice(ResultHostTensor.data());
|
||||
* ```
|
||||
*/
|
||||
struct DeviceMem
|
||||
|
||||
{
|
||||
DeviceMem() : mpDeviceBuf(nullptr), mMemSize(0) {}
|
||||
DeviceMem(std::size_t mem_size) : mMemSize(mem_size)
|
||||
@@ -163,8 +188,8 @@ struct DeviceMem
|
||||
}
|
||||
}
|
||||
|
||||
void* mpDeviceBuf;
|
||||
std::size_t mMemSize;
|
||||
void* mpDeviceBuf; ///< pointer to device buffer
|
||||
std::size_t mMemSize; ///< size of device buffer in bytes
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
75
include/ck_tile/host/device_prop.hpp
Normal file
75
include/ck_tile/host/device_prop.hpp
Normal file
@@ -0,0 +1,75 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifndef __HIPCC_RTC__
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
constexpr unsigned int fnv1a_hash(std::string_view str, unsigned int h = 2166136261u)
|
||||
{
|
||||
return str.empty() ? h
|
||||
: fnv1a_hash(str.substr(1),
|
||||
(h ^ static_cast<unsigned char>(str.front())) * 16777619u);
|
||||
}
|
||||
inline std::string get_device_name()
|
||||
{
|
||||
hipDeviceProp_t props{};
|
||||
int device;
|
||||
auto status = hipGetDevice(&device);
|
||||
if(status != hipSuccess)
|
||||
{
|
||||
return std::string();
|
||||
}
|
||||
status = hipGetDeviceProperties(&props, device);
|
||||
if(status != hipSuccess)
|
||||
{
|
||||
return std::string();
|
||||
}
|
||||
const std::string raw_name(props.gcnArchName);
|
||||
const auto name = raw_name.substr(0, raw_name.find(':')); // str.substr(0, npos) returns str.
|
||||
switch(fnv1a_hash(name))
|
||||
{
|
||||
// https://github.com/ROCm/MIOpen/blob/8498875aef84878e04c1eabefdf6571514891086/src/target_properties.cpp#L40
|
||||
case fnv1a_hash("Ellesmere"):
|
||||
case fnv1a_hash("Baffin"):
|
||||
case fnv1a_hash("RacerX"):
|
||||
case fnv1a_hash("Polaris10"):
|
||||
case fnv1a_hash("Polaris11"):
|
||||
case fnv1a_hash("Tonga"):
|
||||
case fnv1a_hash("Fiji"):
|
||||
case fnv1a_hash("gfx800"):
|
||||
case fnv1a_hash("gfx802"):
|
||||
case fnv1a_hash("gfx804"): return "gfx803";
|
||||
case fnv1a_hash("Vega10"):
|
||||
case fnv1a_hash("gfx901"): return "gfx900";
|
||||
case fnv1a_hash("10.3.0 Sienna_Cichlid 18"): return "gfx1030";
|
||||
default: return name;
|
||||
}
|
||||
}
|
||||
|
||||
inline bool is_gfx11_supported()
|
||||
{
|
||||
return get_device_name() == "gfx1100" || get_device_name() == "gfx1101" ||
|
||||
get_device_name() == "gfx1102" || get_device_name() == "gfx1103" ||
|
||||
get_device_name() == "gfx1150" || get_device_name() == "gfx1151" ||
|
||||
get_device_name() == "gfx1152";
|
||||
}
|
||||
|
||||
inline bool is_gfx12_supported()
|
||||
{
|
||||
return get_device_name() == "gfx1200" || get_device_name() == "gfx1201";
|
||||
}
|
||||
|
||||
inline bool is_load_tr_supported()
|
||||
{
|
||||
// Check if load transpose is supported.
|
||||
return get_device_name() == "gfx950";
|
||||
}
|
||||
} // namespace ck_tile
|
||||
|
||||
#endif
|
||||
@@ -8,6 +8,7 @@
|
||||
#include <iterator>
|
||||
#include <optional>
|
||||
#include <random>
|
||||
#include <stdexcept>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <unordered_set>
|
||||
@@ -17,13 +18,31 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/**
|
||||
* @brief Functor for filling a range with randomly generated values from a uniform distribution.
|
||||
*
|
||||
* This struct provides functionality to fill iterators or ranges with random values
|
||||
* generated from a uniform distribution. It supports both single-threaded and
|
||||
* multi-threaded operation.
|
||||
*
|
||||
* @tparam T The target type for the generated values.
|
||||
*
|
||||
* @note The multi-threaded implementation is not guaranteed to provide perfectly
|
||||
* distributed values across threads.
|
||||
*
|
||||
* @example
|
||||
*
|
||||
* // Direct usage without creating a separate variable:
|
||||
* ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_host_tensor);
|
||||
*/
|
||||
template <typename T>
|
||||
struct FillUniformDistribution
|
||||
{
|
||||
float a_{-5.f};
|
||||
float b_{5.f};
|
||||
std::optional<uint32_t> seed_{11939};
|
||||
// ATTENTION: threaded does not guarantee the distribution between thread
|
||||
// ATTENTION: Whether to use multi-threading (note: not guaranteed to be perfectly distributed
|
||||
// across threads).
|
||||
bool threaded = false;
|
||||
|
||||
template <typename ForwardIter>
|
||||
@@ -45,7 +64,7 @@ struct FillUniformDistribution
|
||||
return;
|
||||
// need to make each thread unique, add an offset to current seed
|
||||
std::mt19937 gen(seed_.has_value() ? (*seed_ + iw_begin)
|
||||
: std::random_device{}());
|
||||
: std::random_device{}());
|
||||
std::uniform_real_distribution<float> dis(a_, b_);
|
||||
std::generate(first + iw_begin, first + iw_end, [&dis, &gen]() {
|
||||
return ck_tile::type_convert<T>(dis(gen));
|
||||
@@ -74,6 +93,60 @@ struct FillUniformDistribution
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FillUniformDistribution<ck_tile::pk_int4_t>
|
||||
{
|
||||
float a_{-8.f}; // same type as primary template so that
|
||||
// `FillUniformDistribution<Type>{-5.0f, 5.0f}` works for all types
|
||||
float b_{7.f};
|
||||
std::optional<uint32_t> seed_{11939};
|
||||
template <typename ForwardIter>
|
||||
void operator()(ForwardIter first, ForwardIter last) const
|
||||
{
|
||||
if(a_ < -8.0f || b_ > 7.0f)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"a_ or b_ of FillUniformDistribution<ck_tile::pk_int4_t> is out of range.");
|
||||
}
|
||||
|
||||
int min_value = static_cast<int>(a_);
|
||||
int max_value = static_cast<int>(b_);
|
||||
constexpr auto int4_array = std::array<uint8_t, 16>{0x88,
|
||||
0x99,
|
||||
0xaa,
|
||||
0xbb,
|
||||
0xcc,
|
||||
0xdd,
|
||||
0xee,
|
||||
0xff,
|
||||
0x00,
|
||||
0x11,
|
||||
0x22,
|
||||
0x33,
|
||||
0x44,
|
||||
0x55,
|
||||
0x66,
|
||||
0x77};
|
||||
std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}());
|
||||
std::uniform_int_distribution<std::int32_t> dis(0, max_value - min_value + 1);
|
||||
while(first != last)
|
||||
{
|
||||
int randomInt = dis(gen);
|
||||
*first = int4_array[randomInt + (min_value + 8)];
|
||||
++first;
|
||||
}
|
||||
}
|
||||
template <typename ForwardRange>
|
||||
auto operator()(ForwardRange&& range) const
|
||||
-> std::void_t<decltype(std::declval<const FillUniformDistribution&>()(
|
||||
std::begin(std::forward<ForwardRange>(range)),
|
||||
std::end(std::forward<ForwardRange>(range))))>
|
||||
{
|
||||
(*this)(std::begin(std::forward<ForwardRange>(range)),
|
||||
std::end(std::forward<ForwardRange>(range)));
|
||||
}
|
||||
};
|
||||
|
||||
namespace impl {
|
||||
|
||||
// clang-format off
|
||||
@@ -169,7 +242,7 @@ struct FillNormalDistribution
|
||||
return;
|
||||
// need to make each thread unique, add an offset to current seed
|
||||
std::mt19937 gen(seed_.has_value() ? (*seed_ + iw_begin)
|
||||
: std::random_device{}());
|
||||
: std::random_device{}());
|
||||
std::normal_distribution<float> dis(mean_, std::sqrt(variance_));
|
||||
std::generate(first + iw_begin, first + iw_end, [&dis, &gen]() {
|
||||
return ck_tile::type_convert<T>(dis(gen));
|
||||
@@ -280,7 +353,7 @@ struct FillMonotonicSeq
|
||||
template <typename ForwardIter>
|
||||
void operator()(ForwardIter first, ForwardIter last) const
|
||||
{
|
||||
std::generate(first, last, [=, n = init_value_]() mutable {
|
||||
std::generate(first, last, [=, *this, n = init_value_]() mutable {
|
||||
auto tmp = n;
|
||||
if constexpr(std::is_same_v<decltype(tmp), pk_int4_t>)
|
||||
{
|
||||
@@ -315,7 +388,7 @@ struct FillStepRange
|
||||
template <typename ForwardIter>
|
||||
void operator()(ForwardIter first, ForwardIter last) const
|
||||
{
|
||||
std::generate(first, last, [=, n = start_value_]() mutable {
|
||||
std::generate(first, last, [=, *this, n = start_value_]() mutable {
|
||||
auto tmp = n;
|
||||
n += step_;
|
||||
if constexpr(IsAscending)
|
||||
@@ -334,9 +407,10 @@ struct FillStepRange
|
||||
}
|
||||
|
||||
template <typename ForwardRange>
|
||||
auto operator()(ForwardRange&& range) const -> std::void_t<
|
||||
decltype(std::declval<const FillStepRange&>()(std::begin(std::forward<ForwardRange>(range)),
|
||||
std::end(std::forward<ForwardRange>(range))))>
|
||||
auto operator()(ForwardRange&& range) const
|
||||
-> std::void_t<decltype(std::declval<const FillStepRange&>()(
|
||||
std::begin(std::forward<ForwardRange>(range)),
|
||||
std::end(std::forward<ForwardRange>(range))))>
|
||||
{
|
||||
(*this)(std::begin(std::forward<ForwardRange>(range)),
|
||||
std::end(std::forward<ForwardRange>(range)));
|
||||
@@ -355,9 +429,53 @@ struct FillConstant
|
||||
}
|
||||
|
||||
template <typename ForwardRange>
|
||||
auto operator()(ForwardRange&& range) const -> std::void_t<
|
||||
decltype(std::declval<const FillConstant&>()(std::begin(std::forward<ForwardRange>(range)),
|
||||
std::end(std::forward<ForwardRange>(range))))>
|
||||
auto operator()(ForwardRange&& range) const
|
||||
-> std::void_t<decltype(std::declval<const FillConstant&>()(
|
||||
std::begin(std::forward<ForwardRange>(range)),
|
||||
std::end(std::forward<ForwardRange>(range))))>
|
||||
{
|
||||
(*this)(std::begin(std::forward<ForwardRange>(range)),
|
||||
std::end(std::forward<ForwardRange>(range)));
|
||||
}
|
||||
};
|
||||
|
||||
//----------------------------------------------------------------------------------------------
|
||||
/// @brief Transforms given input to fit 2:4 structured sparsity pattern so
|
||||
/// every subgroup of 4 elements contain at most 2 non-zero elements
|
||||
template <typename T>
|
||||
struct AdjustToStructuredSparsity
|
||||
{
|
||||
size_t start{0};
|
||||
// masks represent all valid 2:4 structured sparsity permutations
|
||||
// clang-format off
|
||||
static constexpr int32_t masks[] = {0, 0, 1, 1,
|
||||
0, 1, 0, 1,
|
||||
0, 1, 1, 0,
|
||||
1, 0, 0, 1,
|
||||
1, 0, 1, 0,
|
||||
1, 1, 0, 0,
|
||||
0, 0, 0, 1,
|
||||
0, 0, 1, 0,
|
||||
0, 1, 0, 0,
|
||||
1, 0, 0, 0};
|
||||
// clang-format on
|
||||
|
||||
template <typename ForwardIter>
|
||||
void operator()(ForwardIter first, ForwardIter last) const
|
||||
{
|
||||
std::transform(first, last, first, [=, *this, index = start](T val) mutable {
|
||||
auto tmp = val * masks[index % (sizeof(masks) / sizeof(int32_t))];
|
||||
index += 1;
|
||||
|
||||
return type_convert<T>(tmp);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename ForwardRange>
|
||||
auto operator()(ForwardRange&& range) const
|
||||
-> std::void_t<decltype(std::declval<const AdjustToStructuredSparsity&>()(
|
||||
std::begin(std::forward<ForwardRange>(range)),
|
||||
std::end(std::forward<ForwardRange>(range))))>
|
||||
{
|
||||
(*this)(std::begin(std::forward<ForwardRange>(range)),
|
||||
std::end(std::forward<ForwardRange>(range)));
|
||||
@@ -396,9 +514,10 @@ struct FillTrigValue
|
||||
}
|
||||
|
||||
template <typename ForwardRange>
|
||||
auto operator()(ForwardRange&& range) const -> std::void_t<
|
||||
decltype(std::declval<const FillTrigValue&>()(std::begin(std::forward<ForwardRange>(range)),
|
||||
std::end(std::forward<ForwardRange>(range))))>
|
||||
auto operator()(ForwardRange&& range) const
|
||||
-> std::void_t<decltype(std::declval<const FillTrigValue&>()(
|
||||
std::begin(std::forward<ForwardRange>(range)),
|
||||
std::end(std::forward<ForwardRange>(range))))>
|
||||
{
|
||||
(*this)(std::begin(std::forward<ForwardRange>(range)),
|
||||
std::end(std::forward<ForwardRange>(range)));
|
||||
|
||||
30
include/ck_tile/host/flush_icache.hpp
Normal file
30
include/ck_tile/host/flush_icache.hpp
Normal file
@@ -0,0 +1,30 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace ck_tile {
|
||||
static __global__ void flush_cache()
|
||||
{
|
||||
asm __volatile__("s_icache_inv \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t" ::
|
||||
:);
|
||||
}
|
||||
} // namespace ck_tile
|
||||
@@ -85,6 +85,19 @@ CK_TILE_HOST auto construct_f_unpack_args(F, T args)
|
||||
return construct_f_unpack_args_impl<F>(args, std::make_index_sequence<N>{});
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Descriptor for tensors in host memory.
|
||||
*
|
||||
* HostTensorDescriptor manages the shape (dimensions) and memory layout (strides)
|
||||
* of a tensor in host memory. It provides functionality to:
|
||||
* - Store tensor dimensions and strides
|
||||
* - Calculate default strides for contiguous memory layout
|
||||
* - Convert multi-dimensional indices to linear memory offsets
|
||||
* - Query tensor metadata (dimensions, element counts, etc.)
|
||||
*
|
||||
* The class supports both automatic stride calculation for contiguous memory layout
|
||||
* and custom strides for more complex memory patterns.
|
||||
*/
|
||||
struct HostTensorDescriptor
|
||||
{
|
||||
HostTensorDescriptor() = default;
|
||||
@@ -138,12 +151,35 @@ struct HostTensorDescriptor
|
||||
}
|
||||
|
||||
std::size_t get_num_of_dimension() const { return mLens.size(); }
|
||||
/**
|
||||
* @brief Calculates the total number of elements in the tensor.
|
||||
*
|
||||
* Computes the product of all dimension lengths to determine the
|
||||
* total element count in the tensor.
|
||||
*
|
||||
* @pre The lengths array (mLens) and strides array (mStrides) must have
|
||||
* the same size.
|
||||
*
|
||||
* @return The total number of elements in the tensor.
|
||||
*/
|
||||
std::size_t get_element_size() const
|
||||
{
|
||||
assert(mLens.size() == mStrides.size());
|
||||
return std::accumulate(
|
||||
mLens.begin(), mLens.end(), std::size_t{1}, std::multiplies<std::size_t>());
|
||||
}
|
||||
/**
|
||||
* @brief Calculates the total element space required for the tensor in memory.
|
||||
*
|
||||
* This method computes the minimum size of contiguous memory needed to store
|
||||
* all elements of the tensor, taking into account the tensor's dimensions and
|
||||
* strides. The calculation is based on the formula: 1 + max((length_i - 1) * stride_i)
|
||||
* across all dimensions.
|
||||
*
|
||||
* Dimensions with length 0 are skipped in this calculation.
|
||||
*
|
||||
* @return The size of the tensor's element space (number of elements).
|
||||
*/
|
||||
std::size_t get_element_space_size() const
|
||||
{
|
||||
std::size_t space = 1;
|
||||
@@ -165,6 +201,18 @@ struct HostTensorDescriptor
|
||||
|
||||
const std::vector<std::size_t>& get_strides() const { return mStrides; }
|
||||
|
||||
/**
|
||||
* @brief Calculates the linear offset from multi-dimensional indices.
|
||||
*
|
||||
* Converts a set of N-dimensional indices into a single linear offset by computing
|
||||
* the inner product of the indices with the tensor's strides.
|
||||
*
|
||||
* @tparam Is Parameter pack of index types (should be convertible to std::size_t)
|
||||
* @param is Variable number of indices, one for each dimension of the tensor
|
||||
* @return std::size_t Linear offset corresponding to the given multi-dimensional indices
|
||||
*
|
||||
* @pre The number of indices must match the number of dimensions in the tensor
|
||||
*/
|
||||
template <typename... Is>
|
||||
std::size_t GetOffsetFromMultiIndex(Is... is) const
|
||||
{
|
||||
@@ -173,7 +221,16 @@ struct HostTensorDescriptor
|
||||
return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0});
|
||||
}
|
||||
|
||||
std::size_t GetOffsetFromMultiIndex(std::vector<std::size_t> iss) const
|
||||
/**
|
||||
* @brief Calculates the linear memory offset from a multi-dimensional index
|
||||
*
|
||||
* Computes the linear offset by performing an inner product between the provided
|
||||
* multi-dimensional indices and the tensor's strides.
|
||||
*
|
||||
* @param iss Vector containing the multi-dimensional indices
|
||||
* @return The calculated linear offset as a size_t
|
||||
*/
|
||||
std::size_t GetOffsetFromMultiIndex(const std::vector<std::size_t>& iss) const
|
||||
{
|
||||
return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0});
|
||||
}
|
||||
@@ -194,8 +251,8 @@ struct HostTensorDescriptor
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<std::size_t> mLens;
|
||||
std::vector<std::size_t> mStrides;
|
||||
std::vector<std::size_t> mLens; ///< Lengths of each dimension
|
||||
std::vector<std::size_t> mStrides; ///< Strides for each dimension
|
||||
};
|
||||
|
||||
template <typename New2Old>
|
||||
@@ -321,7 +378,7 @@ struct HostTensor
|
||||
~HostTensor() = default;
|
||||
|
||||
HostTensor& operator=(const HostTensor&) = default;
|
||||
HostTensor& operator=(HostTensor&&) = default;
|
||||
HostTensor& operator=(HostTensor&&) = default;
|
||||
|
||||
template <typename FromT>
|
||||
explicit HostTensor(const HostTensor<FromT>& other) : HostTensor(other.template CopyAsType<T>())
|
||||
@@ -352,7 +409,13 @@ struct HostTensor
|
||||
}
|
||||
|
||||
// void SetZero() { ck_tile::ranges::fill<T>(mData, 0); }
|
||||
void SetZero() { std::fill(mData.begin(), mData.end(), 0); }
|
||||
void SetZero()
|
||||
{
|
||||
if constexpr(std::is_same_v<T, e8m0_t>)
|
||||
std::fill(mData.begin(), mData.end(), e8m0_t{1.f});
|
||||
else
|
||||
std::fill(mData.begin(), mData.end(), 0);
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
void ForEach_impl(F&& f, std::vector<size_t>& idx, size_t rank)
|
||||
@@ -483,9 +546,12 @@ struct HostTensor
|
||||
return mData[GetOffsetFromMultiIndex(is...)];
|
||||
}
|
||||
|
||||
T& operator()(std::vector<std::size_t> idx) { return mData[GetOffsetFromMultiIndex(idx)]; }
|
||||
T& operator()(const std::vector<std::size_t>& idx)
|
||||
{
|
||||
return mData[GetOffsetFromMultiIndex(idx)];
|
||||
}
|
||||
|
||||
const T& operator()(std::vector<std::size_t> idx) const
|
||||
const T& operator()(const std::vector<std::size_t>& idx) const
|
||||
{
|
||||
return mData[GetOffsetFromMultiIndex(idx)];
|
||||
}
|
||||
@@ -662,6 +728,8 @@ struct HostTensor
|
||||
file << type_convert<float>(itm) << std::endl;
|
||||
else if(dtype == "int")
|
||||
file << type_convert<int>(itm) << std::endl;
|
||||
else if(dtype == "int8_t")
|
||||
file << static_cast<int>(type_convert<ck_tile::int8_t>(itm)) << std::endl;
|
||||
else
|
||||
// TODO: we didn't implement operator<< for all custom
|
||||
// data types, here fall back to float in case compile error
|
||||
@@ -681,6 +749,24 @@ struct HostTensor
|
||||
Data mData;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Creates a host tensor descriptor with specified dimensions and layout
|
||||
*
|
||||
* Constructs a HostTensorDescriptor with appropriate strides based on whether the tensor
|
||||
* layout is row-major or column-major. This is determined via the compile-time template
|
||||
* parameter `is_row_major`.
|
||||
*
|
||||
* @tparam is_row_major Compile-time flag indicating if the layout is row-major (true) or
|
||||
* column-major (false)
|
||||
*
|
||||
* @param row Number of rows in the tensor
|
||||
* @param col Number of columns in the tensor
|
||||
* @param stride Stride between adjacent rows (for row-major) or columns (for column-major)
|
||||
*
|
||||
* @return HostTensorDescriptor with shape {row, col} and strides:
|
||||
* - For row-major: {stride, 1}
|
||||
* - For column-major: {1, stride}
|
||||
*/
|
||||
template <bool is_row_major>
|
||||
auto host_tensor_descriptor(std::size_t row,
|
||||
std::size_t col,
|
||||
@@ -698,6 +784,7 @@ auto host_tensor_descriptor(std::size_t row,
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
}
|
||||
|
||||
template <bool is_row_major>
|
||||
auto get_default_stride(std::size_t row,
|
||||
std::size_t col,
|
||||
@@ -718,5 +805,4 @@ auto get_default_stride(std::size_t row,
|
||||
else
|
||||
return stride;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -15,7 +15,7 @@ struct joinable_thread : std::thread
|
||||
{
|
||||
}
|
||||
|
||||
joinable_thread(joinable_thread&&) = default;
|
||||
joinable_thread(joinable_thread&&) = default;
|
||||
joinable_thread& operator=(joinable_thread&&) = default;
|
||||
|
||||
~joinable_thread()
|
||||
|
||||
@@ -1,23 +1,31 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <numeric>
|
||||
#include <functional>
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/host/stream_config.hpp"
|
||||
#include "ck_tile/core/utility/ignore.hpp"
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
#include "ck_tile/host/stream_config.hpp"
|
||||
#include "ck_tile/host/timer.hpp"
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <cstddef>
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <int MaxThreadPerBlock, int MinBlockPerCu, typename Kernel, typename... Args>
|
||||
#if CK_TILE_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(MaxThreadPerBlock, MinBlockPerCu)
|
||||
#endif
|
||||
__global__ void kentry(Args... args)
|
||||
{
|
||||
#if defined(__HIP_DEVICE_COMPILE__)
|
||||
Kernel{}(args...);
|
||||
#else
|
||||
(..., (ignore = args, 0));
|
||||
#endif
|
||||
}
|
||||
|
||||
//
|
||||
@@ -51,6 +59,60 @@ CK_TILE_HOST void launch_and_check(const stream_config& sc, Callables&&... calla
|
||||
}
|
||||
}
|
||||
|
||||
// Measure the preprocess time during the cold iterations
|
||||
template <typename TimerType, typename PreprocessFunc>
|
||||
CK_TILE_HOST double
|
||||
preprocess_profiling_impl(TimerType timer, const stream_config& s, PreprocessFunc preprocess)
|
||||
{
|
||||
timer.start(s.stream_id_);
|
||||
for(int i = 0; i < s.nrepeat_; i++)
|
||||
{
|
||||
if constexpr(!std::is_same_v<PreprocessFunc, std::nullptr_t>)
|
||||
{
|
||||
preprocess();
|
||||
}
|
||||
}
|
||||
timer.stop(s.stream_id_);
|
||||
|
||||
return timer.duration() / s.nrepeat_;
|
||||
}
|
||||
|
||||
template <typename TimerType, typename CallablesFunc, typename PreprocessFunc = std::nullptr_t>
|
||||
CK_TILE_HOST double timing_loop_impl(TimerType timer,
|
||||
const stream_config& s,
|
||||
CallablesFunc&& callables_func,
|
||||
PreprocessFunc preprocess = nullptr)
|
||||
{
|
||||
for(int i = 0; i < s.cold_niters_; i++)
|
||||
{
|
||||
callables_func();
|
||||
}
|
||||
// Only profile preprocess if it's provided
|
||||
auto preprocess_time = 0.0;
|
||||
if constexpr(!std::is_same_v<PreprocessFunc, std::nullptr_t>)
|
||||
{
|
||||
preprocess_time = preprocess_profiling_impl(gpu_timer{}, s, preprocess);
|
||||
}
|
||||
|
||||
int i = 0;
|
||||
timer.start(s.stream_id_);
|
||||
while(i < s.nrepeat_)
|
||||
{
|
||||
if constexpr(!std::is_same_v<PreprocessFunc, std::nullptr_t>)
|
||||
{
|
||||
preprocess();
|
||||
}
|
||||
|
||||
callables_func();
|
||||
i++;
|
||||
}
|
||||
timer.stop(s.stream_id_);
|
||||
|
||||
if(!i)
|
||||
return 0.;
|
||||
return (timer.duration() / s.nrepeat_) - preprocess_time;
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
/*
|
||||
* launch_kernel()
|
||||
@@ -81,37 +143,48 @@ CK_TILE_HOST void launch_and_check(const stream_config& sc, Callables&&... calla
|
||||
template <typename... Callables>
|
||||
CK_TILE_HOST float launch_kernel(const stream_config& s, Callables&&... callables)
|
||||
{
|
||||
static_assert(sizeof...(callables) > 0, "At least one callable is required!");
|
||||
|
||||
if(!s.time_kernel_)
|
||||
{
|
||||
launch_and_check(s, std::forward<Callables>(callables)...);
|
||||
return 0;
|
||||
}
|
||||
|
||||
auto time_launches = [&](auto timer) {
|
||||
// warmup
|
||||
for(int i = 0; i < s.cold_niters_; i++)
|
||||
{
|
||||
launch_and_check(s, std::forward<Callables>(callables)...);
|
||||
}
|
||||
|
||||
timer.start(s.stream_id_);
|
||||
for(int i = 0; i < s.nrepeat_; i++)
|
||||
{
|
||||
launch_and_check(s, std::forward<Callables>(callables)...);
|
||||
}
|
||||
timer.stop(s.stream_id_);
|
||||
|
||||
return timer.duration() / s.nrepeat_;
|
||||
};
|
||||
auto callables_func = [&]() { launch_and_check(s, std::forward<Callables>(callables)...); };
|
||||
|
||||
if(s.is_gpu_timer_)
|
||||
{
|
||||
return time_launches(gpu_timer{});
|
||||
return timing_loop_impl(gpu_timer{}, s, callables_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
return time_launches(cpu_timer{});
|
||||
return timing_loop_impl(cpu_timer{}, s, callables_func);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename PreprocessFunc, typename... Callables>
|
||||
CK_TILE_HOST float
|
||||
launch_kernel_time_mask(const stream_config& s, PreprocessFunc preprocess, Callables&&... callables)
|
||||
{
|
||||
static_assert(sizeof...(callables) > 0, "At least one callable is required!");
|
||||
|
||||
if(!s.time_kernel_)
|
||||
{
|
||||
preprocess();
|
||||
launch_and_check(s, std::forward<Callables>(callables)...);
|
||||
return 0;
|
||||
}
|
||||
|
||||
auto callables_func = [&]() { launch_and_check(s, std::forward<Callables>(callables)...); };
|
||||
|
||||
if(s.is_gpu_timer_)
|
||||
{
|
||||
return timing_loop_impl(gpu_timer{}, s, callables_func, preprocess);
|
||||
}
|
||||
else
|
||||
{
|
||||
return timing_loop_impl(cpu_timer{}, s, callables_func, preprocess);
|
||||
}
|
||||
}
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
@@ -11,6 +11,110 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ADataType,
|
||||
typename QDataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
uint32_t QuantGroupSize,
|
||||
bool aquant,
|
||||
typename AElementOp = ck_tile::identity,
|
||||
typename BElementOp = ck_tile::identity,
|
||||
typename ACCElementOp = ck_tile::identity>
|
||||
CK_TILE_HOST void reference_gemm_quant(const HostTensor<ADataType>& a_m_k,
|
||||
const HostTensor<QDataType>& q,
|
||||
const HostTensor<BDataType>& b_k_n,
|
||||
HostTensor<CDataType>& c_m_n,
|
||||
const AElementOp& a_element_op = {},
|
||||
const BElementOp& b_element_op = {},
|
||||
const ACCElementOp& acc_element_op = {})
|
||||
{
|
||||
const std::size_t M = a_m_k.get_length(0);
|
||||
const std::size_t N = b_k_n.get_length(1);
|
||||
const std::size_t K = a_m_k.get_length(1);
|
||||
|
||||
auto f_mn = [&](auto m, auto n) {
|
||||
AccDataType v_acc = 0, v_block_acc = 0;
|
||||
|
||||
static_assert(std::is_same_v<ADataType, pk_int4_t> || std::is_same_v<ADataType, fp8_t> ||
|
||||
std::is_same_v<ADataType, bf8_t>);
|
||||
static_assert(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t> ||
|
||||
std::is_same_v<BDataType, pk_int4_t>);
|
||||
static_assert(std::is_same_v<AccDataType, float>);
|
||||
static_assert(std::is_same_v<CDataType, float> ||
|
||||
std::is_same_v<CDataType, ck_tile::half_t>);
|
||||
for(std::size_t k = 0; k < K; ++k)
|
||||
{
|
||||
AccDataType v_a;
|
||||
AccDataType v_b;
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t_signed_conversion(pk_val);
|
||||
if(k % 2 == 1)
|
||||
v_a = fp32_val.hi;
|
||||
else
|
||||
v_a = fp32_val.lo;
|
||||
}
|
||||
else
|
||||
{
|
||||
v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
|
||||
}
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t_signed_conversion(pk_val);
|
||||
if(k % 2 == 1)
|
||||
v_b = fp32_val.hi;
|
||||
else
|
||||
v_b = fp32_val.lo;
|
||||
}
|
||||
else if constexpr(std::is_same_v<BDataType, fp8_t>)
|
||||
{
|
||||
v_b = fp8_to_float_raw(b_element_op(b_k_n(k, n)));
|
||||
}
|
||||
else
|
||||
{
|
||||
v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
|
||||
}
|
||||
v_block_acc += v_a * v_b;
|
||||
|
||||
// Apply group dequant scale
|
||||
if((k + 1) % QuantGroupSize == 0)
|
||||
{
|
||||
float scale = 0.f;
|
||||
index_t outer_dim = (aquant) ? m : k / QuantGroupSize;
|
||||
index_t inner_dim = (aquant) ? k / QuantGroupSize : n;
|
||||
|
||||
if constexpr(std::is_same_v<QDataType, float>)
|
||||
{
|
||||
scale = q(outer_dim, inner_dim);
|
||||
}
|
||||
else if constexpr(std::is_same_v<QDataType, ck_tile::fp8_t>)
|
||||
{
|
||||
scale = fp8_to_float_raw(q(outer_dim, inner_dim));
|
||||
}
|
||||
else if constexpr(std::is_same_v<QDataType, ck_tile::bf8_t>)
|
||||
{
|
||||
scale = bf8_to_float_raw(q(outer_dim, inner_dim));
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unexpected Q datatype.");
|
||||
}
|
||||
v_block_acc *= scale;
|
||||
v_acc += v_block_acc;
|
||||
v_block_acc = 0;
|
||||
}
|
||||
}
|
||||
|
||||
c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
@@ -71,6 +175,58 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
|
||||
make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ACCElementOp,
|
||||
typename DDataType = remove_cvref_t<std::tuple_element_t<0, DsDataType>>>
|
||||
CK_TILE_HOST void
|
||||
reference_gemm_multiple_d(const HostTensor<ADataType>& a_m_k,
|
||||
const HostTensor<BDataType>& b_k_n,
|
||||
const std::array<HostTensor<DDataType>, DsDataType::size()>& ds_m_n,
|
||||
HostTensor<CDataType>& c_m_n,
|
||||
const ACCElementOp& acc_element_op = {})
|
||||
{
|
||||
const std::size_t M = a_m_k.get_length(0);
|
||||
const std::size_t N = b_k_n.get_length(1);
|
||||
const std::size_t K = a_m_k.get_length(1);
|
||||
|
||||
auto f_mk_kn_mn = [&](auto m, auto n) {
|
||||
AccDataType v_acc = 0;
|
||||
for(std::size_t k = 0; k < K; ++k)
|
||||
{
|
||||
ADataType v_a = a_m_k(m, k);
|
||||
BDataType v_b = b_k_n(k, n);
|
||||
v_acc +=
|
||||
ck_tile::type_convert<AccDataType>(v_a) * ck_tile::type_convert<AccDataType>(v_b);
|
||||
}
|
||||
|
||||
CDataType v_c = 0;
|
||||
if constexpr(DsDataType::size() == 0)
|
||||
{
|
||||
acc_element_op(v_c, ck_tile::type_convert<float>(v_acc));
|
||||
}
|
||||
else if constexpr(DsDataType::size() == 1)
|
||||
{
|
||||
acc_element_op(v_c,
|
||||
ck_tile::type_convert<float>(v_acc),
|
||||
ck_tile::type_convert<float>(ds_m_n[0](m, n)));
|
||||
}
|
||||
else if constexpr(DsDataType::size() == 2)
|
||||
{
|
||||
acc_element_op(v_c,
|
||||
ck_tile::type_convert<float>(v_acc),
|
||||
ck_tile::type_convert<float>(ds_m_n[0](m, n)),
|
||||
ck_tile::type_convert<float>(ds_m_n[1](m, n)));
|
||||
}
|
||||
c_m_n(m, n) = ck_tile::type_convert<CDataType>(v_c);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_mk_kn_mn, M, N)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
|
||||
@@ -0,0 +1,167 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <thread>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType>
|
||||
CK_TILE_HOST void
|
||||
reference_grouped_conv_bwd_weight(const HostTensor<InDataType>& input,
|
||||
HostTensor<WeiDataType>& weight,
|
||||
const HostTensor<OutDataType>& output,
|
||||
std::vector<ck_tile::long_index_t> conv_strides,
|
||||
std::vector<ck_tile::long_index_t> conv_dilations,
|
||||
std::vector<ck_tile::long_index_t> in_left_pads,
|
||||
std::vector<ck_tile::long_index_t>)
|
||||
{
|
||||
if(!(input.get_num_of_dimension() == NDimSpatial + 3 &&
|
||||
weight.get_num_of_dimension() == NDimSpatial + 3 &&
|
||||
output.get_num_of_dimension() == NDimSpatial + 3))
|
||||
{
|
||||
throw std::runtime_error("wrong! inconsistent dimension");
|
||||
}
|
||||
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
auto func = [&](auto g, auto k, auto c, auto x) {
|
||||
float v_acc = 0;
|
||||
|
||||
for(std::size_t n = 0; n < output.get_lengths()[1]; ++n)
|
||||
{
|
||||
for(std::size_t wo = 0; wo < output.get_lengths()[3]; ++wo)
|
||||
{
|
||||
auto wi = static_cast<ck_tile::long_index_t>(wo * conv_strides[0]) +
|
||||
static_cast<ck_tile::long_index_t>(x * conv_dilations[0]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
|
||||
|
||||
if(wi >= 0 && ck_tile::type_convert<std::size_t>(wi) < input.get_lengths()[3])
|
||||
{
|
||||
InDataType v_in = input(g, n, c, wi);
|
||||
OutDataType v_out = output(g, n, k, wo);
|
||||
v_acc += ck_tile::type_convert<float>(v_out) *
|
||||
ck_tile::type_convert<float>(v_in);
|
||||
}
|
||||
}
|
||||
}
|
||||
OutDataType v_acc_converted = ck_tile::type_convert<WeiDataType>(v_acc);
|
||||
weight(g, k, c, x) = v_acc_converted;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(func,
|
||||
weight.get_lengths()[0],
|
||||
weight.get_lengths()[1],
|
||||
weight.get_lengths()[2],
|
||||
weight.get_lengths()[3])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
auto func = [&](auto g, auto k, auto c, auto y, auto x) {
|
||||
float v_acc = 0;
|
||||
|
||||
for(std::size_t n = 0; n < output.get_lengths()[1]; ++n)
|
||||
{
|
||||
for(std::size_t ho = 0; ho < output.get_lengths()[3]; ++ho)
|
||||
{
|
||||
auto hi = static_cast<ck_tile::long_index_t>(ho * conv_strides[0]) +
|
||||
static_cast<ck_tile::long_index_t>(y * conv_dilations[0]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
|
||||
|
||||
for(std::size_t wo = 0; wo < output.get_lengths()[4]; ++wo)
|
||||
{
|
||||
auto wi = static_cast<ck_tile::long_index_t>(wo * conv_strides[1]) +
|
||||
static_cast<ck_tile::long_index_t>(x * conv_dilations[1]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[1]);
|
||||
|
||||
if(hi >= 0 &&
|
||||
ck_tile::type_convert<std::size_t>(hi) < input.get_lengths()[3] &&
|
||||
wi >= 0 &&
|
||||
ck_tile::type_convert<std::size_t>(wi) < input.get_lengths()[4])
|
||||
{
|
||||
InDataType v_in = input(g, n, c, hi, wi);
|
||||
OutDataType v_out = output(g, n, k, ho, wo);
|
||||
|
||||
v_acc += ck_tile::type_convert<float>(v_out) *
|
||||
ck_tile::type_convert<float>(v_in);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
WeiDataType v_acc_converted = ck_tile::type_convert<WeiDataType>(v_acc);
|
||||
weight(g, k, c, y, x) = v_acc_converted;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(func,
|
||||
weight.get_lengths()[0],
|
||||
weight.get_lengths()[1],
|
||||
weight.get_lengths()[2],
|
||||
weight.get_lengths()[3],
|
||||
weight.get_lengths()[4])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
auto func = [&](auto g, auto k, auto c, auto z, auto y, auto x) {
|
||||
float v_acc = 0;
|
||||
|
||||
for(std::size_t n = 0; n < output.get_lengths()[1]; ++n)
|
||||
{
|
||||
for(std::size_t do_ = 0; do_ < output.get_lengths()[3]; ++do_)
|
||||
{
|
||||
auto di = static_cast<ck_tile::long_index_t>(do_ * conv_strides[0]) +
|
||||
static_cast<ck_tile::long_index_t>(z * conv_dilations[0]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
|
||||
for(std::size_t ho = 0; ho < output.get_lengths()[4]; ++ho)
|
||||
{
|
||||
auto hi = static_cast<ck_tile::long_index_t>(ho * conv_strides[1]) +
|
||||
static_cast<ck_tile::long_index_t>(y * conv_dilations[1]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[1]);
|
||||
for(std::size_t wo = 0; wo < output.get_lengths()[5]; ++wo)
|
||||
{
|
||||
auto wi = static_cast<ck_tile::long_index_t>(wo * conv_strides[2]) +
|
||||
static_cast<ck_tile::long_index_t>(x * conv_dilations[2]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[2]);
|
||||
if(di >= 0 &&
|
||||
ck_tile::type_convert<std::size_t>(di) < input.get_lengths()[3] &&
|
||||
hi >= 0 &&
|
||||
ck_tile::type_convert<std::size_t>(hi) < input.get_lengths()[4] &&
|
||||
wi >= 0 &&
|
||||
ck_tile::type_convert<std::size_t>(wi) < input.get_lengths()[5])
|
||||
{
|
||||
InDataType v_in = input(g, n, c, di, hi, wi);
|
||||
OutDataType v_out = output(g, n, k, do_, ho, wo);
|
||||
|
||||
v_acc += ck_tile::type_convert<float>(v_out) *
|
||||
ck_tile::type_convert<float>(v_in);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
WeiDataType v_acc_converted = ck_tile::type_convert<WeiDataType>(v_acc);
|
||||
weight(g, k, c, z, y, x) = v_acc_converted;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(func,
|
||||
weight.get_lengths()[0],
|
||||
weight.get_lengths()[1],
|
||||
weight.get_lengths()[2],
|
||||
weight.get_lengths()[3],
|
||||
weight.get_lengths()[4],
|
||||
weight.get_lengths()[5])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Ref_conv_bwd_weight: number of dimensions must be between 1 and 3.");
|
||||
}
|
||||
}
|
||||
} // namespace ck_tile
|
||||
165
include/ck_tile/host/reference/reference_grouped_conv_fwd.hpp
Normal file
165
include/ck_tile/host/reference/reference_grouped_conv_fwd.hpp
Normal file
@@ -0,0 +1,165 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <thread>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType>
|
||||
CK_TILE_HOST void reference_grouped_conv_fwd(const HostTensor<InDataType>& input,
|
||||
const HostTensor<WeiDataType>& weight,
|
||||
HostTensor<OutDataType>& output,
|
||||
std::vector<ck_tile::long_index_t> conv_strides,
|
||||
std::vector<ck_tile::long_index_t> conv_dilations,
|
||||
std::vector<ck_tile::long_index_t> in_left_pads,
|
||||
std::vector<ck_tile::long_index_t>)
|
||||
{
|
||||
if(!(input.get_num_of_dimension() == NDimSpatial + 3 &&
|
||||
weight.get_num_of_dimension() == NDimSpatial + 3 &&
|
||||
output.get_num_of_dimension() == NDimSpatial + 3))
|
||||
{
|
||||
throw std::runtime_error("wrong! inconsistent dimension");
|
||||
}
|
||||
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
auto func = [&](auto g, auto n, auto k, auto wo) {
|
||||
float v_acc = 0;
|
||||
|
||||
for(std::size_t c = 0; c < weight.get_lengths()[2]; ++c)
|
||||
{
|
||||
for(std::size_t x = 0; x < weight.get_lengths()[3]; ++x)
|
||||
{
|
||||
auto wi = static_cast<ck_tile::long_index_t>(wo * conv_strides[0]) +
|
||||
static_cast<ck_tile::long_index_t>(x * conv_dilations[0]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
|
||||
|
||||
if(wi >= 0 && ck_tile::type_convert<std::size_t>(wi) < input.get_lengths()[3])
|
||||
{
|
||||
InDataType v_in = input(g, n, c, wi);
|
||||
WeiDataType v_wei = weight(g, k, c, x);
|
||||
v_acc += ck_tile::type_convert<float>(v_in) *
|
||||
ck_tile::type_convert<float>(v_wei);
|
||||
}
|
||||
}
|
||||
}
|
||||
OutDataType v_acc_converted = ck_tile::type_convert<OutDataType>(v_acc);
|
||||
output(g, n, k, wo) = v_acc_converted;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(func,
|
||||
output.get_lengths()[0],
|
||||
output.get_lengths()[1],
|
||||
output.get_lengths()[2],
|
||||
output.get_lengths()[3])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
auto func = [&](auto g, auto n, auto k, auto ho, auto wo) {
|
||||
float v_acc = 0;
|
||||
|
||||
for(std::size_t c = 0; c < weight.get_lengths()[2]; ++c)
|
||||
{
|
||||
for(std::size_t y = 0; y < weight.get_lengths()[3]; ++y)
|
||||
{
|
||||
auto hi = static_cast<ck_tile::long_index_t>(ho * conv_strides[0]) +
|
||||
static_cast<ck_tile::long_index_t>(y * conv_dilations[0]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
|
||||
|
||||
for(std::size_t x = 0; x < weight.get_lengths()[4]; ++x)
|
||||
{
|
||||
auto wi = static_cast<ck_tile::long_index_t>(wo * conv_strides[1]) +
|
||||
static_cast<ck_tile::long_index_t>(x * conv_dilations[1]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[1]);
|
||||
|
||||
if(hi >= 0 &&
|
||||
ck_tile::type_convert<std::size_t>(hi) < input.get_lengths()[3] &&
|
||||
wi >= 0 &&
|
||||
ck_tile::type_convert<std::size_t>(wi) < input.get_lengths()[4])
|
||||
{
|
||||
InDataType v_in = input(g, n, c, hi, wi);
|
||||
WeiDataType v_wei = weight(g, k, c, y, x);
|
||||
|
||||
v_acc += ck_tile::type_convert<float>(v_in) *
|
||||
ck_tile::type_convert<float>(v_wei);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
OutDataType v_acc_converted = ck_tile::type_convert<OutDataType>(v_acc);
|
||||
output(g, n, k, ho, wo) = v_acc_converted;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(func,
|
||||
output.get_lengths()[0],
|
||||
output.get_lengths()[1],
|
||||
output.get_lengths()[2],
|
||||
output.get_lengths()[3],
|
||||
output.get_lengths()[4])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
auto func = [&](auto g, auto n, auto k, auto d_o, auto ho, auto wo) {
|
||||
float v_acc = 0;
|
||||
|
||||
for(std::size_t c = 0; c < weight.get_lengths()[2]; ++c)
|
||||
{
|
||||
for(std::size_t z = 0; z < weight.get_lengths()[3]; ++z)
|
||||
{
|
||||
auto di = static_cast<ck_tile::long_index_t>(d_o * conv_strides[0]) +
|
||||
static_cast<ck_tile::long_index_t>(z * conv_dilations[0]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
|
||||
for(std::size_t y = 0; y < weight.get_lengths()[4]; ++y)
|
||||
{
|
||||
auto hi = static_cast<ck_tile::long_index_t>(ho * conv_strides[1]) +
|
||||
static_cast<ck_tile::long_index_t>(y * conv_dilations[1]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[1]);
|
||||
for(std::size_t x = 0; x < weight.get_lengths()[5]; ++x)
|
||||
{
|
||||
auto wi = static_cast<ck_tile::long_index_t>(wo * conv_strides[2]) +
|
||||
static_cast<ck_tile::long_index_t>(x * conv_dilations[2]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[2]);
|
||||
if(di >= 0 &&
|
||||
ck_tile::type_convert<std::size_t>(di) < input.get_lengths()[3] &&
|
||||
hi >= 0 &&
|
||||
ck_tile::type_convert<std::size_t>(hi) < input.get_lengths()[4] &&
|
||||
wi >= 0 &&
|
||||
ck_tile::type_convert<std::size_t>(wi) < input.get_lengths()[5])
|
||||
{
|
||||
InDataType v_in = input(g, n, c, di, hi, wi);
|
||||
WeiDataType v_wei = weight(g, k, c, z, y, x);
|
||||
|
||||
v_acc += ck_tile::type_convert<float>(v_in) *
|
||||
ck_tile::type_convert<float>(v_wei);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
OutDataType v_acc_converted = ck_tile::type_convert<OutDataType>(v_acc);
|
||||
output(g, n, k, d_o, ho, wo) = v_acc_converted;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(func,
|
||||
output.get_lengths()[0],
|
||||
output.get_lengths()[1],
|
||||
output.get_lengths()[2],
|
||||
output.get_lengths()[3],
|
||||
output.get_lengths()[4],
|
||||
output.get_lengths()[5])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Ref_Conv_fwd: number of dimensions must be between 1 and 3.");
|
||||
}
|
||||
}
|
||||
} // namespace ck_tile
|
||||
@@ -9,7 +9,7 @@
|
||||
namespace ck_tile {
|
||||
|
||||
#define MOE_SORTING_MOCK_ID(token_id_, topk_id_) \
|
||||
static_cast<uint32_t>(((token_id_)&0x00ffffff) | (((topk_id_)&0xff) << 24))
|
||||
static_cast<uint32_t>(((token_id_) & 0x00ffffff) | (((topk_id_) & 0xff) << 24))
|
||||
|
||||
template <typename WeightType, typename IndexType = index_t>
|
||||
CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
|
||||
@@ -21,10 +21,12 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
|
||||
index_t& unit_cnt,
|
||||
const index_t experts,
|
||||
const index_t unit_size,
|
||||
const index_t tokens,
|
||||
bool local_expert_masking,
|
||||
bool skip_experts_with_zero_token = true)
|
||||
{
|
||||
const index_t num_token = topk_ids.mDesc.get_lengths()[0];
|
||||
// note: if tokens is smaller than topk_ids.mDesc.get_lengths()[0], indicating local_token case
|
||||
const index_t num_token = tokens; // topk_ids.mDesc.get_lengths()[0];
|
||||
const index_t topk = topk_ids.mDesc.get_lengths()[1];
|
||||
// allocate a temp buffer, and fill the value with [number_token|topk]
|
||||
std::vector<std::vector<IndexType>> expert_tokens(
|
||||
|
||||
@@ -30,4 +30,82 @@ reference_reduce(const HostTensor<XDataType>& x_m_n, HostTensor<YDataType>& y_m,
|
||||
|
||||
make_ParallelTensorFunctor(f, y_m.mDesc.get_lengths()[0])(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
// Generic reference reduce for arbitrary dimensions
|
||||
template <
|
||||
typename XDataType,
|
||||
typename ComputeDataType,
|
||||
typename YDataType,
|
||||
typename ReduceOp,
|
||||
typename KeptDim, // Expected type: ck_tile::sequence<...> containing dimension indices to keep
|
||||
typename ReduceDims> // Expected type: ck_tile::sequence<...> containing dimension indices to
|
||||
// reduce
|
||||
CK_TILE_HOST void reference_reduce(const HostTensor<XDataType>& x_tensor,
|
||||
HostTensor<YDataType>& y_tensor,
|
||||
ReduceOp reduce_op,
|
||||
KeptDim kept_dim,
|
||||
ReduceDims reduce_dims)
|
||||
{
|
||||
const auto& x_lengths = x_tensor.mDesc.get_lengths();
|
||||
|
||||
// Calculate total kept elements (product of all kept dimension lengths)
|
||||
index_t total_kept_elements = 1;
|
||||
static_for<0, kept_dim.size(), 1>{}(
|
||||
[&](auto i) { total_kept_elements *= x_lengths[kept_dim.at(i)]; });
|
||||
|
||||
// Calculate total reduce elements (product of all reduce dimension lengths)
|
||||
index_t total_reduce_elements = 1;
|
||||
static_for<0, reduce_dims.size(), 1>{}(
|
||||
[&](auto i) { total_reduce_elements *= x_lengths[reduce_dims.at(i)]; });
|
||||
|
||||
auto f = [&](auto linear_kept_idx) {
|
||||
ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
|
||||
|
||||
// Convert linear kept index to multi-dimensional kept indices
|
||||
std::vector<index_t> kept_indices(kept_dim.size());
|
||||
index_t temp_kept = linear_kept_idx;
|
||||
static_for<0, kept_dim.size(), 1>{}([&](auto i) {
|
||||
constexpr auto dim_idx = kept_dim.size() - 1 - i;
|
||||
constexpr auto dim = kept_dim.at(dim_idx);
|
||||
const auto len = x_lengths[dim];
|
||||
kept_indices[dim_idx] = temp_kept % len;
|
||||
temp_kept /= len;
|
||||
});
|
||||
|
||||
for(index_t reduce_idx = 0; reduce_idx < total_reduce_elements; ++reduce_idx)
|
||||
{
|
||||
// Convert linear reduce index to multi-dimensional reduce indices
|
||||
std::vector<index_t> reduce_indices(reduce_dims.size());
|
||||
index_t temp_reduce = reduce_idx;
|
||||
static_for<0, reduce_dims.size(), 1>{}([&](auto i) {
|
||||
constexpr auto dim_idx = reduce_dims.size() - 1 - i;
|
||||
constexpr auto dim = reduce_dims.at(dim_idx);
|
||||
const auto len = x_lengths[dim];
|
||||
reduce_indices[dim_idx] = temp_reduce % len;
|
||||
temp_reduce /= len;
|
||||
});
|
||||
|
||||
// Build full input tensor indices by combining kept and reduce indices
|
||||
std::vector<std::size_t> full_indices(x_lengths.size(), 0);
|
||||
static_for<0, kept_dim.size(), 1>{}(
|
||||
[&](auto i) { full_indices[kept_dim.at(i)] = kept_indices[i]; });
|
||||
static_for<0, reduce_dims.size(), 1>{}(
|
||||
[&](auto i) { full_indices[reduce_dims.at(i)] = reduce_indices[i]; });
|
||||
|
||||
// Access input tensor element
|
||||
const auto v_a = type_convert<ComputeDataType>(x_tensor(full_indices));
|
||||
|
||||
v_acc = reduce_op(v_acc, v_a);
|
||||
}
|
||||
|
||||
// Calculate output tensor index using kept indices
|
||||
// The output tensor has the same structure as the kept dimensions
|
||||
std::vector<std::size_t> y_indices(kept_dim.size());
|
||||
static_for<0, kept_dim.size(), 1>{}([&](auto i) { y_indices[i] = kept_indices[i]; });
|
||||
|
||||
y_tensor(y_indices) = type_convert<YDataType>(v_acc);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f, total_kept_elements)(std::thread::hardware_concurrency());
|
||||
}
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -14,7 +14,7 @@ CK_TILE_HOST void
|
||||
reference_softmax(const HostTensor<InputType>& x, HostTensor<OutputType>& y, index_t dim = -1)
|
||||
{
|
||||
index_t rank = x.get_num_of_dimension();
|
||||
assert(rank == y.get_num_of_dimension());
|
||||
assert(static_cast<std::size_t>(rank) == y.get_num_of_dimension());
|
||||
assert(dim == -1 || dim < rank);
|
||||
|
||||
index_t target_dim = dim == -1 ? (rank - 1) : dim;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -38,8 +38,8 @@ CK_TILE_HOST void reference_topk(const HostTensor<DataType>& x,
|
||||
{
|
||||
// rank must be the same
|
||||
index_t rank = x.get_num_of_dimension();
|
||||
assert(rank == y_values.get_num_of_dimension());
|
||||
assert(rank == y_indices.get_num_of_dimension());
|
||||
assert(static_cast<std::size_t>(rank) == y_values.get_num_of_dimension());
|
||||
assert(static_cast<size_t>(rank) == y_indices.get_num_of_dimension());
|
||||
assert(dim == -1 || dim < rank);
|
||||
|
||||
index_t topk_dim = dim == -1 ? (rank - 1) : dim;
|
||||
@@ -47,7 +47,8 @@ CK_TILE_HOST void reference_topk(const HostTensor<DataType>& x,
|
||||
auto x_len = x.get_lengths();
|
||||
|
||||
assert(k <= topk_src_len);
|
||||
assert(k == y_values.get_length(topk_dim) && k == y_indices.get_length(topk_dim));
|
||||
assert(static_cast<size_t>(k) == y_values.get_length(topk_dim) &&
|
||||
static_cast<size_t>(k) == y_indices.get_length(topk_dim));
|
||||
|
||||
index_t n_parallel = x.get_element_size() / topk_src_len;
|
||||
|
||||
|
||||
33
include/ck_tile/host/reference/reference_transpose.hpp
Normal file
33
include/ck_tile/host/reference/reference_transpose.hpp
Normal file
@@ -0,0 +1,33 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include <thread>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ADataType, typename BDataType>
|
||||
void reference_transpose_elementwise(const HostTensor<ADataType>& a, HostTensor<BDataType>& b)
|
||||
{
|
||||
ck_tile::index_t M = static_cast<ck_tile::index_t>(a.mDesc.get_lengths()[0]);
|
||||
ck_tile::index_t N = static_cast<ck_tile::index_t>(a.mDesc.get_lengths()[1]);
|
||||
|
||||
// Ensure the b tensor is sized correctly for N x M
|
||||
if(static_cast<ck_tile::index_t>(b.mDesc.get_lengths()[0]) != N ||
|
||||
static_cast<ck_tile::index_t>(b.mDesc.get_lengths()[1]) != M)
|
||||
{
|
||||
throw std::runtime_error("Output tensor b has incorrect dimensions for transpose.");
|
||||
}
|
||||
|
||||
auto f = [&](auto i, auto j) {
|
||||
auto v_a = a(i, j);
|
||||
b(j, i) = ck_tile::type_convert<BDataType>(v_a);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f, M, N)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
102
include/ck_tile/host/rotating_buffers.hpp
Normal file
102
include/ck_tile/host/rotating_buffers.hpp
Normal file
@@ -0,0 +1,102 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ADataType, typename BDataType>
|
||||
struct RotatingMemWrapper
|
||||
{
|
||||
RotatingMemWrapper() = delete;
|
||||
RotatingMemWrapper(const void* a_ptr_,
|
||||
const void* b_ptr_,
|
||||
std::size_t rotating_count_,
|
||||
std::size_t size_a_,
|
||||
std::size_t size_b_)
|
||||
: a_ptr(a_ptr_),
|
||||
b_ptr(b_ptr_),
|
||||
rotating_count(rotating_count_),
|
||||
size_a(size_a_),
|
||||
size_b(size_b_)
|
||||
{
|
||||
p_a_grids.push_back(a_ptr);
|
||||
p_b_grids.push_back(b_ptr);
|
||||
for(size_t i = 1; i < rotating_count; i++)
|
||||
{
|
||||
{
|
||||
void* pADeviceBuf;
|
||||
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&pADeviceBuf), size_a_));
|
||||
HIP_CHECK_ERROR(hipMemcpy(static_cast<void*>(pADeviceBuf),
|
||||
const_cast<void*>(p_a_grids[0]),
|
||||
size_a_,
|
||||
hipMemcpyDeviceToDevice));
|
||||
p_a_grids.push_back(pADeviceBuf);
|
||||
}
|
||||
|
||||
{
|
||||
void* pBDeviceBuf;
|
||||
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&pBDeviceBuf), size_b_));
|
||||
HIP_CHECK_ERROR(hipMemcpy(static_cast<void*>(pBDeviceBuf),
|
||||
const_cast<void*>(p_b_grids[0]),
|
||||
size_b_,
|
||||
hipMemcpyDeviceToDevice));
|
||||
p_b_grids.push_back(pBDeviceBuf);
|
||||
}
|
||||
}
|
||||
}
|
||||
void Next()
|
||||
{
|
||||
if(rotating_count > 1)
|
||||
{
|
||||
std::size_t idx = iter++ % rotating_count;
|
||||
a_ptr = p_a_grids[idx];
|
||||
b_ptr = p_b_grids[idx];
|
||||
}
|
||||
}
|
||||
void Print()
|
||||
{
|
||||
std::cout << "RotatingMemWrapper: { size_a: " << size_a << ", size_b: " << size_b
|
||||
<< ", rotating_count: " << rotating_count << "}" << std::endl;
|
||||
}
|
||||
~RotatingMemWrapper() noexcept
|
||||
{
|
||||
if(rotating_count > 1)
|
||||
{
|
||||
// restore ptr
|
||||
a_ptr = p_a_grids[0];
|
||||
b_ptr = p_b_grids[0];
|
||||
|
||||
// free device mem
|
||||
for(size_t i = 1; i < rotating_count; i++)
|
||||
{
|
||||
ck_tile::hip_check_error(hipFree(const_cast<void*>(p_a_grids[i])));
|
||||
ck_tile::hip_check_error(hipFree(const_cast<void*>(p_b_grids[i])));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
const void* a_ptr;
|
||||
const void* b_ptr;
|
||||
std::size_t iter = 0;
|
||||
std::size_t rotating_count = 1;
|
||||
std::size_t size_a = 0;
|
||||
std::size_t size_b = 0;
|
||||
std::vector<const void*> p_a_grids;
|
||||
std::vector<const void*> p_b_grids;
|
||||
};
|
||||
inline void flush_icache()
|
||||
{
|
||||
hipDeviceProp_t deviceProps;
|
||||
HIP_CHECK_ERROR(hipGetDeviceProperties(&deviceProps, 0));
|
||||
int32_t gpu_block3 = deviceProps.multiProcessorCount * 60;
|
||||
|
||||
ck_tile::flush_cache<<<dim3(gpu_block3), dim3(64), 0, nullptr>>>();
|
||||
HIP_CHECK_ERROR(hipGetLastError());
|
||||
}
|
||||
} // namespace ck_tile
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -20,6 +20,10 @@ namespace ck_tile {
|
||||
*
|
||||
* // create stream config with _some_stream_id_, and benchmark using cpu timer
|
||||
* stream_config s = stream_config{_some_stream_id_, true, 0, 3, 10, false};
|
||||
*
|
||||
* // create stream config with _some_stream_id_, and enable gpu timer for rotating buffer with
|
||||
*rotating buffer count stream_config s = stream_config{_some_stream_id_, true, 0, 3, 10, true,
|
||||
*true, 1};
|
||||
**/
|
||||
|
||||
struct stream_config
|
||||
@@ -30,5 +34,7 @@ struct stream_config
|
||||
int cold_niters_ = 3;
|
||||
int nrepeat_ = 10;
|
||||
bool is_gpu_timer_ = true; // keep compatible
|
||||
bool flush_cache_ = false;
|
||||
int rotating_count_ = 1;
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
45
include/ck_tile/host/stream_utils.hpp
Normal file
45
include/ck_tile/host/stream_utils.hpp
Normal file
@@ -0,0 +1,45 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime_api.h>
|
||||
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/host/stream_config.hpp"
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
static inline index_t get_available_compute_units(const stream_config& s)
|
||||
{
|
||||
constexpr static uint32_t MAX_MASK_DWORDS = 64;
|
||||
|
||||
// assume at most 64*32 = 2048 CUs
|
||||
uint32_t cu_mask[MAX_MASK_DWORDS]{};
|
||||
|
||||
auto count_set_bits = [](uint32_t dword) {
|
||||
index_t count = 0;
|
||||
while(dword != 0)
|
||||
{
|
||||
if(dword & 0x1)
|
||||
{
|
||||
count++;
|
||||
}
|
||||
dword = dword >> 1;
|
||||
}
|
||||
return count;
|
||||
};
|
||||
|
||||
HIP_CHECK_ERROR(hipExtStreamGetCUMask(s.stream_id_, MAX_MASK_DWORDS, &cu_mask[0]));
|
||||
|
||||
index_t num_cu = 0;
|
||||
for(uint32_t i = 0; i < MAX_MASK_DWORDS; i++)
|
||||
{
|
||||
num_cu += count_set_bits(cu_mask[i]);
|
||||
}
|
||||
|
||||
return num_cu;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -4,6 +4,10 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp"
|
||||
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_common_policy.hpp"
|
||||
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_pipeline.hpp"
|
||||
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_policy.hpp"
|
||||
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_problem.hpp"
|
||||
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_pipeline.hpp"
|
||||
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp"
|
||||
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp"
|
||||
|
||||
@@ -19,7 +19,6 @@ struct BatchedTransposeHostArgs
|
||||
index_t batch;
|
||||
index_t height;
|
||||
index_t width;
|
||||
// index_t dim_blocks;
|
||||
index_t dim_stride;
|
||||
index_t dim_block_h;
|
||||
index_t dim_block_w;
|
||||
@@ -28,10 +27,12 @@ struct BatchedTransposeHostArgs
|
||||
template <typename Pipeline_>
|
||||
struct BatchedTransposeKernel
|
||||
{
|
||||
using Pipeline = remove_cvref_t<Pipeline_>;
|
||||
using Problem = remove_cvref_t<typename Pipeline::Problem>;
|
||||
|
||||
using Type = typename Problem::InputType;
|
||||
CK_TILE_DEVICE static index_t counter = 0;
|
||||
using Pipeline = remove_cvref_t<Pipeline_>;
|
||||
using Problem = remove_cvref_t<typename Pipeline::Problem>;
|
||||
|
||||
using Type = typename Problem::DataType;
|
||||
|
||||
struct BatchedTransposeKargs
|
||||
{
|
||||
@@ -46,11 +47,13 @@ struct BatchedTransposeKernel
|
||||
using Kargs = BatchedTransposeKargs;
|
||||
using Hargs = BatchedTransposeHostArgs;
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs& host_args)
|
||||
{
|
||||
size_t grid_size_x = (h.width + h.dim_block_w - 1) / h.dim_block_w;
|
||||
size_t grid_size_y = (h.height + h.dim_block_h - 1) / h.dim_block_h;
|
||||
size_t grid_size_z = h.batch;
|
||||
const size_t grid_size_x =
|
||||
ck_tile::integer_divide_ceil(host_args.height, host_args.dim_block_h);
|
||||
const size_t grid_size_y =
|
||||
ck_tile::integer_divide_ceil(host_args.width, host_args.dim_block_w);
|
||||
const size_t grid_size_z = host_args.batch;
|
||||
return dim3(grid_size_x, grid_size_y, grid_size_z);
|
||||
}
|
||||
|
||||
@@ -66,62 +69,58 @@ struct BatchedTransposeKernel
|
||||
return k;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto BlockSize() { return Problem::kBlockSize; }
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return Problem::kBlockSize; }
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
static constexpr ck_tile::index_t kMPerBlock = Problem::kMPerBlock;
|
||||
static constexpr ck_tile::index_t kNPerBlock = Problem::kNPerBlock;
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr ck_tile::index_t VectorSizeInput = Problem::VectorSizeInput;
|
||||
static constexpr ck_tile::index_t VectorStrideInput = 1;
|
||||
static constexpr ck_tile::index_t VectorSizeOutput = Problem::VectorSizeOutput;
|
||||
static constexpr ck_tile::index_t VectorStrideOutput = 1;
|
||||
|
||||
static constexpr ck_tile::index_t kMPerBlock = Problem::kMPerBlock;
|
||||
static constexpr ck_tile::index_t kNPerBlock = Problem::kNPerBlock;
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
const auto iM = __builtin_amdgcn_readfirstlane(blockIdx.x * kMPerBlock);
|
||||
const auto iN = __builtin_amdgcn_readfirstlane(blockIdx.y * kNPerBlock);
|
||||
const auto offset = __builtin_amdgcn_readfirstlane(blockIdx.z * kargs.height * kargs.width);
|
||||
|
||||
static constexpr ck_tile::index_t kMPerThread = Problem::kMPerThread;
|
||||
static constexpr ck_tile::index_t kNPerThread = Problem::kNPerThread;
|
||||
|
||||
static_assert(kMPerThread == 1 && kNPerThread == 1);
|
||||
|
||||
const auto iDim = blockIdx.z;
|
||||
const auto x_m_n = [&]() {
|
||||
const auto x_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const Type*>(kargs.p_input) + iDim * kargs.dim_stride,
|
||||
static_cast<const Type*>(kargs.p_input) + offset,
|
||||
make_tuple(kargs.height, kargs.width),
|
||||
make_tuple(kargs.width, 1),
|
||||
number<kNPerThread>{}, // TODO thread load value
|
||||
number<1>{});
|
||||
number<VectorSizeInput>{},
|
||||
number<VectorStrideInput>{});
|
||||
|
||||
return pad_tensor_view(x_dram_naive,
|
||||
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
|
||||
sequence<kPadM, kPadN>{});
|
||||
}();
|
||||
|
||||
const auto iM = __builtin_amdgcn_readfirstlane(blockIdx.x * kMPerBlock);
|
||||
const auto iN = __builtin_amdgcn_readfirstlane(blockIdx.y * kNPerBlock);
|
||||
|
||||
const auto y_n_m = [&]() {
|
||||
const auto y_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<Type*>(kargs.p_output) + iDim * kargs.dim_stride,
|
||||
static_cast<Type*>(kargs.p_output) + offset,
|
||||
make_tuple(kargs.width, kargs.height),
|
||||
make_tuple(kargs.height, 1),
|
||||
number<kMPerThread>{},
|
||||
number<1>{});
|
||||
number<VectorSizeOutput>{},
|
||||
number<VectorStrideOutput>{});
|
||||
|
||||
return pad_tensor_view(y_dram_naive,
|
||||
make_tuple(number<kNPerBlock>{}, number<kMPerBlock>{}),
|
||||
sequence<kPadN, kPadM>{});
|
||||
}();
|
||||
|
||||
auto x_block_window =
|
||||
make_tile_window(x_m_n,
|
||||
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
|
||||
{static_cast<ck_tile::index_t>(iM * kMPerBlock),
|
||||
static_cast<ck_tile::index_t>(iN * kNPerBlock)});
|
||||
auto x_block_window = make_tile_window(
|
||||
x_m_n,
|
||||
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
|
||||
{static_cast<ck_tile::index_t>(iM), static_cast<ck_tile::index_t>(iN)});
|
||||
|
||||
auto y_block_window =
|
||||
make_tile_window(y_n_m,
|
||||
make_tuple(number<kNPerBlock>{}, number<kMPerBlock>{}),
|
||||
{static_cast<ck_tile::index_t>(iN * kNPerBlock),
|
||||
static_cast<ck_tile::index_t>(iM * kMPerBlock)});
|
||||
auto y_block_window = make_tile_window(
|
||||
y_n_m,
|
||||
make_tuple(number<kNPerBlock>{}, number<kMPerBlock>{}),
|
||||
{static_cast<ck_tile::index_t>(iN), static_cast<ck_tile::index_t>(iM)});
|
||||
|
||||
Pipeline{}(x_block_window, y_block_window);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct BatchedTransposeCommonPolicy
|
||||
{
|
||||
CK_TILE_DEVICE static constexpr auto TileAccessPattern =
|
||||
tile_distribution_pattern::thread_raked;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeInputDistribution()
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kLeadDimPerBlock = Problem::kNPerBlock;
|
||||
constexpr index_t kSecondDimPerBlock = Problem::kMPerBlock;
|
||||
|
||||
constexpr index_t kVectorSize = Problem::VectorSizeInput;
|
||||
static_assert((kLeadDimPerBlock * kVectorSize) % kBlockSize == 0, "");
|
||||
using TileEncodingPattern = TileDistributionEncodingPattern2D<kBlockSize,
|
||||
kSecondDimPerBlock,
|
||||
kLeadDimPerBlock,
|
||||
kVectorSize,
|
||||
TileAccessPattern>;
|
||||
return TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,67 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem_, typename Policy_>
|
||||
struct BatchedTransposeLdsPipeline
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
|
||||
using DataType = remove_cvref_t<typename Problem::DataType>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t kLeadSizePerBlock = Problem::kLeadSizePerBlock;
|
||||
static constexpr index_t kSecondSizePerBlock = Problem::kSecondSizePerBlock;
|
||||
|
||||
static constexpr index_t GetVectorSize() { return Policy::template GetVectorSize<Problem>(); }
|
||||
|
||||
CK_TILE_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename InputTileWindow, typename OutputTileWindow>
|
||||
CK_TILE_DEVICE void operator()(const InputTileWindow& input_window,
|
||||
OutputTileWindow& output_window)
|
||||
{
|
||||
__shared__ char smem[GetSmemSize()];
|
||||
auto input_tile_window =
|
||||
make_tile_window(input_window, Policy::template MakeInputDistribution<Problem>());
|
||||
auto output_tile_window =
|
||||
make_tile_window(output_window, Policy::template MakeOutputDistribution<Problem>());
|
||||
|
||||
DataType* p_lds_ptr = reinterpret_cast<DataType*>(smem);
|
||||
constexpr auto in_lds_block_desc = Policy::template MakeLdsStoreBlockDescriptor<Problem>();
|
||||
auto input_lds_block =
|
||||
make_tensor_view<address_space_enum::lds>(p_lds_ptr, in_lds_block_desc);
|
||||
|
||||
constexpr auto out_lds_block_desc = Policy::template MakeLdsLoadBlockDescriptor<Problem>();
|
||||
auto output_lds_block =
|
||||
make_tensor_view<address_space_enum::lds>(p_lds_ptr, out_lds_block_desc);
|
||||
|
||||
auto copy_to_lds_window =
|
||||
make_tile_window(input_lds_block,
|
||||
make_tuple(number<kSecondSizePerBlock>{}, number<kLeadSizePerBlock>{}),
|
||||
{0, 0});
|
||||
auto load_from_lds_window =
|
||||
make_tile_window(output_lds_block,
|
||||
make_tuple(number<kSecondSizePerBlock>{}, number<kLeadSizePerBlock>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeLdsLoadTileDistribution<Problem>());
|
||||
|
||||
auto x = load_tile(input_tile_window);
|
||||
|
||||
store_tile(copy_to_lds_window, x);
|
||||
block_sync_lds();
|
||||
|
||||
auto y = load_tile_transpose(load_from_lds_window);
|
||||
|
||||
store_tile(output_tile_window, y);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,123 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "batched_transpose_common_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct BatchedTransposeLdsPolicy : public BatchedTransposeCommonPolicy
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return integer_least_multiple(
|
||||
sizeof(typename Problem::DataType) *
|
||||
MakeLdsStoreBlockDescriptor<Problem>().get_element_space_size(),
|
||||
16);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeOutputDistribution()
|
||||
{
|
||||
constexpr auto input_dstr = MakeLdsLoadTileDistribution<Problem>();
|
||||
|
||||
using OutTileDstrEncode =
|
||||
typename OutputTileDistributionTraits<typename decltype(input_dstr)::DstrEncode,
|
||||
typename Problem::DataType>::TransposedDstrEncode;
|
||||
constexpr auto block_dstr = make_static_tile_distribution(OutTileDstrEncode{});
|
||||
|
||||
return block_dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeLdsStoreBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kLeadDimPerBlock = Problem::kLeadSizePerBlock;
|
||||
constexpr index_t kSecondDimPerBlock = Problem::kSecondSizePerBlock;
|
||||
constexpr index_t kVectorSize = Problem::LDSVectorSize;
|
||||
|
||||
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kSecondDimPerBlock>{},
|
||||
number<kLeadDimPerBlock / kVectorSize>{},
|
||||
number<kVectorSize>{}),
|
||||
make_tuple(number<kLeadDimPerBlock>{}, number<kVectorSize>{}, number<1>{}),
|
||||
number<kVectorSize>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto lds_block_desc = transform_tensor_descriptor(
|
||||
lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(number<kSecondDimPerBlock>{}),
|
||||
make_merge_transform(make_tuple(number<kLeadDimPerBlock / kVectorSize>{},
|
||||
number<kVectorSize>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeLdsLoadBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kLeadDimPerBlock = Problem::kLeadSizePerBlock;
|
||||
constexpr index_t kSecondDimPerBlock = Problem::kSecondSizePerBlock;
|
||||
constexpr index_t kVectorSize = Problem::LDSVectorSize;
|
||||
|
||||
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kSecondDimPerBlock>{},
|
||||
number<kLeadDimPerBlock / kVectorSize>{},
|
||||
number<kVectorSize>{}),
|
||||
make_tuple(number<kLeadDimPerBlock>{}, number<kVectorSize>{}, number<1>{}),
|
||||
number<kVectorSize>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto lds_block_desc = transform_tensor_descriptor(
|
||||
lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(number<kSecondDimPerBlock>{}),
|
||||
make_merge_transform(make_tuple(number<kLeadDimPerBlock / kVectorSize>{},
|
||||
number<kVectorSize>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeLdsLoadTileDistribution()
|
||||
{
|
||||
using DataType = typename Problem::DataType;
|
||||
|
||||
// Calculate block-level dimensions
|
||||
constexpr index_t kLeadIterPerWarp = 1;
|
||||
constexpr index_t kSecondIterPerWarp = 1;
|
||||
constexpr index_t kLeadNumWarps = Problem::kLeadNumWarps;
|
||||
constexpr index_t kSecondNumWarps = Problem::kSecondNumWarps;
|
||||
|
||||
// Calculate repetitions of base pattern
|
||||
constexpr index_t kLeadRepetitions = Problem::kQuadNumPerLeadDim;
|
||||
constexpr index_t kSecondRepetitions = Problem::kQuadNumPerSecondDim;
|
||||
constexpr index_t kSecondDimIterations = Problem::kIterationsInSecondDim;
|
||||
constexpr index_t kSecondDimStrSub = kSecondRepetitions / kSecondDimIterations;
|
||||
|
||||
constexpr index_t kLaneGroupSize = 16;
|
||||
constexpr auto xdllevel_dstr_encoding = make_transposed_distr_encode<DataType,
|
||||
kLaneGroupSize,
|
||||
kSecondDimStrSub,
|
||||
kSecondDimIterations,
|
||||
kLeadRepetitions,
|
||||
1>();
|
||||
|
||||
constexpr auto input_tile_encode =
|
||||
InputTileDistributionEncoding<decltype(xdllevel_dstr_encoding),
|
||||
kLeadIterPerWarp,
|
||||
kSecondIterPerWarp,
|
||||
kLeadNumWarps,
|
||||
kSecondNumWarps>();
|
||||
constexpr auto block_dstr = make_static_tile_distribution(input_tile_encode);
|
||||
return block_dstr;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,73 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// supports 2D transpose which will store to lds,
|
||||
// then use ds_read_b*_tr_b* instruction to get the transposed data
|
||||
template <typename DataType_,
|
||||
typename BlockTile, // sequence<block_x, block_y>
|
||||
typename NumWarps,
|
||||
bool kPadM_,
|
||||
bool kPadN_>
|
||||
struct BatchedTransposeLdsProblem
|
||||
{
|
||||
using DataType = remove_cvref_t<DataType_>;
|
||||
|
||||
static constexpr index_t kRowWarps_ = NumWarps::at(number<0>{});
|
||||
static constexpr index_t kColWarps_ = NumWarps::at(number<1>{});
|
||||
static constexpr index_t kBlockSize_ = get_warp_size() * kRowWarps_ * kColWarps_;
|
||||
static constexpr index_t kRowPerBlock_ = BlockTile::at(number<0>{});
|
||||
static constexpr index_t kColPerBlock_ = BlockTile::at(number<1>{});
|
||||
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
// warps per block
|
||||
static constexpr index_t kLeadNumWarps = kColWarps_;
|
||||
static constexpr index_t kSecondNumWarps = kRowWarps_;
|
||||
|
||||
static constexpr index_t kLeadSizePerBlock = kColPerBlock_;
|
||||
static constexpr index_t kSecondSizePerBlock = kRowPerBlock_;
|
||||
|
||||
static constexpr index_t kQuadrantLeadDim = LaneGroupTransposeTraits<DataType>::kleadDim;
|
||||
static constexpr index_t kQuadrantSecondDim = LaneGroupTransposeTraits<DataType>::ksecondDim;
|
||||
|
||||
static_assert(kLeadSizePerBlock % kLeadNumWarps == 0,
|
||||
"block dim should be divided by warp count!");
|
||||
static_assert(kSecondSizePerBlock % kSecondNumWarps == 0,
|
||||
"block dim should be divided by warp count!");
|
||||
// rows/cols per warp
|
||||
static constexpr index_t kLeadSizePerWarp = kLeadSizePerBlock / kLeadNumWarps;
|
||||
static constexpr index_t kSecondSizePerWarp = kSecondSizePerBlock / kSecondNumWarps;
|
||||
|
||||
static_assert(kLeadSizePerWarp % kQuadrantLeadDim == 0,
|
||||
"xdl dim should be divided by quad dim!");
|
||||
static_assert(kSecondSizePerWarp % kQuadrantSecondDim == 0,
|
||||
"xdl dim should be divided by quad dim!");
|
||||
// xdl rows/cols is divided into quadrants.
|
||||
static constexpr index_t kQuadNumPerLeadDim = kLeadSizePerWarp / kQuadrantLeadDim;
|
||||
static constexpr index_t kQuadNumPerSecondDim = kSecondSizePerWarp / kQuadrantSecondDim;
|
||||
|
||||
static constexpr index_t kIterationsInSecondDim =
|
||||
kQuadNumPerLeadDim * kQuadNumPerSecondDim * 16 / get_warp_size();
|
||||
|
||||
// definitions to adapt to BatchedTransposeKernel
|
||||
|
||||
// FIXME: support padding
|
||||
static constexpr bool kPadM = kPadM_;
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
|
||||
static constexpr auto kMPerBlock = kSecondSizePerBlock;
|
||||
static constexpr auto kNPerBlock = kLeadSizePerBlock;
|
||||
|
||||
// 128-bit is the max single-instruction bandwidth for load/store
|
||||
static constexpr index_t MaxLoadStoreSize = 16;
|
||||
static constexpr auto VectorSizeInput = kPadN ? 1 : MaxLoadStoreSize / sizeof(DataType);
|
||||
static constexpr auto VectorSizeOutput = kPadM ? 1 : MaxLoadStoreSize / sizeof(DataType);
|
||||
static constexpr auto LDSVectorSize = MaxLoadStoreSize / sizeof(DataType);
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -5,8 +5,6 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp"
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -14,39 +12,26 @@ template <typename Problem_, typename Policy_ = BatchedTransposePolicy>
|
||||
struct BatchedTransposePipeline
|
||||
{
|
||||
// TODO: this kernel only support warp per row
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using InputType = ck_tile::remove_cvref_t<typename Problem::InputType>;
|
||||
static constexpr ck_tile::index_t kMPerBlock = Problem::kMPerBlock;
|
||||
static constexpr ck_tile::index_t kNPerBlock = Problem::kNPerBlock;
|
||||
static constexpr index_t AlignmentM = Problem::AlignmentM;
|
||||
static constexpr index_t AlignmentN = Problem::AlignmentN;
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
using Problem = ck_tile::remove_cvref_t<Problem_>;
|
||||
using Policy = ck_tile::remove_cvref_t<Policy_>;
|
||||
|
||||
template <typename InputWindow, typename OutputWindow>
|
||||
CK_TILE_DEVICE auto operator()(const InputWindow& input_window, OutputWindow& out_window)
|
||||
{
|
||||
auto inp_win =
|
||||
make_tile_window(input_window, Policy::template MakeInputDistribution<Problem>());
|
||||
|
||||
auto input_tile = load_tile(inp_win);
|
||||
|
||||
auto output_tile = make_static_distributed_tensor<typename Problem::DataType>(
|
||||
Policy::template MakeOutputDistribution<Problem>());
|
||||
|
||||
transpose_tile2d(output_tile, input_tile);
|
||||
|
||||
auto out_win =
|
||||
make_tile_window(out_window, Policy::template MakeOutputDistribution<Problem>());
|
||||
|
||||
auto x = load_tile(inp_win); // x->thread input_win->block
|
||||
|
||||
auto y = make_static_distributed_tensor<InputType>(
|
||||
Policy::template MakeOutputDistribution<Problem>());
|
||||
|
||||
constexpr auto span_2d_x = decltype(x)::get_distributed_spans();
|
||||
|
||||
sweep_tile_span(span_2d_x[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(span_2d_x[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx1, idx0);
|
||||
y(i_j_idx) = x(i_j_idx);
|
||||
});
|
||||
});
|
||||
|
||||
store_tile(out_win, y);
|
||||
store_tile(out_win, output_tile);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -4,41 +4,26 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/softmax.hpp"
|
||||
#include "ck_tile/ops/topk.hpp"
|
||||
#include "batched_transpose_common_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct BatchedTransposePolicy
|
||||
struct BatchedTransposePolicy : public BatchedTransposeCommonPolicy
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeInputDistribution()
|
||||
CK_TILE_DEVICE static constexpr auto MakeOutputDistribution()
|
||||
{
|
||||
using S = Problem;
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<S::kMWarpPerBlock, S::kMThreadPerWarp, S::kMPerThread>,
|
||||
sequence<S::kNWarpPerBlock, S::kNThreadPerWarp, S::kNPerThread>>,
|
||||
tuple<sequence<1, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<0, 0>, sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 2>>{});
|
||||
}
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t MPerBlock = Problem::kMPerBlock;
|
||||
constexpr index_t NPerBlock = Problem::kNPerBlock;
|
||||
constexpr index_t VecLoadSize = Problem::VectorSizeOutput;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeOutputDistribution()
|
||||
{
|
||||
using S = Problem;
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<S::kNWarpPerBlock, S::kNThreadPerWarp, S::kNPerThread>,
|
||||
sequence<S::kMWarpPerBlock, S::kMThreadPerWarp, S::kMPerThread>>,
|
||||
tuple<sequence<2, 1>, sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>, sequence<1, 1>>,
|
||||
sequence<2, 1>,
|
||||
sequence<2, 2>>{});
|
||||
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
VecLoadSize,
|
||||
TileAccessPattern>;
|
||||
return TileEncodingPattern::MakeShuffled2DStaticTileDistribution();
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -4,45 +4,33 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
#define VectorLoadSize 16
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename InputType_,
|
||||
typename BlockTile, // Sequence<...
|
||||
typename WarpTile, // Sequence<...
|
||||
typename ThreadTile, // Sequence<...
|
||||
bool kPadM_ = true,
|
||||
bool kPadN_ = true>
|
||||
template <typename DataType_,
|
||||
typename BlockTile, // Sequence<...
|
||||
typename WarpLayout,
|
||||
bool kPadM_ = false,
|
||||
bool kPadN_ = false> // Sequence<...
|
||||
struct BatchedTransposeProblem
|
||||
{
|
||||
using InputType = remove_cvref_t<InputType_>;
|
||||
using DataType = remove_cvref_t<DataType_>;
|
||||
|
||||
static constexpr index_t kMPerThread = ThreadTile::at(number<0>{});
|
||||
static constexpr index_t kNPerThread = ThreadTile::at(number<1>{});
|
||||
|
||||
static constexpr index_t kMPerWarp = WarpTile::at(number<0>{});
|
||||
static constexpr index_t kNPerWarp = WarpTile::at(number<1>{});
|
||||
|
||||
static constexpr index_t kMThreadPerWarp = kMPerWarp / kMPerThread;
|
||||
static constexpr index_t kNThreadPerWarp = kNPerWarp / kNPerThread;
|
||||
static constexpr index_t kMPerWarp = WarpLayout::at(number<0>{});
|
||||
static constexpr index_t kNPerWarp = WarpLayout::at(number<1>{});
|
||||
|
||||
static constexpr index_t kMPerBlock = BlockTile::at(number<0>{});
|
||||
static constexpr index_t kNPerBlock = BlockTile::at(number<1>{});
|
||||
|
||||
static constexpr index_t kMWarpPerBlock = kMPerBlock / kMPerWarp;
|
||||
static constexpr index_t kNWarpPerBlock = kNPerBlock / kNPerWarp;
|
||||
|
||||
static constexpr index_t kBlockSize =
|
||||
kMThreadPerWarp * kNThreadPerWarp * kMWarpPerBlock * kNWarpPerBlock;
|
||||
static constexpr index_t kBlockSize = kMPerWarp * kNPerWarp * get_warp_size();
|
||||
|
||||
static constexpr bool kPadM = kPadM_;
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
|
||||
static constexpr index_t AlignmentM = kPadM ? VectorLoadSize / sizeof(InputType) : 1; // TODO
|
||||
static constexpr index_t AlignmentN = kPadN ? VectorLoadSize / sizeof(InputType) : 1;
|
||||
// 128-bit is the max single-instruction bandwidth for load/store
|
||||
static constexpr index_t MaxLoadStoreSize = 16;
|
||||
static constexpr index_t VectorSizeInput = kPadN ? 1 : MaxLoadStoreSize / sizeof(DataType);
|
||||
static constexpr index_t VectorSizeOutput = kPadM ? 1 : MaxLoadStoreSize / sizeof(DataType);
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
0
include/ck_tile/ops/common/utils.hpp
Executable file → Normal file
0
include/ck_tile/ops/common/utils.hpp
Executable file → Normal file
@@ -3,6 +3,11 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/elementwise/binary_elementwise_operation.hpp"
|
||||
#include "ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp"
|
||||
#include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_default_policy.hpp"
|
||||
#include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/elementwise/pipeline/elementwise_shape.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
@@ -0,0 +1,94 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
namespace element_wise {
|
||||
|
||||
struct Add
|
||||
{
|
||||
template <typename Y, typename X0, typename X1>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<float>(float& y, const float& x0, const float& x1) const
|
||||
{
|
||||
y = x0 + x1;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<double>(double& y, const double& x0, const double& x1) const
|
||||
{
|
||||
y = x0 + x1;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<float>(float& y, const float& x0, const half_t& x1) const
|
||||
{
|
||||
y = x0 + type_convert<half_t>(x1);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<half_t>(half_t& y, const float& x0, const float& x1) const
|
||||
{
|
||||
y = type_convert<half_t>(x0 + x1);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<half_t>(half_t& y, const float& x0, const half_t& x1) const
|
||||
{
|
||||
y = type_convert<half_t>(x0) + x1;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
|
||||
{
|
||||
y = x0 + x1;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<float>(float& y, const float& x0, const bf16_t& x1) const
|
||||
{
|
||||
const float x1_tmp = type_convert<float>(x1);
|
||||
y = x0 + x1_tmp;
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<bf16_t>(bf16_t& y, const bf16_t& x0, const bf16_t& x1) const
|
||||
{
|
||||
const float x1_tmp = type_convert<float>(x0);
|
||||
const float x2_tmp = type_convert<float>(x1);
|
||||
const float y_tmp = x1_tmp + x2_tmp;
|
||||
y = type_convert<bf16_t>(y_tmp);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<bf16_t>(bf16_t& y, const float& x0, const bf16_t& x1) const
|
||||
{
|
||||
const float x2_tmp = type_convert<float>(x1);
|
||||
const float y_tmp = x0 + x2_tmp;
|
||||
y = type_convert<bf16_t>(y_tmp);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
|
||||
{
|
||||
y = x0 + x1;
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace element_wise
|
||||
} // namespace ck_tile
|
||||
123
include/ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp
Normal file
123
include/ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp
Normal file
@@ -0,0 +1,123 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_default_policy.hpp"
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem_, typename Policy_>
|
||||
struct ElementWiseKernel
|
||||
{
|
||||
using Problem = ck_tile::remove_cvref_t<Problem_>;
|
||||
using Policy = ck_tile::remove_cvref_t<Policy_>;
|
||||
|
||||
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
|
||||
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
|
||||
using ElementWiseOperation = ck_tile::remove_cvref_t<typename Problem::ElementWiseOperation>;
|
||||
|
||||
template <typename... XDataType, typename Dims>
|
||||
CK_TILE_DEVICE void operator()(Dims lens,
|
||||
Dims input_strides,
|
||||
Dims output_strides,
|
||||
const tuple<XDataType...>& input_tensors,
|
||||
YDataType* p_y) const
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
|
||||
// Setup block-level coordinates and transforms
|
||||
const index_t iM = get_block_id() * S::kBlockM;
|
||||
const auto merge_transform = make_merge_transform(lens);
|
||||
|
||||
// Load all input tiles into registers.
|
||||
// The lambda structure here is intended to minimize the lifetime
|
||||
// of intermediate objects (views, windows) used for loading.
|
||||
const auto x_tiles = ck_tile::generate_tuple(
|
||||
[&](auto i) {
|
||||
const auto tensor_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
input_tensors.get(i), lens, input_strides, number<S::kVectorM>{}, number<1>{});
|
||||
|
||||
const auto transformed_tensor = pad_tensor_view(
|
||||
transform_tensor_view(tensor_view,
|
||||
ck_tile::make_tuple(merge_transform),
|
||||
ck_tile::make_tuple(make_index_sequence<Dims::size()>{}),
|
||||
ck_tile::make_tuple(sequence<0>{})),
|
||||
ck_tile::make_tuple(number<S::kBlockM>{}),
|
||||
sequence<Problem::kPad>{});
|
||||
|
||||
const auto x_window =
|
||||
make_tile_window(transformed_tensor,
|
||||
ck_tile::make_tuple(number<S::kBlockM>{}),
|
||||
{iM},
|
||||
Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
|
||||
return load_tile(x_window);
|
||||
},
|
||||
number<sizeof...(XDataType)>{});
|
||||
|
||||
// Setup output tile in registers.
|
||||
const auto& x_tile0 = x_tiles.get(number<0>{});
|
||||
auto y_tile = make_static_distributed_tensor<YDataType>(x_tile0.get_tile_distribution());
|
||||
|
||||
// Perform element-wise computation.
|
||||
const auto spans = x_tile0.get_distributed_spans();
|
||||
sweep_tile_span(spans[number<0>{}], [&](auto idx) {
|
||||
const auto tile_idx = make_tuple(idx);
|
||||
apply(
|
||||
[&](auto&&... tiles) {
|
||||
ElementWiseOperation{}(y_tile(tile_idx),
|
||||
type_convert<ComputeDataType>(tiles[tile_idx])...);
|
||||
},
|
||||
x_tiles);
|
||||
});
|
||||
|
||||
// Setup output window and store the result tile.
|
||||
const auto y_m_n = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_y, lens, output_strides, number<S::kVectorM>{});
|
||||
|
||||
const auto transformed_y_m_n = pad_tensor_view(
|
||||
transform_tensor_view(y_m_n,
|
||||
ck_tile::make_tuple(merge_transform),
|
||||
ck_tile::make_tuple(make_index_sequence<Dims::size()>{}),
|
||||
ck_tile::make_tuple(sequence<0>{})),
|
||||
ck_tile::make_tuple(number<S::kBlockM>{}),
|
||||
sequence<Problem::kPad>{});
|
||||
|
||||
auto y_window = make_tile_window(transformed_y_m_n,
|
||||
make_tuple(number<S::kBlockM>{}),
|
||||
{iM},
|
||||
y_tile.get_tile_distribution());
|
||||
|
||||
store_tile(y_window, cast_tile<YDataType>(y_tile));
|
||||
}
|
||||
|
||||
template <typename... Ints>
|
||||
CK_TILE_HOST static bool IsSupportedArgument(const ck_tile::tuple<Ints...>& input_sizes)
|
||||
{
|
||||
int total_elements = 1;
|
||||
const auto kVectorM = Problem_::BlockShape::kVectorM;
|
||||
|
||||
apply([&](auto&&... args) { ((total_elements *= args), ...); }, input_sizes);
|
||||
|
||||
if((total_elements % kVectorM) != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Conditions not met: total number of input elements (",
|
||||
total_elements,
|
||||
") should be multiple of the vectorization size (",
|
||||
kVectorM,
|
||||
")");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,29 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
struct ElementWiseDefaultPolicy
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution()
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<>, // Replicate
|
||||
tuple<sequence<S::kRepeatM,
|
||||
S::kWarpPerBlockM,
|
||||
S::kThreadPerWarpM,
|
||||
S::kVectorM>>, // Hierarchical
|
||||
tuple<sequence<1>, sequence<1>>, // Parallel
|
||||
tuple<sequence<1>, sequence<2>>, // Parallel
|
||||
sequence<1, 1>, // Yield
|
||||
sequence<0, 3>>{} // Yield
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,26 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename XDataType_,
|
||||
typename ComputeDataType_,
|
||||
typename YDataType_,
|
||||
typename BlockShape_,
|
||||
typename ElementWiseOperation_,
|
||||
bool kPad_ = true>
|
||||
struct ElementWisePipelineProblem
|
||||
{
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
using YDataType = remove_cvref_t<YDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
using ElementWiseOperation = remove_cvref_t<ElementWiseOperation_>;
|
||||
static constexpr bool kPad = kPad_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,30 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename BlockWarps, typename BlockTile, typename WarpTile, typename ComputeDataType>
|
||||
struct ElementWiseShape
|
||||
{
|
||||
static constexpr index_t kBlockM = BlockTile::at(number<0>{});
|
||||
|
||||
static constexpr index_t kWarpM = WarpTile::at(number<0>{});
|
||||
|
||||
static constexpr index_t kVectorM =
|
||||
min(static_cast<index_t>(16 / sizeof(ComputeDataType)), kWarpM / get_warp_size());
|
||||
|
||||
static constexpr index_t kWarpPerBlockM = BlockWarps::at(number<0>{});
|
||||
|
||||
static constexpr index_t kThreadPerWarpM = get_warp_size();
|
||||
|
||||
static constexpr index_t kRepeatM = kBlockM / (kWarpPerBlockM * kVectorM * kThreadPerWarpM);
|
||||
|
||||
static constexpr index_t kBlockSize =
|
||||
ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -110,6 +110,86 @@ CK_TILE_DEVICE bf16x4_t i4_to_bhalf4(int q)
|
||||
return res;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE fp8x8_t amd_assembly_i4_to_fp8x8(int a)
|
||||
{
|
||||
uint32_t src = static_cast<uint32_t>(a), src_hi;
|
||||
uint32_t fp8x4_lo, fp8x4_hi;
|
||||
float tmp_0, tmp_1;
|
||||
|
||||
asm volatile("v_lshrrev_b32 %[v_hi_src], 4, %[v_src]\n"
|
||||
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_3\n"
|
||||
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_3\n"
|
||||
"v_cvt_pk_fp8_f32 %[v_dst_hi], %[v_tmp_1], %[v_tmp_0], op_sel:[0, 0, 1]\n"
|
||||
|
||||
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_2\n"
|
||||
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_2\n"
|
||||
"v_cvt_pk_fp8_f32 %[v_dst_hi], %[v_tmp_1], %[v_tmp_0]\n"
|
||||
|
||||
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_1\n"
|
||||
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_1\n"
|
||||
"v_cvt_pk_fp8_f32 %[v_dst_lo], %[v_tmp_1], %[v_tmp_0], op_sel:[0, 0, 1]\n"
|
||||
|
||||
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src]\n"
|
||||
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src]\n"
|
||||
"v_cvt_pk_fp8_f32 %[v_dst_lo], %[v_tmp_1], %[v_tmp_0]\n"
|
||||
: [v_tmp_0] "+v"(tmp_0),
|
||||
[v_tmp_1] "+v"(tmp_1),
|
||||
[v_hi_src] "+v"(src_hi),
|
||||
[v_dst_lo] "+v"(fp8x4_lo),
|
||||
[v_dst_hi] "+v"(fp8x4_hi),
|
||||
[v_src] "+v"(src)
|
||||
:);
|
||||
|
||||
return bit_cast<fp8x8_t>(((static_cast<uint64_t>(fp8x4_hi) << 32) | fp8x4_lo));
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE float amd_assembly_fp8_to_fp32(uint32_t src)
|
||||
{
|
||||
float res;
|
||||
asm volatile("v_cvt_f32_fp8 %0, %1, src0_sel:BYTE_0" : "=v"(res) : "v"(src));
|
||||
return res;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE float amd_assembly_bf8_to_fp32(uint32_t src)
|
||||
{
|
||||
float res;
|
||||
asm volatile("v_cvt_f32_bf8 %0, %1, src0_sel:BYTE_0" : "=v"(res) : "v"(src));
|
||||
return res;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE bf8x8_t amd_assembly_i4_to_bf8x8(int a)
|
||||
{
|
||||
uint32_t src = static_cast<uint32_t>(a), src_hi;
|
||||
uint32_t bf8x4_lo, bf8x4_hi;
|
||||
float tmp_0, tmp_1;
|
||||
|
||||
asm volatile("v_lshrrev_b32 %[v_hi_src], 4, %[v_src]\n"
|
||||
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_3\n"
|
||||
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_3\n"
|
||||
"v_cvt_pk_bf8_f32 %[v_dst_hi], %[v_tmp_1], %[v_tmp_0], op_sel:[0, 0, 1]\n"
|
||||
|
||||
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_2\n"
|
||||
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_2\n"
|
||||
"v_cvt_pk_bf8_f32 %[v_dst_hi], %[v_tmp_1], %[v_tmp_0]\n"
|
||||
|
||||
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_1\n"
|
||||
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_1\n"
|
||||
"v_cvt_pk_bf8_f32 %[v_dst_lo], %[v_tmp_1], %[v_tmp_0], op_sel:[0, 0, 1]\n"
|
||||
|
||||
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src]\n"
|
||||
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src]\n"
|
||||
"v_cvt_pk_bf8_f32 %[v_dst_lo], %[v_tmp_1], %[v_tmp_0]\n"
|
||||
: [v_tmp_0] "+v"(tmp_0),
|
||||
[v_tmp_1] "+v"(tmp_1),
|
||||
[v_hi_src] "+v"(src_hi),
|
||||
[v_dst_lo] "+v"(bf8x4_lo),
|
||||
[v_dst_hi] "+v"(bf8x4_hi),
|
||||
[v_src] "+v"(src)
|
||||
:);
|
||||
|
||||
return bit_cast<bf8x8_t>(((static_cast<uint64_t>(bf8x4_hi) << 32) | bf8x4_lo));
|
||||
}
|
||||
|
||||
struct PassThroughPack8
|
||||
{
|
||||
template <typename Y, typename X>
|
||||
@@ -126,6 +206,16 @@ struct PassThroughPack8
|
||||
y.lo = i4_to_bhalf4(bit_cast<int>(x));
|
||||
y.hi = i4_to_bhalf4(bit_cast<int>(x) >> 16);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(fp8x8_t& y, const pk_int4x4_t& x) const
|
||||
{
|
||||
y = amd_assembly_i4_to_fp8x8(bit_cast<int>(x));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(bf8x8_t& y, const pk_int4x4_t& x) const
|
||||
{
|
||||
y = amd_assembly_i4_to_bf8x8(bit_cast<int>(x));
|
||||
}
|
||||
constexpr const static bool is_pack8_invocable = true;
|
||||
};
|
||||
|
||||
@@ -172,219 +262,67 @@ struct PassThroughPack2
|
||||
|
||||
struct PassThrough
|
||||
{
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
|
||||
template <class T>
|
||||
using raw_t = std::remove_cv_t<std::remove_reference_t<T>>;
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<double, double>(double& y, const double& x) const
|
||||
template <class Y, class X>
|
||||
CK_TILE_HOST_DEVICE void operator()(Y&& y, const X& x) const
|
||||
{
|
||||
y = x;
|
||||
/* Only do the assignment when
|
||||
- y is an *l-value* and
|
||||
- y is *not* const */
|
||||
if constexpr(std::is_lvalue_reference_v<Y&&> && !std::is_const_v<raw_t<Y>>)
|
||||
{
|
||||
y = ck_tile::type_convert<raw_t<Y>>(x);
|
||||
}
|
||||
/* otherwise (r-value or const) → do nothing */
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<float, double>(float& y, const double& x) const
|
||||
template <typename E, typename C, typename... Ds>
|
||||
CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const Ds&... ds) const -> void
|
||||
{
|
||||
y = type_convert<float>(x);
|
||||
}
|
||||
// Suppress unused parameter warning for ds
|
||||
((void)ds, ...);
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<double, float>(double& y, const float& x) const
|
||||
{
|
||||
y = type_convert<double>(x);
|
||||
// Just assign e with c
|
||||
if constexpr(std::is_same_v<E, C>)
|
||||
{
|
||||
e = c;
|
||||
}
|
||||
else
|
||||
{
|
||||
e = ck_tile::type_convert<E>(c);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<float, float>(float& y, const float& x) const
|
||||
struct MultiDMultiply
|
||||
{
|
||||
template <typename E, typename C, typename... Ds>
|
||||
CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const Ds&... ds) const -> void
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
// Start with the base value c
|
||||
float result = ck_tile::type_convert<float>(c);
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()<ck_tile::fp16_t, ck_tile::fp16_t>(ck_tile::fp16_t& y, const ck_tile::fp16_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
// Multiply by each D parameter using fold expression
|
||||
((result *= ck_tile::type_convert<float>(ds)), ...);
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp16_t, float>(ck_tile::fp16_t& y,
|
||||
const float& x) const
|
||||
{
|
||||
y = type_convert<ck_tile::fp16_t>(x);
|
||||
e = ck_tile::type_convert<E>(result);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()<ck_tile::bf16_t, ck_tile::bf16_t>(ck_tile::bf16_t& y, const ck_tile::bf16_t& x) const
|
||||
struct MultiDAdd
|
||||
{
|
||||
template <typename E, typename C, typename... Ds>
|
||||
CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const Ds&... ds) const -> void
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
// Start with the base value c
|
||||
float result = ck_tile::type_convert<float>(c);
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<int32_t, int32_t>(int32_t& y, const int32_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
// Add by each D parameter using fold expression
|
||||
((result += ck_tile::type_convert<float>(ds)), ...);
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<ck_tile::bf16_t, float>(ck_tile::bf16_t& y,
|
||||
const float& x) const
|
||||
{
|
||||
y = type_convert<ck_tile::bf16_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::bf16_t>(float& y,
|
||||
const ck_tile::bf16_t& x) const
|
||||
{
|
||||
y = type_convert<float>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()<ck_tile::bf16_t, ck_tile::fp16_t>(ck_tile::bf16_t& y, const ck_tile::fp16_t& x) const
|
||||
{
|
||||
y = type_convert<ck_tile::bf16_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::fp16_t>(float& y,
|
||||
const ck_tile::fp16_t& x) const
|
||||
{
|
||||
y = type_convert<float>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp16_t, int8_t>(ck_tile::fp16_t& y,
|
||||
const int8_t& x) const
|
||||
{
|
||||
y = type_convert<ck_tile::fp16_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<ck_tile::bf16_t, int8_t>(ck_tile::bf16_t& y,
|
||||
const int8_t& x) const
|
||||
{
|
||||
y = type_convert<ck_tile::bf16_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<uint8_t, uint8_t>(uint8_t& y, const uint8_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<int8_t, int32_t>(int8_t& y, const int32_t& x) const
|
||||
{
|
||||
y = type_convert<int8_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<int32_t, int8_t>(int32_t& y, const int8_t& x) const
|
||||
{
|
||||
y = type_convert<int32_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<int8_t, float>(int8_t& y, const float& x) const
|
||||
{
|
||||
y = type_convert<int8_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<float, int8_t>(float& y, const int8_t& x) const
|
||||
{
|
||||
y = type_convert<float>(x);
|
||||
}
|
||||
|
||||
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<int4_t, int4_t>(int4_t& y, const int4_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<int4_t, int>(int4_t& y, const int& x) const
|
||||
{
|
||||
y = type_convert<int4_t>(x);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()<ck_tile::fp8_t, ck_tile::fp8_t>(ck_tile::fp8_t& y, const ck_tile::fp8_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::fp8_t>(float& y,
|
||||
const ck_tile::fp8_t& x) const
|
||||
{
|
||||
y = type_convert<float>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp8_t, float>(ck_tile::fp8_t& y,
|
||||
const float& x) const
|
||||
{
|
||||
y = type_convert<ck_tile::fp8_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()<ck_tile::fp16_t, ck_tile::fp8_t>(ck_tile::fp16_t& y, const ck_tile::fp8_t& x) const
|
||||
{
|
||||
y = type_convert<ck_tile::fp16_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()<ck_tile::fp8_t, ck_tile::fp16_t>(ck_tile::fp8_t& y, const ck_tile::fp16_t& x) const
|
||||
{
|
||||
y = type_convert<ck_tile::fp8_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()<ck_tile::bf8_t, ck_tile::bf8_t>(ck_tile::bf8_t& y, const ck_tile::bf8_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::bf8_t>(float& y,
|
||||
const ck_tile::bf8_t& x) const
|
||||
{
|
||||
y = type_convert<float>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<ck_tile::bf8_t, float>(ck_tile::bf8_t& y,
|
||||
const float& x) const
|
||||
{
|
||||
y = type_convert<ck_tile::bf8_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()<ck_tile::fp16_t, ck_tile::bf8_t>(ck_tile::fp16_t& y, const ck_tile::bf8_t& x) const
|
||||
{
|
||||
y = type_convert<ck_tile::fp16_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()<ck_tile::bf8_t, ck_tile::fp16_t>(ck_tile::bf8_t& y, const ck_tile::fp16_t& x) const
|
||||
{
|
||||
y = ck_tile::type_convert<ck_tile::bf8_t>(x);
|
||||
e = ck_tile::type_convert<E>(result);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1479,5 +1417,6 @@ struct FastNumericArrayConverter<uint8_t, ck_tile::fp16_t, N>
|
||||
CK_TILE_DEVICE OutputArray operator()(InputArray const& Input) { return convert(Input); }
|
||||
};
|
||||
#endif
|
||||
|
||||
} // namespace element_wise
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -4,9 +4,9 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
|
||||
#include "ck_tile/ops/epilogue/default_2d_and_dynamic_quant_epilogue.hpp"
|
||||
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
|
||||
#include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp"
|
||||
#include "ck_tile/ops/epilogue/default_2d_and_dynamic_quant_epilogue.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -11,34 +11,52 @@ namespace ck_tile {
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename DsDataType_,
|
||||
typename AccDataType_,
|
||||
typename ODataType_,
|
||||
typename CLayout_,
|
||||
typename DsLayout_,
|
||||
typename ELayout_,
|
||||
typename CDElementwise_,
|
||||
index_t kBlockSize_,
|
||||
index_t kM_,
|
||||
index_t kN_,
|
||||
index_t kMWave_,
|
||||
index_t kNWave_,
|
||||
index_t kMPerXdl_,
|
||||
index_t kNPerXdl_,
|
||||
index_t kKPerXdl_,
|
||||
bool isCTransposed_>
|
||||
index_t MWave_,
|
||||
index_t NWave_,
|
||||
index_t MPerXdl_,
|
||||
index_t NPerXdl_,
|
||||
index_t KPerXdl_,
|
||||
bool isCTransposed_,
|
||||
memory_operation_enum MemoryOperation_,
|
||||
index_t kNumWaveGroups_ = 1,
|
||||
bool FixedVectorSize_ = false,
|
||||
index_t VectorSizeC_ = 1>
|
||||
struct CShuffleEpilogueProblem
|
||||
{
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using AccDataType = remove_cvref_t<AccDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
using CLayout = remove_cvref_t<CLayout_>;
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
static constexpr index_t kMPerBlock = kM_;
|
||||
static constexpr index_t kNPerBlock = kN_;
|
||||
static constexpr index_t kMWave = kMWave_;
|
||||
static constexpr index_t kNWave = kNWave_;
|
||||
static constexpr index_t kMPerXdl = kMPerXdl_;
|
||||
static constexpr index_t kNPerXdl = kNPerXdl_;
|
||||
static constexpr index_t kKPerXdl = kKPerXdl_;
|
||||
static constexpr index_t isCTransposed = isCTransposed_;
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using AccDataType = remove_cvref_t<AccDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
using DsDataType = remove_cvref_t<DsDataType_>;
|
||||
using DsLayout = remove_cvref_t<DsLayout_>;
|
||||
using ELayout = remove_cvref_t<ELayout_>;
|
||||
using CDElementwise = remove_cvref_t<CDElementwise_>;
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
static constexpr index_t kMPerBlock = kM_;
|
||||
static constexpr index_t kNPerBlock = kN_;
|
||||
static constexpr index_t MWave = MWave_;
|
||||
static constexpr index_t NWave = NWave_;
|
||||
static constexpr index_t MPerXdl = MPerXdl_;
|
||||
static constexpr index_t NPerXdl = NPerXdl_;
|
||||
static constexpr index_t KPerXdl = KPerXdl_;
|
||||
static constexpr index_t isCTransposed = isCTransposed_;
|
||||
static constexpr memory_operation_enum MemoryOperation = MemoryOperation_;
|
||||
static constexpr bool FixedVectorSize = FixedVectorSize_;
|
||||
static constexpr index_t VectorSizeC = VectorSizeC_;
|
||||
static constexpr index_t kNumWaveGroups = kNumWaveGroups_;
|
||||
static constexpr index_t NumDTensor = DsDataType::size();
|
||||
|
||||
static_assert(NumDTensor == DsLayout::size(),
|
||||
"The size of DsDataType and DsLayout should be the same");
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
@@ -49,32 +67,33 @@ struct CShuffleEpilogue
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using DsDataType = remove_cvref_t<typename Problem::DsDataType>;
|
||||
using DsLayout = remove_cvref_t<typename Problem::DsLayout>;
|
||||
using ATypeToUse =
|
||||
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
|
||||
// Used for weight-only quantization kernel, B would be dequantized to the same data type as A
|
||||
using BTypeToUse =
|
||||
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ODataType, BDataType>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t kMPerBlock = Problem::kMPerBlock;
|
||||
static constexpr index_t kNPerBlock = Problem::kNPerBlock;
|
||||
static constexpr index_t kMWave = Problem::kMWave;
|
||||
static constexpr index_t kNWave = Problem::kNWave;
|
||||
static constexpr index_t kMPerXdl = Problem::kMPerXdl;
|
||||
static constexpr index_t kNPerXdl = Problem::kNPerXdl;
|
||||
static constexpr index_t kKPerXdl = Problem::kKPerXdl;
|
||||
static constexpr index_t isCTransposed = Problem::isCTransposed;
|
||||
static constexpr index_t kMPerIteration = kMPerXdl * kMWave;
|
||||
static constexpr index_t kNPerIteration = kNPerXdl * kNWave;
|
||||
|
||||
using WG = WarpGemmMfmaDispatcher<ADataType,
|
||||
BTypeToUse,
|
||||
AccDataType,
|
||||
kMPerXdl,
|
||||
kNPerXdl,
|
||||
kKPerXdl,
|
||||
isCTransposed>;
|
||||
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
|
||||
using ELayout = remove_cvref_t<typename Problem::ELayout>;
|
||||
using CDElementwise = remove_cvref_t<typename Problem::CDElementwise>;
|
||||
static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t kMPerBlock = Problem::kMPerBlock;
|
||||
static constexpr index_t kNPerBlock = Problem::kNPerBlock;
|
||||
static constexpr index_t MWave = Problem::MWave;
|
||||
static constexpr index_t NWave = Problem::NWave;
|
||||
static constexpr index_t MPerXdl = Problem::MPerXdl;
|
||||
static constexpr index_t NPerXdl = Problem::NPerXdl;
|
||||
static constexpr index_t KPerXdl = Problem::KPerXdl;
|
||||
static constexpr index_t isCTransposed = Problem::isCTransposed;
|
||||
static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
|
||||
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
|
||||
static constexpr index_t MPerIteration = MPerXdl * MWave;
|
||||
static constexpr index_t NPerIteration = NPerXdl * NWave;
|
||||
static constexpr index_t NumDTensor = Problem::NumDTensor;
|
||||
|
||||
static_assert(NumDTensor == DsLayout::size(),
|
||||
"The size of DsDataType and DsLayout should be the same");
|
||||
/**
|
||||
* @brief Get the vector store size for C tensor.
|
||||
*
|
||||
@@ -85,100 +104,244 @@ struct CShuffleEpilogue
|
||||
*
|
||||
* @return The vector store size for C tensor.
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeC()
|
||||
{
|
||||
constexpr index_t MaxVectorStoreSize = 16;
|
||||
return MaxVectorStoreSize / sizeof(ODataType);
|
||||
if constexpr(FixedVectorSize)
|
||||
{
|
||||
return VectorSizeC;
|
||||
}
|
||||
constexpr index_t max_vector_size = 16;
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return std::min(static_cast<int>(NPerIteration),
|
||||
static_cast<int>(max_vector_size / sizeof(ODataType)));
|
||||
}
|
||||
else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return std::min(static_cast<int>(MPerIteration),
|
||||
static_cast<int>(max_vector_size / sizeof(ODataType)));
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported ELayout!");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get the vector store size for Di tensor.
|
||||
*
|
||||
* @return The vector store size for Di tensor.
|
||||
*/
|
||||
template <index_t I>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeD(number<I> index)
|
||||
{
|
||||
constexpr index_t max_vector_size = 16;
|
||||
using DiDataType = remove_cvref_t<std::tuple_element_t<index.value, DsDataType>>;
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return std::min(static_cast<int>(NPerIteration),
|
||||
static_cast<int>(max_vector_size / sizeof(DiDataType)));
|
||||
}
|
||||
else if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return std::min(static_cast<int>(MPerIteration),
|
||||
static_cast<int>(max_vector_size / sizeof(DiDataType)));
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported DLayout!");
|
||||
}
|
||||
return max_vector_size / sizeof(DiDataType);
|
||||
}
|
||||
/**
|
||||
* @brief Shuffle tile configuration parameters
|
||||
*
|
||||
* @details These parameters control the number of XDL tiles processed per wave in each shuffle
|
||||
* iteration:
|
||||
* - NumMXdlPerWavePerShuffle: Number of XDL tiles in M dimension processed per wave
|
||||
* - NumNXdlPerWavePerShuffle: Number of XDL tiles in N dimension processed per wave
|
||||
*/
|
||||
static constexpr auto shuffle_tile_tuple = [] {
|
||||
constexpr index_t elem_per_thread = MPerXdl * NPerXdl / get_warp_size();
|
||||
if constexpr(elem_per_thread >= GetVectorSizeC())
|
||||
{
|
||||
return std::make_tuple(1, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t num_xdl_shuffles = GetVectorSizeC() / elem_per_thread;
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
static_assert((kMPerBlock % (MPerXdl * MWave) == 0) &&
|
||||
(kMPerBlock % num_xdl_shuffles == 0),
|
||||
"kMPerBlock must be divisible by MPerXdl*MWave and "
|
||||
"num_xdl_shuffles for CShuffleEpilogue");
|
||||
return std::make_tuple(min(num_xdl_shuffles, kMPerBlock / (MPerXdl * MWave)), 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert((kNPerBlock % (NPerXdl * NWave) == 0) &&
|
||||
(kNPerBlock % num_xdl_shuffles == 0),
|
||||
"kNPerBlock must be divisible by NPerXdl*NWave and "
|
||||
"num_xdl_shuffles for CShuffleEpilogue");
|
||||
return std::make_tuple(1, min(num_xdl_shuffles, kNPerBlock / (NPerXdl * NWave)));
|
||||
}
|
||||
}
|
||||
}();
|
||||
static constexpr index_t NumMXdlPerWavePerShuffle = std::get<0>(shuffle_tile_tuple);
|
||||
static constexpr index_t NumNXdlPerWavePerShuffle = std::get<1>(shuffle_tile_tuple);
|
||||
|
||||
static constexpr auto MNPerIterationShuffle = [] {
|
||||
constexpr index_t m_val = MPerXdl * MWave * NumMXdlPerWavePerShuffle;
|
||||
constexpr index_t n_val = NPerXdl * NWave * NumNXdlPerWavePerShuffle;
|
||||
if constexpr(kMPerBlock % m_val != 0 || kNPerBlock % n_val != 0)
|
||||
return std::make_tuple(MPerXdl * MWave, NPerXdl * NWave);
|
||||
else
|
||||
return std::make_tuple(m_val, n_val);
|
||||
}();
|
||||
static constexpr index_t MPerIterationShuffle = std::get<0>(MNPerIterationShuffle);
|
||||
static constexpr index_t NPerIterationShuffle = std::get<1>(MNPerIterationShuffle);
|
||||
|
||||
using WG = WarpGemmDispatcher<ATypeToUse,
|
||||
BTypeToUse,
|
||||
AccDataType,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
KPerXdl,
|
||||
isCTransposed>;
|
||||
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDescriptor()
|
||||
{
|
||||
// N is contiguous dimension
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(number<kMWave * kMPerXdl>{}, number<kNWave * kNPerXdl>{}),
|
||||
make_tuple(number<kNWave * kNPerXdl>{}, number<1>{}));
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
|
||||
make_tuple(number<NPerIterationShuffle>{}, number<1>{}));
|
||||
}
|
||||
// M is contiguous dimension
|
||||
else if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(number<kMWave * kMPerXdl>{}, number<kNWave * kNPerXdl>{}),
|
||||
make_tuple(number<1>{}, number<kMWave * kMPerXdl>{}));
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
|
||||
make_tuple(number<1>{}, number<MPerIterationShuffle>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported CLayout!");
|
||||
static_assert(false, "Unsupported ELayout!");
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeLdsDistributionEncode()
|
||||
{
|
||||
constexpr auto block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
|
||||
sequence<NumNXdlPerWavePerShuffle, NWave>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto block_dstr_encoding = detail::make_embed_tile_distribution_encoding(
|
||||
block_outer_dstr_encoding, typename CWarpDstr::DstrEncode{});
|
||||
|
||||
return block_dstr_encoding;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return kMWave * kNWave * kMPerXdl * kNPerXdl * sizeof(ODataType);
|
||||
return MPerIterationShuffle * NPerIterationShuffle * sizeof(ODataType);
|
||||
}
|
||||
|
||||
template <typename ODramWindow,
|
||||
typename OAccTile,
|
||||
memory_operation_enum out_memory_data_op = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(ODramWindow& out_dram_window, const OAccTile& o_acc_tile, void* p_smem)
|
||||
template <typename ODramWindow, typename OAccTile, typename DsDramWindows>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
const DsDramWindows& ds_dram_windows,
|
||||
void* p_smem)
|
||||
{
|
||||
constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode());
|
||||
|
||||
const index_t iMWarp = get_warp_id() / kNWave;
|
||||
const index_t iNWarp = get_warp_id() - iMWarp * kNWave;
|
||||
auto lds_tile = make_static_distributed_tensor<AccDataType>(LdsTileDistr);
|
||||
|
||||
constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
|
||||
auto o_lds_block = make_tensor_view<address_space_enum::lds>(
|
||||
static_cast<ODataType*>(p_smem), lds_block_desc);
|
||||
auto in_lds_window =
|
||||
make_tile_window(o_lds_block,
|
||||
make_tuple(number<kMPerXdl>{}, number<kNPerXdl>{}),
|
||||
{number<kMPerXdl>{} * iMWarp, number<kNPerXdl>{} * iNWarp});
|
||||
auto out_lds_window =
|
||||
make_tile_window(o_lds_block,
|
||||
make_tuple(number<kMWave * kMPerXdl>{}, number<kNWave * kNPerXdl>{}),
|
||||
{0, 0});
|
||||
|
||||
auto in_lds_window = make_tile_window(
|
||||
o_lds_block,
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
|
||||
{0, 0},
|
||||
LdsTileDistr);
|
||||
|
||||
auto out_lds_window = make_tile_window(
|
||||
o_lds_block,
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
|
||||
{0, 0});
|
||||
|
||||
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
|
||||
sequence<0, 1>,
|
||||
sequence<kMPerXdl * kMWave, kNPerXdl * kNWave>>;
|
||||
sequence<0, 1>,
|
||||
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
|
||||
constexpr index_t num_access = SFC::get_num_of_access();
|
||||
|
||||
static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
|
||||
"Currently, the CShuffle Epilogue only supports the Row Major Output layout");
|
||||
|
||||
using TileEncodingPattern =
|
||||
TileDistributionEncodingPattern2D<kBlockSize,
|
||||
kMPerIteration,
|
||||
kNPerIteration,
|
||||
MPerIterationShuffle,
|
||||
NPerIterationShuffle,
|
||||
GetVectorSizeC(),
|
||||
tile_distribution_pattern::thread_raked>;
|
||||
tile_distribution_pattern::thread_raked,
|
||||
Problem::kNumWaveGroups>;
|
||||
constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
|
||||
auto d_dram_windows = generate_tuple(
|
||||
[&](auto idx) {
|
||||
return make_tile_window(ds_dram_windows[idx], dram_tile_distribution);
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
CWarpTensor c_warp_in_tensor;
|
||||
static_for<0, num_access, 1>{}([&](auto iAccess) {
|
||||
block_sync_lds();
|
||||
constexpr auto idx_y_start = SFC::get_index(iAccess);
|
||||
|
||||
constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (kMPerXdl * kMWave)>{};
|
||||
constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (kNPerXdl * kNWave)>{};
|
||||
constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (MPerIterationShuffle)>{};
|
||||
constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (NPerIterationShuffle)>{};
|
||||
|
||||
c_warp_in_tensor.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
lds_tile.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(
|
||||
sequence<mIter * NumMXdlPerWavePerShuffle, nIter * NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_lengths));
|
||||
|
||||
const auto c_warp_in_tensor_casted = cast_tile<ODataType>(c_warp_in_tensor);
|
||||
const auto c_warptile_in_tensor_casted = cast_tile<ODataType>(lds_tile);
|
||||
|
||||
block_sync_lds();
|
||||
store_tile(in_lds_window, c_warp_in_tensor_casted);
|
||||
store_tile(in_lds_window, c_warptile_in_tensor_casted);
|
||||
block_sync_lds();
|
||||
|
||||
const auto c_out_tensor =
|
||||
load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
|
||||
auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
|
||||
|
||||
if constexpr(out_memory_data_op == memory_operation_enum::set)
|
||||
const auto ds_tensor = generate_tuple(
|
||||
[&](auto idx) { return load_tile(d_dram_windows[idx]); }, number<NumDTensor>{});
|
||||
|
||||
const auto c_ds_tiles = concat_tuple_of_reference(
|
||||
tie(c_out_tensor, c_out_tensor),
|
||||
generate_tie([&](auto idx) -> const auto& { return ds_tensor[idx]; },
|
||||
number<NumDTensor>{}));
|
||||
|
||||
tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles);
|
||||
|
||||
if constexpr(MemoryOperation == memory_operation_enum::set)
|
||||
{
|
||||
store_tile(out_dram_window, c_out_tensor);
|
||||
}
|
||||
@@ -189,7 +352,13 @@ struct CShuffleEpilogue
|
||||
if constexpr(iAccess != num_access - 1)
|
||||
{
|
||||
constexpr auto step = SFC::get_forward_step(iAccess);
|
||||
|
||||
move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})});
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto idx) {
|
||||
move_tile_window(d_dram_windows[idx],
|
||||
{step.at(number<0>{}), step.at(number<1>{})});
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -15,17 +15,21 @@ template <typename AccDataType_,
|
||||
typename ODataType_,
|
||||
bool kPadM_,
|
||||
bool kPadN_,
|
||||
bool UseRawStore_ = true>
|
||||
bool UseRawStore_ = true,
|
||||
memory_operation_enum MemoryOperation_ = memory_operation_enum::set>
|
||||
struct Default2DEpilogueProblem
|
||||
{
|
||||
using AccDataType = remove_cvref_t<AccDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
static constexpr bool kPadM = kPadM_;
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool UseRawStore = UseRawStore_;
|
||||
using AccDataType = remove_cvref_t<AccDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
static constexpr bool kPadM = kPadM_;
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool UseRawStore = UseRawStore_;
|
||||
static constexpr memory_operation_enum MemoryOperation = MemoryOperation_;
|
||||
};
|
||||
|
||||
template <typename AccDataType_,
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename AccDataType_,
|
||||
typename ODataType_,
|
||||
typename CLayout_,
|
||||
bool kPadM_,
|
||||
@@ -34,10 +38,17 @@ template <typename AccDataType_,
|
||||
index_t kNPerXdl_,
|
||||
index_t kKPerXdl_,
|
||||
bool isCTransposed_,
|
||||
bool UseRawStore_ = true>
|
||||
struct DefaultGemm2DEpilogueProblem
|
||||
: public Default2DEpilogueProblem<AccDataType_, ODataType_, kPadM_, kPadN_, UseRawStore_>
|
||||
bool UseRawStore_ = true,
|
||||
memory_operation_enum MemoryOperation_ = memory_operation_enum::set>
|
||||
struct DefaultGemm2DEpilogueProblem : public Default2DEpilogueProblem<AccDataType_,
|
||||
ODataType_,
|
||||
kPadM_,
|
||||
kPadN_,
|
||||
UseRawStore_,
|
||||
MemoryOperation_>
|
||||
{
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using CLayout = remove_cvref_t<CLayout_>;
|
||||
static constexpr index_t kMPerXdl = kMPerXdl_;
|
||||
static constexpr index_t kNPerXdl = kNPerXdl_;
|
||||
@@ -54,22 +65,20 @@ struct Default2DEpilogue
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool UseRawStore = Problem::UseRawStore;
|
||||
static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
|
||||
|
||||
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
|
||||
// how do we fix this ?
|
||||
template <typename ODramWindowTmp,
|
||||
typename OAccTile,
|
||||
memory_operation_enum out_memory_data_op = memory_operation_enum::set>
|
||||
template <typename ODramWindowTmp, typename OAccTile>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile, void* = nullptr)
|
||||
operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile, void* = nullptr) const
|
||||
{
|
||||
|
||||
// TODO: this is ugly
|
||||
if constexpr(UseRawStore && (kPadM || kPadN))
|
||||
{
|
||||
if constexpr(out_memory_data_op == memory_operation_enum::set)
|
||||
if constexpr(MemoryOperation == memory_operation_enum::set)
|
||||
{
|
||||
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
|
||||
}
|
||||
@@ -81,7 +90,7 @@ struct Default2DEpilogue
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(out_memory_data_op == memory_operation_enum::set)
|
||||
if constexpr(MemoryOperation == memory_operation_enum::set)
|
||||
{
|
||||
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
|
||||
}
|
||||
@@ -91,27 +100,43 @@ struct Default2DEpilogue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ODramWindowTmp, typename OAccTile, typename DsDramWindows>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
|
||||
const OAccTile& o_acc_tile,
|
||||
const DsDramWindows& /* unused */,
|
||||
void* = nullptr) const
|
||||
{
|
||||
return operator()<ODramWindowTmp, OAccTile>(o_dram_window_tmp, o_acc_tile);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
// Used for weight-only quantization kernel, B would be dequantized to the same data type as A
|
||||
using BTypeToUse =
|
||||
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
|
||||
using DsDataType = ck_tile::tuple<>;
|
||||
using DsLayout = ck_tile::tuple<>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
static constexpr index_t kMPerXdl = Problem::kMPerXdl;
|
||||
static constexpr index_t kNPerXdl = Problem::kNPerXdl;
|
||||
static constexpr index_t kKPerXdl = Problem::kKPerXdl;
|
||||
static constexpr index_t isCTransposed = Problem::isCTransposed;
|
||||
|
||||
using WG = WarpGemmMfmaDispatcher<ODataType,
|
||||
ODataType,
|
||||
AccDataType,
|
||||
kMPerXdl,
|
||||
kNPerXdl,
|
||||
kKPerXdl,
|
||||
isCTransposed>;
|
||||
using WG = WarpGemmDispatcher<ADataType,
|
||||
BTypeToUse,
|
||||
AccDataType,
|
||||
kMPerXdl,
|
||||
kNPerXdl,
|
||||
kKPerXdl,
|
||||
isCTransposed>;
|
||||
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
|
||||
@@ -134,7 +159,9 @@ struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
|
||||
else
|
||||
{
|
||||
// In this case each thread has just a single item in Ndim
|
||||
return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN;
|
||||
return (WG::WarpGemmAttribute::Impl::kCNLane *
|
||||
WG::WarpGemmAttribute::Impl::kBNBlock) /
|
||||
WG::kN;
|
||||
}
|
||||
}
|
||||
// M is contiguous dimension
|
||||
@@ -143,7 +170,9 @@ struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
|
||||
if constexpr(isCTransposed)
|
||||
{
|
||||
// In this case each thread has just a single item in Mdim
|
||||
return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN;
|
||||
return (WG::WarpGemmAttribute::Impl::kCNLane *
|
||||
WG::WarpGemmAttribute::Impl::kAMBlock) /
|
||||
WG::kN;
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -162,6 +191,8 @@ struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
|
||||
static_assert(false, "Unsupported CLayout!");
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeD() { return 1; }
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -3,10 +3,16 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp"
|
||||
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp"
|
||||
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp"
|
||||
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
|
||||
#include "ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp"
|
||||
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp"
|
||||
#include "ck_tile/ops/flatmm/pipeline/tile_flatmm_shape.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -0,0 +1,122 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1_custom_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A is block window on shared memory
|
||||
// B is block window on shared memory
|
||||
// C is block distributed tensor
|
||||
template <typename Problem_, typename BlockPolicy_>
|
||||
struct BlockFlatmmASmemBSmemCRegV1
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using BlockPolicy = remove_cvref_t<BlockPolicy_>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; // TileFlatmmShape
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
static constexpr auto I1 = number<1>();
|
||||
static constexpr auto I2 = number<2>();
|
||||
static constexpr auto idxM = I0;
|
||||
static constexpr auto idxN = I1;
|
||||
static constexpr auto idxK = I2;
|
||||
using BlockTile = remove_cvref_t<typename BlockGemmShape::BlockTile>;
|
||||
using BlockWarps = remove_cvref_t<typename BlockGemmShape::BlockWarps>;
|
||||
using WarpTile = remove_cvref_t<typename BlockGemmShape::WarpTile>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
|
||||
{
|
||||
constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
|
||||
constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
|
||||
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
return c_block_tensor;
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ABlockWindow, typename BFlatBlockTensor>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
ABlockWindow& a_warp_windows,
|
||||
BFlatBlockTensor& b_warp_tensor) const
|
||||
{
|
||||
constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp =
|
||||
BlockTile::at(idxN) / (WarpTile::at(idxN) * BlockWarps::at(idxN));
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block window
|
||||
const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user