Merge commit '5d6d236b255b4ef9c8f38e1bd35975acda0af19a' into develop

This commit is contained in:
assistant-librarian[bot]
2025-08-07 14:15:40 +00:00
parent d9959414c5
commit be60fd573b
46 changed files with 1788 additions and 806 deletions

View File

@@ -27,6 +27,7 @@
#include "ck_tile/core/container/thread_buffer.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/e8m0.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/int8.hpp"
@@ -74,6 +75,7 @@
#include "ck_tile/core/utility/literals.hpp"
#include "ck_tile/core/utility/magic_div.hpp"
#include "ck_tile/core/utility/philox_rand.hpp"
#include "ck_tile/core/utility/print.hpp"
#include "ck_tile/core/utility/random.hpp"
#include "ck_tile/core/utility/reduce_operator.hpp"
#include "ck_tile/core/utility/static_counter.hpp"

View File

@@ -9,6 +9,7 @@
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/utility/magic_div.hpp"
#include "ck_tile/core/utility/print.hpp"
namespace ck_tile {
@@ -139,20 +140,19 @@ struct pass_through : public base_transform<1, 1>
{
return make_tuple(low_vector_lengths, low_vector_strides);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("pass_through{");
//
printf("up_lengths_:");
print(up_lengths_);
//
printf("}");
}
};
template <typename LowLength>
CK_TILE_HOST_DEVICE static void print(const pass_through<LowLength>& pt)
{
printf("pass_through{");
printf("up_lengths_: ");
print(pt.get_upper_lengths());
printf("}");
}
template <typename LowLength,
typename LeftPadLength,
typename RightPadLength,
@@ -229,29 +229,25 @@ struct pad : public base_transform<1, 1>
ck_tile::is_known_at_compile_time<LeftPadLength>::value &&
ck_tile::is_known_at_compile_time<RightPadLength>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("pad{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("left_pad_length_: ");
print(left_pad_length_);
printf(", ");
//
printf("right_pad_length_: ");
print(right_pad_length_);
printf("}");
}
};
template <typename LowLength,
typename LeftPadLength,
typename RightPadLength,
bool SkipIsValidCheck>
CK_TILE_HOST_DEVICE static void
print(const pad<LowLength, LeftPadLength, RightPadLength, SkipIsValidCheck>& p)
{
printf("pad{");
printf("up_lengths_: ");
print(p.up_lengths_);
printf(", left_pad_length_: ");
print(p.left_pad_length_);
printf(", right_pad_length_: ");
print(p.right_pad_length_);
printf("}");
}
template <typename LowLength, typename LeftPadLength, bool SkipIsValidCheck = false>
struct left_pad
{
@@ -330,24 +326,20 @@ struct left_pad
// It's up to runtime to check the padding length should be multiple of vector length
return make_tuple(low_vector_lengths, low_vector_strides);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("left_pad{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("left_pad_length_: ");
print(left_pad_length_);
printf("}");
}
};
template <typename LowLength, typename LeftPadLength, bool SkipIsValidCheck>
CK_TILE_HOST_DEVICE static void
print(const left_pad<LowLength, LeftPadLength, SkipIsValidCheck>& lp)
{
printf("left_pad{");
printf("up_lengths_: ");
print(lp.up_lengths_);
printf(", left_pad_length_: ");
print(lp.left_pad_length_);
printf("}");
}
template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck = false>
struct right_pad : public base_transform<1, 1>
{
@@ -430,24 +422,20 @@ struct right_pad : public base_transform<1, 1>
// It's up to runtime to check the padding length should be multiple of vector length
return make_tuple(low_vector_lengths, low_vector_strides);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("right_pad{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("right_pad_length_: ");
print(right_pad_length_);
printf("}");
}
};
template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck>
CK_TILE_HOST_DEVICE static void
print(const right_pad<LowLength, RightPadLength, SkipIsValidCheck>& rp)
{
printf("right_pad{");
printf("up_lengths_: ");
print(rp.up_lengths_);
printf(", right_pad_length_: ");
print(rp.right_pad_length_);
printf("}");
}
// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1]
// UpLengths and Coefficients can be either of the followings:
// 1) Tuple of index_t, which is known at run-time, or
@@ -532,24 +520,19 @@ struct embed : public base_transform<1, UpLengths::size()>
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
ck_tile::is_known_at_compile_time<Coefficients>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("embed{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("coefficients_: ");
print(coefficients_);
printf("}");
}
};
template <typename UpLengths, typename Coefficients>
CK_TILE_HOST_DEVICE static void print(const embed<UpLengths, Coefficients>& e)
{
printf("embed{");
printf("up_lengths_: ");
print(e.up_lengths_);
printf(", coefficients_: ");
print(e.coefficients_);
printf("}");
}
template <typename LowLengths>
struct lambda_merge_generate_MagicDivision_calculate_magic_divisor
{
@@ -699,24 +682,19 @@ struct merge_v2_magic_division : public base_transform<LowLengths::size(), 1>
return make_tuple(up_vector_lengths, up_vector_strides);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("merge_v2_magic_division{");
//
printf("low_lengths_ ");
print(low_lengths_);
printf(", ");
//
printf("up_lengths_ ");
print(up_lengths_);
printf("}");
}
};
template <typename LowLengths>
CK_TILE_HOST_DEVICE static void print(const merge_v2_magic_division<LowLengths>& m)
{
printf("merge_v2_magic_division{");
printf("low_lengths_: ");
print(m.low_lengths_);
printf(", up_lengths_: ");
print(m.up_lengths_);
printf("}");
}
// Implementation of "merge" transformation primitive that uses division and mod. It is supposed to
// be used for low_lengths that are known at compile time and are power of 2, otherwise performance
// will be very bad
@@ -830,29 +808,21 @@ struct merge_v3_division_mod : public base_transform<LowLengths::size(), 1>
return make_tuple(up_vector_lengths, up_vector_strides);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("Merge_v3_direct_division_mod{");
//
printf("low_lengths_ ");
print(low_lengths_);
printf(", ");
//
printf("low_lengths_scan_ ");
print(low_lengths_scan_);
printf(", ");
//
printf("up_lengths_ ");
print(up_lengths_);
printf("}");
}
};
template <typename LowLengths>
CK_TILE_HOST_DEVICE static void print(const merge_v3_division_mod<LowLengths>& m)
{
printf("merge_v3_division_mod{");
printf("low_lengths_: ");
print(m.low_lengths_);
printf(", low_lengths_scan_: ");
print(m.low_lengths_scan_);
printf(", up_lengths_: ");
print(m.up_lengths_);
printf("}");
}
template <typename UpLengths, bool Use24BitIntegerCalculation>
struct unmerge : public base_transform<1, UpLengths::size()>
{
@@ -958,24 +928,19 @@ struct unmerge : public base_transform<1, UpLengths::size()>
return make_tuple(up_vector_lengths, up_vector_strides);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("unmerge{");
//
printf("up_lengths_");
print(up_lengths_);
printf(", ");
//
printf("up_lengths_scan_");
print(up_lengths_scan_);
printf("}");
}
};
template <typename UpLengths, bool Use24BitIntegerCalculation>
CK_TILE_HOST_DEVICE static void print(const unmerge<UpLengths, Use24BitIntegerCalculation>& u)
{
printf("unmerge{");
printf("up_lengths_: ");
print(u.up_lengths_);
printf(", up_lengths_scan_: ");
print(u.up_lengths_scan_);
printf("}");
}
template <typename LowerIndex>
struct freeze : public base_transform<1, 0>
{
@@ -1023,19 +988,17 @@ struct freeze : public base_transform<1, 0>
{
return ck_tile::is_known_at_compile_time<LowerIndex>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("freeze{");
//
printf("low_idx_: ");
print(low_idx_);
printf("}");
}
};
template <typename LowerIndex>
CK_TILE_HOST_DEVICE static void print(const freeze<LowerIndex>& f)
{
printf("freeze{");
printf("low_idx_: ");
print(f.low_idx_);
printf("}");
}
// insert a dangling upper dimension without lower dimension
template <typename UpperLength>
struct insert : public base_transform<0, 1>
@@ -1092,18 +1055,17 @@ struct insert : public base_transform<0, 1>
{
return ck_tile::is_known_at_compile_time<UpperLength>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("insert{");
//
print(up_lengths_);
printf("}");
}
};
template <typename UpperLength>
CK_TILE_HOST_DEVICE static void print(const insert<UpperLength>& i)
{
printf("insert{");
printf("up_lengths_: ");
print(i.up_lengths_);
printf("}");
}
// replicate the original tensor and create a higher dimensional tensor
template <typename UpLengths>
struct replicate : public base_transform<0, UpLengths::size()>
@@ -1152,21 +1114,19 @@ struct replicate : public base_transform<0, UpLengths::size()>
return ck_tile::is_known_at_compile_time<UpLengths>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("replicate{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf("}");
}
//
UpLengths up_lengths_;
};
template <typename UpLengths>
CK_TILE_HOST_DEVICE static void print(const replicate<UpLengths>& r)
{
printf("replicate{");
printf("up_lengths_: ");
print(r.up_lengths_);
printf("}");
}
template <typename LowLength, typename SliceBegin, typename SliceEnd>
struct slice : public base_transform<1, 1>
{
@@ -1238,28 +1198,20 @@ struct slice : public base_transform<1, 1>
ck_tile::is_known_at_compile_time<SliceBegin>::value &&
ck_tile::is_known_at_compile_time<SliceEnd>::value;
}
};
CK_TILE_HOST_DEVICE void print() const
{
printf("slice{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("slice_begin_: ");
print(slice_begin_);
printf(", ");
//
printf("slice_end_: ");
print(slice_end_);
printf("}");
} // namespace ck
}; // namespace ck
template <typename LowLength, typename SliceBegin, typename SliceEnd>
CK_TILE_HOST_DEVICE static void print(const slice<LowLength, SliceBegin, SliceEnd>& s)
{
printf("slice{");
printf("up_lengths_: ");
print(s.up_lengths_);
printf(", slice_begin_: ");
print(s.slice_begin_);
printf(", slice_end_: ");
print(s.slice_end_);
printf("}");
}
/*
* \brief lower_idx = upper_idx % modulus.
@@ -1328,19 +1280,19 @@ struct modulo : public base_transform<1, 1>
{
return ck_tile::is_known_at_compile_time<UpLengths>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("Modulus{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf("}");
}
};
template <typename Modulus, typename UpLength>
CK_TILE_HOST_DEVICE static void print(const modulo<Modulus, UpLength>& m)
{
printf("modulo{");
printf("modulus_: ");
print(m.modulus_);
printf(", up_lengths_: ");
print(m.up_lengths_);
printf("}");
}
// 2D XOR, NOTE: "xor" is a keyword
template <typename LowLengths>
struct xor_t : public base_transform<2, 2>
@@ -1424,20 +1376,17 @@ struct xor_t : public base_transform<2, 2>
return make_tuple(up_vector_lengths, up_vector_strides);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("xor_t{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
printf("}");
}
};
template <typename LowLengths>
CK_TILE_HOST_DEVICE static void print(const xor_t<LowLengths>& x)
{
printf("xor_t{");
printf("up_lengths_: ");
print(x.up_lengths_);
printf("}");
}
template <typename LowLength, typename OffsetLength>
struct offset : public base_transform<1, 1>
{
@@ -1509,24 +1458,19 @@ struct offset : public base_transform<1, 1>
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
ck_tile::is_known_at_compile_time<OffsetLength>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("offset{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("offset_length_: ");
print(offset_length_);
printf("}");
}
};
template <typename LowLength, typename OffsetLength>
CK_TILE_HOST_DEVICE static void print(const offset<LowLength, OffsetLength>& o)
{
printf("offset{");
printf("up_lengths_: ");
print(o.up_lengths_);
printf(", offset_length_: ");
print(o.offset_length_);
printf("}");
}
template <typename UpLength, typename IndexingAdaptor>
struct indexing : public base_transform<1, 1>
{
@@ -1595,20 +1539,19 @@ struct indexing : public base_transform<1, 1>
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
IndexingAdaptor::is_known_at_compile_time();
}
CK_TILE_HOST_DEVICE void print() const
{
printf("embed{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
printf("}");
}
};
template <typename UpLength, typename IndexingAdaptor>
CK_TILE_HOST_DEVICE static void print(const indexing<UpLength, IndexingAdaptor>& i)
{
printf("indexing{");
printf("up_lengths_: ");
print(i.up_lengths_);
printf(", iadaptor_: ");
print(i.iadaptor_);
printf("}");
}
//*******************************************************************************************************
template <typename LowLength>

View File

@@ -77,6 +77,7 @@
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
#include "ck_tile/core/utility/print.hpp"
namespace ck_tile {
@@ -317,4 +318,51 @@ struct TileDistributionEncodingPattern2D<BlockSize,
}
};
// Helper function to convert enum to string
constexpr const char* tile_distribution_pattern_to_string(tile_distribution_pattern pattern)
{
switch(pattern)
{
case tile_distribution_pattern::thread_raked: return "thread_raked";
case tile_distribution_pattern::warp_raked: return "warp_raked";
case tile_distribution_pattern::block_raked: return "block_raked";
default: return "unknown";
}
}
template <index_t BlockSize,
index_t YPerTile,
index_t XPerTile,
index_t VecSize,
tile_distribution_pattern DistributionPattern,
index_t NumWaveGroups>
CK_TILE_HOST_DEVICE void print(const TileDistributionEncodingPattern2D<BlockSize,
YPerTile,
XPerTile,
VecSize,
DistributionPattern,
NumWaveGroups>&)
{
using PatternType = TileDistributionEncodingPattern2D<BlockSize,
YPerTile,
XPerTile,
VecSize,
DistributionPattern,
NumWaveGroups>;
printf("TileDistributionEncodingPattern2D<BlockSize:%d, YPerTile:%d, XPerTile:%d, "
"VecSize:%d, %s>: ",
BlockSize,
YPerTile,
XPerTile,
VecSize,
tile_distribution_pattern_to_string(DistributionPattern));
printf("{<Y0, Y1, Y2>: <%d, %d, %d>, <X0, X1>: <%d, %d>}\n",
PatternType::Y0,
PatternType::Y1,
PatternType::Y2,
PatternType::X0,
PatternType::X1);
}
} // namespace ck_tile

View File

@@ -218,4 +218,19 @@ CK_TILE_HOST_DEVICE constexpr index_t get_smem_capacity()
#endif
}
/// Helper function to convert address space enum to string
CK_TILE_HOST_DEVICE constexpr const char* address_space_to_string(address_space_enum addr_space)
{
switch(addr_space)
{
case address_space_enum::generic: return "generic";
case address_space_enum::global: return "global";
case address_space_enum::lds: return "lds";
case address_space_enum::sgpr: return "sgpr";
case address_space_enum::constant: return "constant";
case address_space_enum::vgpr: return "vgpr";
default: return "unknown";
}
}
} // namespace ck_tile

View File

@@ -177,9 +177,27 @@ struct array<T, 0>
CK_TILE_HOST_DEVICE constexpr array() {}
CK_TILE_HOST_DEVICE static constexpr index_t size() { return 0; }
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return is_static_v<T>; };
CK_TILE_HOST_DEVICE void print() const { printf("array{size: 0, data: []}"); }
};
template <typename T, index_t N>
CK_TILE_HOST_DEVICE static void print(const array<T, N>& a)
{
printf("array{size: %ld, data: [", static_cast<long>(N));
for(index_t i = 0; i < N; ++i)
{
if(i > 0)
printf(", ");
print(a[i]);
}
printf("]}");
}
template <typename T>
CK_TILE_HOST_DEVICE static void print(const array<T, 0>&)
{
printf("array{size: 0, data: []}");
}
template <typename, typename>
struct vector_traits;

View File

@@ -139,26 +139,21 @@ struct map
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
CK_TILE_HOST_DEVICE constexpr iterator end() { return iterator{impl_, size_}; }
CK_TILE_HOST_DEVICE void print() const
{
printf("map{size_: %d, ", size_);
//
printf("impl_: [");
//
for(const auto& [k, d] : *this)
{
printf("{key: ");
print(k);
printf(", data: ");
print(d);
printf("}, ");
}
//
printf("]");
//
printf("}");
}
};
template <typename key, typename data, index_t max_size>
CK_TILE_HOST_DEVICE static void print(const map<key, data, max_size>& m)
{
printf("map{size_: %d, impl_: [", m.size_);
for(const auto& [k, d] : m)
{
printf("{key: ");
print(k);
printf(", data: ");
print(d);
printf("}, ");
}
printf("]}");
}
} // namespace ck_tile

View File

@@ -9,13 +9,10 @@
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/to_sequence.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/print.hpp"
namespace ck_tile {
template <index_t, index_t, index_t>
struct static_for;
template <index_t...>
struct sequence;
@@ -196,15 +193,24 @@ struct sequence
{
return sequence<f(Is)...>{};
}
CK_TILE_HOST_DEVICE static void print()
{
printf("sequence{size: %d, data: [", size());
((printf("%d ", Is)), ...);
printf("]}");
}
};
template <index_t... Is>
CK_TILE_HOST_DEVICE static void print(const sequence<Is...>&)
{
printf("sequence<");
if constexpr(sizeof...(Is) > 0)
{
bool first = true;
(([&first](index_t value) {
printf("%s%d", first ? "" : ", ", value);
first = false;
}(Is)),
...);
}
printf(">");
}
namespace impl {
template <typename T, T... Ints>
struct __integer_sequence;

View File

@@ -300,12 +300,29 @@ struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
#undef TP_COM_
};
template <typename, typename = void>
template <typename... T>
CK_TILE_HOST_DEVICE void print(const tuple<T...>& t)
{
printf("tuple<");
if constexpr(sizeof...(T) > 0)
{
bool first = true;
static_for<0, sizeof...(T), 1>{}([&t, &first](auto i) {
if(!first)
printf(", ");
print(t.get(i));
first = false;
});
}
printf(">");
}
template <typename, typename>
struct vector_traits;
// specialization for array
template <typename... T>
struct vector_traits<tuple<T...>>
struct vector_traits<tuple<T...>, void>
{
using scalar_type = __type_pack_element<0, T...>;
static constexpr index_t vector_size = sizeof...(T);

View File

@@ -0,0 +1,102 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/mxfp_convert.hpp"
namespace ck_tile {
/**
* @brief Unsigned representation of a conventional biased Float32 exponent.
*
* bias = 127;
*
* E8M0_1 = 0b01111111; => 2^(127-127) = 1
* E8M0_2 = 0b10000000; => 2^(128-127) = 2^1 = 2
* E8M0_3 = 0b10000010; => 2^(130-127) = 2^3 = 8
* E8M0_135 = 0b10000111; => 2^(135-127) = 2^8 = 256
* E8M0_142 = 0b10001110; => 2^(142-127) = 2^15 = 32768
* E8M0_MIN = 0b00000000; => 2^-127
* E8M0_MAX = 0b11111110; => 2^127
* E8M0_NAN = 0b11111111; => NaN
*/
struct e8m0_bexp_t
{
using raw_type = uint8_t;
using type = raw_type;
raw_type data;
CK_TILE_HOST_DEVICE constexpr e8m0_bexp_t() : data{type{0b11111111}} {}
CK_TILE_HOST_DEVICE explicit constexpr e8m0_bexp_t(type init) : data{init} {}
CK_TILE_HOST_DEVICE explicit constexpr e8m0_bexp_t(float scale)
: e8m0_bexp_t(static_cast<type>(numeric_utils<float>::get_exponent(scale)))
{
}
CK_TILE_HOST_DEVICE constexpr operator type() const { return data; }
CK_TILE_HOST_DEVICE constexpr raw_type& get() { return data; }
CK_TILE_HOST_DEVICE constexpr raw_type get() const { return data; }
CK_TILE_HOST_DEVICE constexpr operator float() const;
constexpr bool operator==(const e8m0_bexp_t& other) const { return data == other.data; }
constexpr bool operator!=(const e8m0_bexp_t& other) const { return data != other.data; }
};
using e8m0_t = e8m0_bexp_t;
using e8m0_raw_t = typename e8m0_t::raw_type;
template <>
struct numeric_traits<e8m0_t>
{
using bitwise_type = e8m0_raw_t;
static constexpr int exp = 8;
static constexpr int mant = 0;
static constexpr int bias = 127;
static constexpr int PackedSize = 1;
};
// limits
template <class T>
struct numeric;
template <>
struct numeric<e8m0_t>
{
static constexpr e8m0_raw_t binary_min = 0b00000000; // 2^-127
static constexpr e8m0_raw_t binary_max = 0b11111110; // 2^127
static constexpr e8m0_raw_t binary_nan = 0b11111111;
CK_TILE_HOST_DEVICE static constexpr e8m0_t min() { return e8m0_t{binary_min}; }
CK_TILE_HOST_DEVICE static constexpr e8m0_t max() { return e8m0_t{binary_max}; }
CK_TILE_HOST_DEVICE static constexpr e8m0_t quiet_NaN() { return e8m0_t{binary_nan}; }
CK_TILE_HOST_DEVICE static constexpr e8m0_t signaling_NaN() { return e8m0_t{binary_nan}; }
CK_TILE_HOST_DEVICE static constexpr bool has_inf() { return false; }
CK_TILE_HOST_DEVICE static constexpr e8m0_t epsilon() { return signaling_NaN(); }
CK_TILE_HOST_DEVICE static constexpr e8m0_t round_error() { return signaling_NaN(); }
CK_TILE_HOST_DEVICE static constexpr e8m0_t zero() { return signaling_NaN(); }
CK_TILE_HOST_DEVICE static constexpr e8m0_t infinity() { return signaling_NaN(); }
};
CK_TILE_HOST_DEVICE constexpr e8m0_bexp_t::operator float() const
{
using traits = numeric_traits<float>;
if(data == numeric<e8m0_t>::binary_nan)
{
return traits::NaN;
}
else if(data == 0)
{
return std::numeric_limits<float>::min();
}
else
{
return bit_cast<float>(static_cast<traits::bitwise_type>(data) << traits::mant);
}
}
} // namespace ck_tile

View File

@@ -19,14 +19,18 @@ struct constant
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; }
};
template <auto v>
CK_TILE_HOST_DEVICE static void print(const constant<v>&)
{
printf("%ld", static_cast<long>(v));
}
template <typename T, T v>
struct integral_constant : constant<v>
{
using value_type = T;
using type = integral_constant; // using injected-class-name
static constexpr T value = v;
// constexpr CK_TILE_HOST_DEVICE operator value_type() const noexcept { return value; }
// constexpr CK_TILE_HOST_DEVICE value_type operator()() const noexcept { return value; } //
};
template <index_t v>

View File

@@ -12,15 +12,19 @@ struct numeric_utils : numeric_traits<T>
using traits = numeric_traits<T>;
using _numeric = numeric<T>;
using raw_type = typename T::raw_type;
using raw_type = typename traits::bitwise_type;
static constexpr int exp_mask = (1 << traits::exp) - 1;
static constexpr int get_exponent(raw_type x)
static constexpr raw_type get_exponent(raw_type x)
{
// TODO: check if repeated calls are optimized.
return (x >> traits::mant) & exp_mask;
}
static constexpr raw_type get_exponent(const T& x)
{
return get_exponent(bit_cast<raw_type>(x));
}
static constexpr bool is_positive(raw_type x)
{
return (x >> (traits::exp + traits::mant)) == _numeric::binary_zero;
@@ -33,7 +37,7 @@ struct numeric_utils : numeric_traits<T>
static constexpr double get_mantissa(raw_type x)
{
double mantissa = is_subnormal(x) ? 0.0f : 1.0f;
for(uint32_t i = 0; i < traits::mant; ++i)
for(raw_type i = 0; i < traits::mant; ++i)
{
mantissa += std::ldexp(static_cast<float>(x & 0b1), -(traits::mant - i));
x >>= 1;
@@ -43,22 +47,23 @@ struct numeric_utils : numeric_traits<T>
};
template <typename T>
CK_TILE_HOST_DEVICE float convert_to_float(typename T::raw_type data, int scale_exp = 127)
CK_TILE_HOST_DEVICE float convert_to_float(typename T::raw_type data, float scale = 1.f)
{
using utils = numeric_utils<T>;
static constexpr int e8m0_bias = 127; // TODO: make it generic.
float sign = utils::is_positive(data) ? 1.0 : -1.0;
int exp = (utils::is_subnormal(data) ? 1 : utils::get_exponent(data)) - utils::bias;
float mant = utils::get_mantissa(data);
using utils = numeric_utils<T>;
float sign = utils::is_positive(data) ? 1.0 : -1.0;
int exp = (utils::is_subnormal(data) ? 1 : utils::get_exponent(data)) - utils::bias;
float mant = utils::get_mantissa(data);
return std::ldexp(sign * mant, exp + scale_exp - e8m0_bias);
return std::ldexp(sign * mant * scale, exp);
}
template <typename T>
CK_TILE_HOST_DEVICE typename T::raw_type convert_to_type(float value)
CK_TILE_HOST_DEVICE typename T::raw_type convert_to_type(float value, float scale = 1.f)
{
using bitwise_type = typename numeric_traits<T>::bitwise_type;
value /= scale;
if(std::abs(value) > float(numeric<T>::max()))
{
float max_value = numeric<T>::max();

View File

@@ -23,14 +23,11 @@ using fp32x2_t = float __attribute__((ext_vector_type(2)));
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
CK_TILE_HOST_DEVICE constexpr uint8_t float_to_e2m1(float);
CK_TILE_HOST_DEVICE constexpr uint8_t float_to_e2m1(float x, float scale = 1.f);
// TODO: Add stochastic method
struct pk_float4_e2m1_t
{
static constexpr int exponent = 2;
static constexpr int mantissa = 1;
static constexpr int bias = 1;
// TODO: Can we merge raw_type and type?
using raw_type = uint8_t;
using type = raw_type;
@@ -41,18 +38,27 @@ struct pk_float4_e2m1_t
CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t(T init) : data{static_cast<type>(init)}
{
}
CK_TILE_HOST_DEVICE explicit constexpr pk_float4_e2m1_t(float init) : data{float_to_e2m1(init)}
CK_TILE_HOST_DEVICE explicit constexpr pk_float4_e2m1_t(float init, float scale = 1.f)
: data{float_to_e2m1(init, scale)}
{
}
CK_TILE_HOST_DEVICE constexpr operator type() const { return data; }
CK_TILE_HOST_DEVICE constexpr raw_type& get() { return data; }
CK_TILE_HOST_DEVICE constexpr raw_type get() const { return data; }
CK_TILE_HOST_DEVICE constexpr operator float() const;
CK_TILE_HOST_DEVICE constexpr operator fp32x2_t() const;
CK_TILE_HOST_DEVICE constexpr operator fp16_t() const;
CK_TILE_HOST_DEVICE constexpr operator fp16x2_t() const;
CK_TILE_HOST_DEVICE constexpr operator bf16_t() const;
CK_TILE_HOST_DEVICE constexpr operator bf16x2_t() const;
CK_TILE_HOST_DEVICE constexpr float to_float(float scale = 1.f) const;
CK_TILE_HOST_DEVICE constexpr fp32x2_t to_fp32x2(float scale = 1.f) const;
CK_TILE_HOST_DEVICE constexpr fp16_t to_fp16(float scale = 1.f) const;
CK_TILE_HOST_DEVICE constexpr fp16x2_t to_fp16x2(float scale = 1.f) const;
CK_TILE_HOST_DEVICE constexpr bf16_t to_bf16(float scale = 1.f) const;
CK_TILE_HOST_DEVICE constexpr bf16x2_t to_bf16x2(float scale = 1.f) const;
CK_TILE_HOST_DEVICE constexpr operator float() const { return to_float(); }
CK_TILE_HOST_DEVICE constexpr operator fp32x2_t() const { return to_fp32x2(); }
CK_TILE_HOST_DEVICE constexpr operator fp16_t() const { return to_fp16(); }
CK_TILE_HOST_DEVICE constexpr operator fp16x2_t() const { return to_fp16x2(); }
CK_TILE_HOST_DEVICE constexpr operator bf16_t() const { return to_bf16(); }
CK_TILE_HOST_DEVICE constexpr operator bf16x2_t() const { return to_bf16x2(); }
template <index_t I>
CK_TILE_HOST_DEVICE constexpr raw_type unpack(number<I>) const;
@@ -191,131 +197,160 @@ CK_TILE_DEVICE pk_fp4_raw_t _to_f4(T src, float scale = 1.0f)
} // namespace impl
#endif
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator bf16_t() const
CK_TILE_HOST_DEVICE constexpr bf16_t pk_fp4_t::to_bf16(float scale) const
{
#if CK_TILE_FP4_CVT_DEVICE
return impl::_from_f4<bf16_t>(data);
return impl::_from_f4<bf16_t>(data, scale);
#else
return bf16_t{type_convert<bf16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{})))};
return bf16_t{type_convert<bf16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{}), scale))};
#endif
}
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator bf16x2_t() const
CK_TILE_HOST_DEVICE constexpr bf16x2_t pk_fp4_t::to_bf16x2(float scale) const
{
#if CK_TILE_FP4_CVT_DEVICE
return impl::_from_f4<bf16x2_t>(data);
return impl::_from_f4<bf16x2_t>(data, scale);
#else
return bf16x2_t{type_convert<bf16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{}))),
type_convert<bf16_t>(convert_to_float<pk_fp4_t>(unpack(number<1>{})))};
return bf16x2_t{type_convert<bf16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{}), scale)),
type_convert<bf16_t>(convert_to_float<pk_fp4_t>(unpack(number<1>{}), scale))};
#endif
}
// TODO: make float_to_e2m1 generic so that we can convert from directrly.
CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t float_to_e2m1(float x)
CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t float_to_e2m1(float x, float scale)
{
#if CK_TILE_FP4_CVT_DEVICE
return impl::_to_f4(x);
return impl::_to_f4(x, scale);
#else
return convert_to_type<pk_fp4_t>(x);
return convert_to_type<pk_fp4_t>(x, scale);
#endif
}
CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_to_fp32x2(const pk_fp4_t& x) { return fp32x2_t(x); }
CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_to_fp16x2(const pk_fp4_t& x) { return fp16x2_t(x); }
CK_TILE_HOST_DEVICE constexpr bf16x2_t pk_fp4_to_bf16x2(const pk_fp4_t& x) { return bf16x2_t(x); }
CK_TILE_HOST_DEVICE constexpr pk_fp4_t float_to_pk_fp4(const float& x) { return float_to_e2m1(x); }
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16_to_pk_fp4(const fp16_t& x)
CK_TILE_HOST_DEVICE constexpr pk_fp4_t float_to_pk_fp4(const float& x, float scale)
{
return float_to_e2m1(x, scale);
}
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16_to_pk_fp4(const fp16_t& x, float scale)
{
#if CK_TILE_FP4_CVT_DEVICE
return impl::_to_f4(x);
return impl::_to_f4(x, scale);
#else
return float_to_e2m1(type_convert<float>(x));
return float_to_e2m1(type_convert<float>(x), scale);
#endif
}
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16_to_pk_fp4(const bf16_t& x)
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16_to_pk_fp4(const bf16_t& x, float scale)
{
#if CK_TILE_FP4_CVT_DEVICE
return impl::_to_f4(x);
return impl::_to_f4(x, scale);
#else
return float_to_e2m1(type_convert<float>(x));
return float_to_e2m1(type_convert<float>(x), scale);
#endif
}
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t& x)
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t& x, float scale)
{
#if CK_TILE_FP4_CVT_DEVICE
return impl::_to_f4(x);
return impl::_to_f4(x, scale);
#else
return pk_fp4_t::pack(float_to_e2m1(type_convert<float>(x[0])),
float_to_e2m1(type_convert<float>(x[1])));
return pk_fp4_t::pack(float_to_e2m1(type_convert<float>(x[0]), scale),
float_to_e2m1(type_convert<float>(x[1]), scale));
#endif
}
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x)
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x, float scale)
{
#if CK_TILE_FP4_CVT_DEVICE
return impl::_to_f4(x);
return impl::_to_f4(x, scale);
#else
return pk_fp4_t::pack(float_to_e2m1(type_convert<float>(x[0])),
float_to_e2m1(type_convert<float>(x[1])));
return pk_fp4_t::pack(float_to_e2m1(type_convert<float>(x[0]), scale),
float_to_e2m1(type_convert<float>(x[1]), scale));
#endif
}
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x)
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x, float scale)
{
#if CK_TILE_FP4_CVT_DEVICE
return impl::_to_f4(x);
return impl::_to_f4(x, scale);
#else
return pk_fp4_t::pack(float_to_e2m1(x[0]), float_to_e2m1(x[1]));
return pk_fp4_t::pack(float_to_e2m1(x[0], scale), float_to_e2m1(x[1], scale));
#endif
}
CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_to_fp32x2(const pk_fp4_t& x, float scale)
{
return x.to_fp32x2(scale);
}
CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_to_fp16x2(const pk_fp4_t& x, float scale)
{
return x.to_fp16x2(scale);
}
CK_TILE_HOST_DEVICE constexpr bf16x2_t pk_fp4_to_bf16x2(const pk_fp4_t& x, float scale)
{
return x.to_bf16x2(scale);
}
CK_TILE_HOST_DEVICE constexpr float pk_fp4_to_float(const pk_fp4_t& x, float scale)
{
return x.to_float(scale);
}
CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_to_fp16(const pk_fp4_t& x, float scale)
{
return x.to_fp16(scale);
}
CK_TILE_HOST_DEVICE constexpr bf16_t pk_fp4_to_bf16(const pk_fp4_t& x, float scale)
{
return x.to_bf16(scale);
}
#if TEST_convert_with_table == 0
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator float() const
CK_TILE_HOST_DEVICE constexpr float pk_fp4_t::to_float(float scale) const
{
#if CK_TILE_FP4_CVT_DEVICE
return impl::_from_f4<fp32_t>(data);
return impl::_from_f4<fp32_t>(data, scale);
#else
return convert_to_float<pk_fp4_t>(unpack(number<0>{}));
return convert_to_float<pk_fp4_t>(unpack(number<0>{}), scale);
#endif
}
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp32x2_t() const
CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_t::to_fp32x2(float scale) const
{
#if CK_TILE_FP4_CVT_DEVICE
return impl::_from_f4<fp32x2_t>(data);
return impl::_from_f4<fp32x2_t>(data, scale);
#else
return fp32x2_t{convert_to_float<pk_fp4_t>(unpack(number<0>{})),
convert_to_float<pk_fp4_t>(unpack(number<1>{}))};
return fp32x2_t{convert_to_float<pk_fp4_t>(unpack(number<0>{}), scale),
convert_to_float<pk_fp4_t>(unpack(number<1>{}), scale)};
#endif
}
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp16_t() const
CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_t::to_fp16(float scale) const
{
#if CK_TILE_FP4_CVT_DEVICE
return impl::_from_f4<fp16_t>(data);
return impl::_from_f4<fp16_t>(data, scale);
#else
return fp16_t{type_convert<fp16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{})))};
return fp16_t{type_convert<fp16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{}), scale))};
#endif
}
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp16x2_t() const
CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const
{
#if CK_TILE_FP4_CVT_DEVICE
return impl::_from_f4<fp16x2_t>(data);
return impl::_from_f4<fp16x2_t>(data, scale);
#else
return fp16x2_t{type_convert<fp16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{}))),
type_convert<fp16_t>(convert_to_float<pk_fp4_t>(unpack(number<1>{})))};
return fp16x2_t{type_convert<fp16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{}), scale)),
type_convert<fp16_t>(convert_to_float<pk_fp4_t>(unpack(number<1>{}), scale))};
#endif
}
#else
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator float() const
CK_TILE_HOST_DEVICE constexpr float pk_fp4_t::to_float(float scale) const
{
return e2m1_to_fp32_table[data & 0xf];
return e2m1_to_fp32_table[unpack(number<0>{})] * scale;
}
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp32x2_t() const
CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_t::to_fp32x2(float scale) const
{
return fp32x2_t{e2m1_to_fp32_table[data & 0xf], e2m1_to_fp32_table[data >> 4]};
return fp32x2_t{e2m1_to_fp32_table[unpack(number<0>{})] * scale, e2m1_to_fp32_table[unpack(number<1>{}] * scale};
}
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp16_t() const
CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_t::to_fp16(float scale) const
{
return e2m1_to_fp16_table[data & 0xf];
return type_convert<float>(e2m1_to_fp16_table[unpack(number<0>{})]) * scale;
}
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp16x2_t() const
CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const
{
return fp16x2_t{e2m1_to_fp16_table[data & 0xf], e2m1_to_fp16_table[data >> 4]};
return fp16x2_t{
type_convert<fp16_t>(type_convert<float>(e2m1_to_fp16_table[unpack(number<0>{})]) * scale),
type_convert<fp16_t>(type_convert<float>(e2m1_to_fp16_table[unpack(number<1>{})]) * scale)};
}
#endif

View File

@@ -64,6 +64,7 @@ CK_TILE_TYPE_CONVERT(bf8_t, bf8, float, float)
CK_TILE_TYPE_CONVERT(float, float, int8_t, int8)
CK_TILE_TYPE_CONVERT(int8_t, int8, float, float)
#undef CK_TILE_TYPE_CONVERT
} // namespace ck_tile
@@ -71,16 +72,36 @@ CK_TILE_TYPE_CONVERT(int8_t, int8, float, float)
namespace ck_tile {
CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp32x2_t, fp32x2)
CK_TILE_TYPE_CONVERT(fp32x2_t, fp32x2, pk_fp4_t, pk_fp4)
CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp16x2_t, fp16x2)
CK_TILE_TYPE_CONVERT(fp16x2_t, fp16x2, pk_fp4_t, pk_fp4)
CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, bf16x2_t, bf16x2)
CK_TILE_TYPE_CONVERT(bf16x2_t, bf16x2, pk_fp4_t, pk_fp4)
CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, float, float)
CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, bf16_t, bf16)
CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp16_t, fp16)
#undef CK_TILE_TYPE_CONVERT
template <typename Y, typename X>
CK_TILE_HOST_DEVICE constexpr Y scaled_type_convert(X x, float scale);
#define CK_TILE_SCALED_TYPE_CONVERT(dtype_, dname_, stype_, sname_) \
template <> \
CK_TILE_HOST_DEVICE constexpr dtype_ scaled_type_convert<dtype_, stype_>(stype_ x, \
float scale) \
{ \
return sname_##_to_##dname_(x, scale); \
} \
template <> \
CK_TILE_HOST_DEVICE constexpr dtype_ type_convert<dtype_, stype_>(stype_ x) \
{ \
return sname_##_to_##dname_(x, 1.f); \
}
CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp32x2_t, fp32x2)
CK_TILE_SCALED_TYPE_CONVERT(fp32x2_t, fp32x2, pk_fp4_t, pk_fp4)
CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp16x2_t, fp16x2)
CK_TILE_SCALED_TYPE_CONVERT(fp16x2_t, fp16x2, pk_fp4_t, pk_fp4)
CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, bf16x2_t, bf16x2)
CK_TILE_SCALED_TYPE_CONVERT(bf16x2_t, bf16x2, pk_fp4_t, pk_fp4)
CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, float, float)
CK_TILE_SCALED_TYPE_CONVERT(float, float, pk_fp4_t, pk_fp4)
CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, bf16_t, bf16)
CK_TILE_SCALED_TYPE_CONVERT(bf16_t, bf16, pk_fp4_t, pk_fp4)
CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp16_t, fp16)
CK_TILE_SCALED_TYPE_CONVERT(fp16_t, fp16, pk_fp4_t, pk_fp4)
#undef CK_TILE_SCALED_TYPE_CONVERT
#endif
} // namespace ck_tile

View File

@@ -84,7 +84,7 @@ using ext_vector_t = typename impl::ext_vector<T, N>::type;
// by default, any type will result in a vector_size=1 with scalar_type=T traits.
// ... unless we have other vector_traits specialization
template <typename T, typename>
template <typename T, typename = void>
struct vector_traits
{
using scalar_type =
@@ -94,7 +94,7 @@ struct vector_traits
// specialization for ext_vector_type()
template <typename T, index_t N>
struct vector_traits<T __attribute__((ext_vector_type(N)))>
struct vector_traits<T __attribute__((ext_vector_type(N))), void>
{
using scalar_type = std::conditional_t<std::is_same_v<T, pk_int4_t>, int8_t, T>;
static constexpr index_t vector_size = N;

View File

@@ -210,28 +210,6 @@ struct buffer_view<address_space_enum::generic,
// FIXME: remove
CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
CK_TILE_HOST_DEVICE void print() const
{
printf("buffer_view{");
// AddressSpace
printf("AddressSpace: generic, ");
// p_data_
printf("p_data_: %p, ", static_cast<void*>(const_cast<remove_cvref_t<T>*>(p_data_)));
// buffer_size_
printf("buffer_size_: ");
print(buffer_size_);
printf(", ");
// invalid_element_value_
printf("invalid_element_value_: ");
print(invalid_element_value_);
printf("}");
}
};
// Address Space: Global
@@ -757,28 +735,6 @@ struct buffer_view<address_space_enum::global,
// FIXME: remove
CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
CK_TILE_HOST_DEVICE void print() const
{
printf("buffer_view{");
// AddressSpace
printf("AddressSpace: Global, ");
// p_data_
printf("p_data_: %p, ", static_cast<void*>(const_cast<remove_cvref_t<T>*>(p_data_)));
// buffer_size_
printf("buffer_size_: ");
print(buffer_size_);
printf(", ");
// invalid_element_value_
printf("invalid_element_value_: ");
print(invalid_element_value_);
printf("}");
}
};
// Address Space: LDS
@@ -1138,28 +1094,6 @@ struct buffer_view<address_space_enum::lds,
// FIXME: remove
CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
CK_TILE_HOST_DEVICE void print() const
{
printf("buffer_view{");
// AddressSpace
printf("AddressSpace: Lds, ");
// p_data_
printf("p_data_: %p, ", static_cast<void*>(const_cast<remove_cvref_t<T>*>(p_data_)));
// buffer_size_
printf("buffer_size_: ");
print(buffer_size_);
printf(", ");
// invalid_element_value_
printf("invalid_element_value_: ");
print(invalid_element_value_);
printf("}");
}
};
// Address Space: Vgpr
@@ -1313,28 +1247,6 @@ struct buffer_view<address_space_enum::vgpr,
// FIXME: remove
CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
CK_TILE_HOST_DEVICE void print() const
{
printf("buffer_view{");
// AddressSpace
printf("AddressSpace: Vgpr, ");
// p_data_
printf("p_data_: %p, ", static_cast<void*>(const_cast<remove_cvref_t<T>*>(p_data_)));
// buffer_size_
printf("buffer_size_: ");
print(buffer_size_);
printf(", ");
// invalid_element_value_
printf("invalid_element_value_: ");
print(invalid_element_value_);
printf("}");
}
};
template <address_space_enum BufferAddressSpace,
@@ -1360,4 +1272,25 @@ make_buffer_view(T* p, BufferSizeType buffer_size, X invalid_element_value)
p, buffer_size, invalid_element_value};
}
// Generalized print function for all buffer_view variants
template <address_space_enum BufferAddressSpace,
typename T,
typename BufferSizeType,
bool InvalidElementUseNumericalZeroValue,
amd_buffer_coherence_enum Coherence>
CK_TILE_HOST_DEVICE void print(const buffer_view<BufferAddressSpace,
T,
BufferSizeType,
InvalidElementUseNumericalZeroValue,
Coherence>& bv)
{
printf("buffer_view{AddressSpace: %s, p_data_: %p, buffer_size_: ",
address_space_to_string(BufferAddressSpace),
static_cast<void*>(const_cast<remove_cvref_t<T>*>(bv.p_data_)));
print(bv.buffer_size_);
printf(", invalid_element_value_: ");
print(bv.invalid_element_value_);
printf("}");
}
} // namespace ck_tile

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -53,10 +53,13 @@ struct is_null_tile_window<null_tile_window<T>> : public std::true_type
};
} // namespace impl
template <typename T>
constexpr bool is_null_tile_window_v = impl::is_null_tile_window<remove_cvref_t<T>>::value;
template <typename T>
CK_TILE_DEVICE constexpr auto is_null_tile_window(const T&)
{
return impl::is_null_tile_window<remove_cvref_t<T>>::value;
return is_null_tile_window_v<remove_cvref_t<T>>;
}
template <typename WindowLengths>

View File

@@ -305,42 +305,45 @@ struct tensor_adaptor
get_container_subset(vector_strides, top_dims));
}
CK_TILE_HOST_DEVICE void print() const
{
printf("tensor_adaptor{");
//
printf("transforms: ");
print(transforms_);
printf(", ");
//
printf("LowerDimensionHiddenIds: ");
print(LowerDimensionHiddenIdss{});
printf(", ");
//
printf("UpperDimensionHiddenIds: ");
print(UpperDimensionHiddenIdss{});
printf(", ");
//
printf("BottomDimensionHiddenIds: ");
print(BottomDimensionHiddenIds{});
printf(", ");
//
printf("TopDimensionHiddenIds: ");
print(TopDimensionHiddenIds{});
printf("}");
}
private:
Transforms transforms_;
ElementSize element_size_;
};
template <typename Transforms,
typename LowerDimensionHiddenIdss,
typename UpperDimensionHiddenIdss,
typename BottomDimensionHiddenIds,
typename TopDimensionHiddenIds>
CK_TILE_HOST_DEVICE static void print(const tensor_adaptor<Transforms,
LowerDimensionHiddenIdss,
UpperDimensionHiddenIdss,
BottomDimensionHiddenIds,
TopDimensionHiddenIds>& adaptor)
{
printf("tensor_adaptor{\n");
printf(" transforms: [");
print(adaptor.get_transforms());
printf("],\n");
printf(" LowerDimensionHiddenIds: [");
print(LowerDimensionHiddenIdss{});
printf("],\n");
printf(" UpperDimensionHiddenIds: [");
print(UpperDimensionHiddenIdss{});
printf("],\n");
printf(" BottomDimensionHiddenIds: [");
print(BottomDimensionHiddenIds{});
printf("],\n");
//
printf(" TopDimensionHiddenIds: [");
print(TopDimensionHiddenIds{});
printf("]\n}\n");
}
// Transforms: Tuple<transforms...>
// LowerDimensionOldTopIdss: Tuple<Sequence<...>, ...>
// UpperDimensionNewTopIdss: Tuple<Sequence<...>, ...>

View File

@@ -140,25 +140,37 @@ struct tensor_descriptor : public tensor_adaptor<Transforms,
to_array<index_t, ndim_hidden_>(GuaranteedVectorStrides{}));
}
CK_TILE_HOST_DEVICE void print() const
{
printf("tensor_descriptor{");
// tensor_adaptor
Base::print();
printf(", ");
// element_space_size_
printf("element_space_size_: ");
print(element_space_size_);
printf("}");
}
// TODO make these private
ElementSpaceSize element_space_size_;
};
template <typename Transforms,
typename LowerDimensionHiddenIdss,
typename UpperDimensionHiddenIdss,
typename TopDimensionHiddenIds,
typename ElementSpaceSize,
typename GuaranteedVectorLengths,
typename GuaranteedVectorStrides>
CK_TILE_HOST_DEVICE static void print(const tensor_descriptor<Transforms,
LowerDimensionHiddenIdss,
UpperDimensionHiddenIdss,
TopDimensionHiddenIds,
ElementSpaceSize,
GuaranteedVectorLengths,
GuaranteedVectorStrides>& descriptor)
{
printf("tensor_descriptor{\n");
// first print the tensor adaptor part of the descriptor using the base class print
print(static_cast<const typename decltype(descriptor)::Base&>(descriptor));
printf("element_space_size_: %ld,\n",
static_cast<long>(descriptor.get_element_space_size().value));
printf("guaranteed_vector_lengths: ");
print(GuaranteedVectorLengths{});
printf(",\nguaranteed_vector_strides: ");
print(GuaranteedVectorStrides{});
printf("}\n}\n");
}
template <typename Adaptor, typename ElementSpaceSize>
CK_TILE_HOST_DEVICE constexpr auto
make_tensor_descriptor_from_adaptor(const Adaptor& adaptor,

View File

@@ -228,24 +228,6 @@ struct tile_distribution
{
return PsYs2XsAdaptor::is_static() && Ys2DDescriptor::is_static();
}
CK_TILE_HOST_DEVICE void print() const
{
printf("tile_distribution{");
//
printf("tile_distribution_encoding: ");
print(DstrEncode{});
printf(", ");
//
printf("ps_ys_to_xs_: ");
print(ps_ys_to_xs_);
printf(", ");
//
printf("ys_to_d_: ");
print(ys_to_d_);
//
printf("}");
}
};
namespace detail {
@@ -710,4 +692,27 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
}
} // namespace detail
// Free print function for tile_distribution
template <typename PsYs2XsAdaptor_,
typename Ys2DDescriptor_,
typename StaticTileDistributionEncoding_,
typename TileDistributionDetail_>
CK_TILE_HOST_DEVICE void print(const tile_distribution<PsYs2XsAdaptor_,
Ys2DDescriptor_,
StaticTileDistributionEncoding_,
TileDistributionDetail_>& distribution)
{
printf("tile_distribution{");
printf("tile_distribution_encoding: ");
print(StaticTileDistributionEncoding_{});
printf(", ");
printf("ps_ys_to_xs_: ");
print(distribution.ps_ys_to_xs_);
printf(", ");
printf("ys_to_d_: ");
print(distribution.ys_to_d_);
printf("}\n");
}
} // namespace ck_tile

View File

@@ -428,109 +428,7 @@ struct tile_distribution_encoding
{
return get_sorted_info(get_uniformed_idx_y_to_h(), get_h_dim_lengths_prefix_sum());
}
CK_TILE_HOST_DEVICE void print() const
{
printf("tile_distribution_encoding::detail{");
//
printf("ndim_rh_major_: ");
print(ndim_rh_major_);
printf(", ");
//
printf("ndim_span_major_: ");
print(ndim_span_major_);
printf(", ");
//
printf("ndims_rhs_minor_: ");
print(ndims_rhs_minor_);
printf(", ");
//
printf("ndim_rh_major_: ");
print(ndim_rh_major_);
printf(", ");
//
printf("max_ndim_rh_minor_: ");
print(max_ndim_rh_minor_);
printf(", ");
//
printf("rhs_lengthss_: ");
print(rhs_lengthss_);
printf(", ");
//
printf("ys_lengths_: ");
print(ys_lengths_);
printf(", ");
//
printf("rhs_major_minor_to_ys_: ");
print(rhs_major_minor_to_ys_);
printf(", ");
//
printf("ndims_span_minor_: ");
print(ndims_span_minor_);
printf(", ");
//
printf("max_ndim_span_minor_: ");
print(max_ndim_span_minor_);
printf(", ");
//
printf("ys_to_span_major_: ");
print(ys_to_span_major_);
printf(", ");
//
printf("ys_to_span_minor_: ");
print(ys_to_span_minor_);
printf(", ");
//
printf("distributed_spans_lengthss_: ");
print(distributed_spans_lengthss_);
printf(", ");
//
printf("ndims_distributed_spans_minor_: ");
print(ndims_distributed_spans_minor_);
printf(", ");
//
printf("ps_over_rs_derivative_: ");
print(ps_over_rs_derivative_);
//
printf("}");
}
};
CK_TILE_HOST_DEVICE void print() const
{
printf("tile_distribution_encoding{");
//
printf("NDimX: %d, NDimP: %d, NDimY: %d, ", NDimX, NDimP, NDimY);
//
printf("rs_lengths_: ");
print(rs_lengths_);
printf(", ");
//
printf("hs_lengthss_: ");
print(hs_lengthss_);
printf(", ");
//
printf("ps_to_rhss_major_: ");
print(ps_to_rhss_major_);
printf(", ");
//
printf("ps_to_rhss_minor_: ");
print(ps_to_rhss_minor_);
printf(", ");
//
printf("ys_to_rhs_major_: ");
print(ys_to_rhs_major_);
printf(", ");
//
printf("ys_to_rhs_minor_: ");
print(ys_to_rhs_minor_);
printf(", ");
//
printf("detail: ");
print(detail{});
//
printf("}");
}
};
template <typename encoding, typename shuffle>
@@ -896,4 +794,106 @@ make_reduce_tile_distribution_encoding(InDstr, sequence<InReduceDimXs...> reduce
}
} // namespace detail
// Free print function for tile_distribution_encoding::detail
template <typename RsLengths_,
typename HsLengthss_,
typename Ps2RHssMajor_,
typename Ps2RHssMinor_,
typename Ys2RHsMajor_,
typename Ys2RHsMinor_>
CK_TILE_HOST_DEVICE void
print(const typename tile_distribution_encoding<RsLengths_,
HsLengthss_,
Ps2RHssMajor_,
Ps2RHssMinor_,
Ys2RHsMajor_,
Ys2RHsMinor_>::detail& detail_obj)
{
printf("tile_distribution_encoding::detail{");
printf("ndim_rh_major_: ");
print(detail_obj.ndim_rh_major_);
printf(", ");
printf("ndim_span_major_: ");
print(detail_obj.ndim_span_major_);
printf(", ");
printf("ndims_rhs_minor_: ");
print(detail_obj.ndims_rhs_minor_);
printf(", ");
printf("ndim_rh_major_: ");
print(detail_obj.ndim_rh_major_);
printf(", ");
printf("max_ndim_rh_minor_: ");
print(detail_obj.max_ndim_rh_minor_);
printf(", ");
printf("rhs_lengthss_: ");
print(detail_obj.rhs_lengthss_);
printf(", ");
printf("ys_lengths_: ");
print(detail_obj.ys_lengths_);
printf(", ");
printf("rhs_major_minor_to_ys_: ");
print(detail_obj.rhs_major_minor_to_ys_);
printf(", ");
printf("ndims_span_minor_: ");
print(detail_obj.ndims_span_minor_);
printf(", ");
printf("max_ndim_span_minor_: ");
print(detail_obj.max_ndim_span_minor_);
printf(", ");
printf("ys_to_span_major_: ");
print(detail_obj.ys_to_span_major_);
printf(", ");
printf("ys_to_span_minor_: ");
print(detail_obj.ys_to_span_minor_);
printf(", ");
printf("distributed_spans_lengthss_: ");
print(detail_obj.distributed_spans_lengthss_);
printf(", ");
printf("ndims_distributed_spans_minor_: ");
print(detail_obj.ndims_distributed_spans_minor_);
printf(", ");
printf("ps_over_rs_derivative_: ");
print(detail_obj.ps_over_rs_derivative_);
printf("}");
}
// Free print function for tile_distribution_encoding
template <typename RsLengths_,
typename HsLengthss_,
typename Ps2RHssMajor_,
typename Ps2RHssMinor_,
typename Ys2RHsMajor_,
typename Ys2RHsMinor_>
CK_TILE_HOST_DEVICE void print(const tile_distribution_encoding<RsLengths_,
HsLengthss_,
Ps2RHssMajor_,
Ps2RHssMinor_,
Ys2RHsMajor_,
Ys2RHsMinor_>& encoding)
{
printf("tile_distribution_encoding{");
printf("NDimX: %d, NDimP: %d, NDimY: %d, ", encoding.NDimX, encoding.NDimP, encoding.NDimY);
printf("rs_lengths_: ");
print(encoding.rs_lengths_);
printf(", ");
printf("hs_lengthss_: ");
print(encoding.hs_lengthss_);
printf(", ");
printf("ps_to_rhss_major_: ");
print(encoding.ps_to_rhss_major_);
printf(", ");
printf("ps_to_rhss_minor_: ");
print(encoding.ps_to_rhss_minor_);
printf(", ");
printf("ys_to_rhs_major_: ");
print(encoding.ys_to_rhs_major_);
printf(", ");
printf("ys_to_rhs_minor_: ");
print(encoding.ys_to_rhs_minor_);
printf(", ");
printf("}");
}
} // namespace ck_tile

View File

@@ -0,0 +1,76 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
namespace ck_tile {
/// Declare a ck_tile::print() interface that gets specialized in each header file for types that
/// can be printed.
template <typename T>
CK_TILE_HOST_DEVICE void print(const T&)
{
static_assert(sizeof(T) == 0,
"No print implementation available for this type. Please specialize "
"ck_tile::print for your type.");
}
/// Specialization for int
template <>
CK_TILE_HOST_DEVICE void print(const int& value)
{
printf("%d", value);
}
/// Specialization for float
template <>
CK_TILE_HOST_DEVICE void print(const float& value)
{
printf("%f", value);
}
/// Specialization for double
template <>
CK_TILE_HOST_DEVICE void print(const double& value)
{
printf("%f", value);
}
/// Specialization for long
template <>
CK_TILE_HOST_DEVICE void print(const long& value)
{
printf("%ld", value);
}
/// Specialization for unsigned int
template <>
CK_TILE_HOST_DEVICE void print(const unsigned int& value)
{
printf("%u", value);
}
/// Specialization for char
template <>
CK_TILE_HOST_DEVICE void print(const char& value)
{
printf("%c", value);
}
/// Specialization for array
template <typename T, size_t N>
CK_TILE_HOST_DEVICE void print(const T (&value)[N])
{
printf("[");
for(size_t i = 0; i < N; ++i)
{
if(i > 0)
printf(", ");
print(value[i]); // Recursively call print for each element
}
printf("]");
}
} // namespace ck_tile

View File

@@ -409,7 +409,13 @@ struct HostTensor
}
// void SetZero() { ck_tile::ranges::fill<T>(mData, 0); }
void SetZero() { std::fill(mData.begin(), mData.end(), 0); }
void SetZero()
{
if constexpr(std::is_same_v<T, e8m0_t>)
std::fill(mData.begin(), mData.end(), e8m0_t{1.f});
else
std::fill(mData.begin(), mData.end(), 0);
}
template <typename F>
void ForEach_impl(F&& f, std::vector<size_t>& idx, size_t rank)

View File

@@ -24,8 +24,8 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp"

View File

@@ -52,8 +52,6 @@ struct FmhaBwdDQDKDVKernel
using BiasGradDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::BiasGradDataType>;
static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
@@ -85,8 +83,6 @@ struct FmhaBwdDQDKDVKernel
#define _TS_ std::to_string
auto pn = [&] () {
std::string n;
if (kPadSeqLenQ) n += "s";
if (kPadSeqLenK) n += "sk";
if (kPadHeadDimQ) n += "d";
if (kPadHeadDimV) n += "dv";
return n.empty() ? n : std::string("p") + n; }();
@@ -100,7 +96,7 @@ struct FmhaBwdDQDKDVKernel
"r" + _TS_(gbr4::at(ck_tile::number<0>{})) + "x" + _TS_(gbr4::at(ck_tile::number<1>{})) + "x" + _TS_(gbr4::at(ck_tile::number<2>{})) + "_" +
"w" + _TS_(gwt0::at(ck_tile::number<0>{})) + "x" + _TS_(gwt0::at(ck_tile::number<1>{})) + "x" + _TS_(gwt0::at(ck_tile::number<2>{})) + "_" +
"w" + _TS_(gwt1::at(ck_tile::number<0>{})) + "x" + _TS_(gwt1::at(ck_tile::number<1>{})) + "x" + _TS_(gwt1::at(ck_tile::number<2>{})) + "_" +
("o" + _TS_(kBlockPerCu) + "_") + _SS_(FmhaPipeline::name) + (pn.empty() ? "_npad" : "_" + pn) +
("o" + _TS_(kBlockPerCu)) + (pn.empty() ? "_npad" : "_" + pn) +
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasBiasGrad ? "_dbias" : "_ndbias") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kHasDropout ? "_dropout" : "_ndropout" ) +
(kIsStoreRandval ? "_storerandval" : "" ) + (kIsDeterministic ? "_deterministic" : "_ndeterministic" );
@@ -1221,7 +1217,7 @@ struct FmhaBwdDQDKDVKernel
const auto q_dram = pad_tensor_view(
q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
sequence<false, kPadHeadDimQ>{});
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
k_ptr,
@@ -1232,7 +1228,7 @@ struct FmhaBwdDQDKDVKernel
const auto k_dram = pad_tensor_view(
k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kQKHeaddim>{}),
sequence<kPadSeqLenK, kPadHeadDimQ>{});
sequence<false, kPadHeadDimQ>{});
const auto v_dram = [&]() {
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
@@ -1244,22 +1240,15 @@ struct FmhaBwdDQDKDVKernel
return pad_tensor_view(
v_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kVHeaddim>{}),
sequence<kPadSeqLenK, kPadHeadDimV>{});
sequence<false, kPadHeadDimV>{});
}();
const auto lse_dram = [&]() {
const auto lse_dram_naive = make_naive_tensor_view_packed<address_space_enum::global>(
lse_ptr, make_tuple(kargs.seqlen_q), number<1>{});
return pad_tensor_view(
lse_dram_naive, make_tuple(number<FmhaPipeline::kM0>{}), sequence<kPadSeqLenQ>{});
}();
// lse and d should be fine to read unpaded data as they are not on the reduction dimension
const auto lse_dram = make_naive_tensor_view_packed<address_space_enum::global>(
lse_ptr, make_tuple(kargs.seqlen_q), number<FmhaPipeline::kM0>{});
const auto d_dram = [&]() {
const auto d_dram_naive = make_naive_tensor_view_packed<address_space_enum::global>(
d_ptr, make_tuple(kargs.seqlen_q), number<1>{});
return pad_tensor_view(
d_dram_naive, make_tuple(number<FmhaPipeline::kM0>{}), sequence<kPadSeqLenQ>{});
}();
const auto d_dram = make_naive_tensor_view_packed<address_space_enum::global>(
d_ptr, make_tuple(kargs.seqlen_q), number<FmhaPipeline::kM0>{});
const auto do_dram_naive = make_naive_tensor_view<address_space_enum::global>(
do_ptr,
@@ -1270,7 +1259,7 @@ struct FmhaBwdDQDKDVKernel
const auto do_dram = pad_tensor_view(
do_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kVHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimV>{});
sequence<false, kPadHeadDimV>{});
auto q_dram_window = make_tile_window(
q_dram,
@@ -1313,7 +1302,7 @@ struct FmhaBwdDQDKDVKernel
return pad_tensor_view(
dq_acc_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
sequence<false, kPadHeadDimQ>{});
}();
return make_tile_window(
@@ -1341,7 +1330,7 @@ struct FmhaBwdDQDKDVKernel
return pad_tensor_view(
dq_acc_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
sequence<false, kPadHeadDimQ>{});
}();
return make_tile_window(
@@ -1376,9 +1365,8 @@ struct FmhaBwdDQDKDVKernel
number<FmhaPipeline::kAlignmentBias>{},
number<1>{});
return pad_tensor_view(bias_dram_naive,
bias_dram_window_lengths,
sequence<kPadSeqLenQ, kPadSeqLenK>{});
return pad_tensor_view(
bias_dram_naive, bias_dram_window_lengths, sequence<false, true>{});
}();
return make_tile_window(bias_dram, bias_dram_window_lengths, {0, i_n0});
@@ -1406,9 +1394,8 @@ struct FmhaBwdDQDKDVKernel
number<FmhaPipeline::kAlignmentBias>{},
number<1>{});
return pad_tensor_view(dbias_dram_naive,
bias_dram_window_lengths,
sequence<kPadSeqLenQ, kPadSeqLenK>{});
return pad_tensor_view(
dbias_dram_naive, bias_dram_window_lengths, sequence<false, true>{});
}();
return make_tile_window(dbias_dram, bias_dram_window_lengths, {0, i_n0});
@@ -1495,9 +1482,8 @@ struct FmhaBwdDQDKDVKernel
number<1>{},
number<1>{});
return pad_tensor_view(randval_dram_naive,
randval_dram_window_lengths,
sequence<kPadSeqLenQ, kPadSeqLenK>{});
return pad_tensor_view(
randval_dram_naive, randval_dram_window_lengths, sequence<false, true>{});
}();
return make_tile_window(randval_dram, randval_dram_window_lengths, {0, i_n0});
@@ -1550,7 +1536,7 @@ struct FmhaBwdDQDKDVKernel
return pad_tensor_view(
dk_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kQKHeaddim>{}),
sequence<kPadSeqLenK, kPadHeadDimQ>{});
sequence<false, kPadHeadDimQ>{});
}();
auto dv_dram = [&]() {
@@ -1564,7 +1550,7 @@ struct FmhaBwdDQDKDVKernel
return pad_tensor_view(
dv_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kVHeaddim>{}),
sequence<kPadSeqLenK, kPadHeadDimV>{});
sequence<false, kPadHeadDimV>{});
}();
auto dk_dram_window = make_tile_window(

View File

@@ -49,8 +49,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum;
@@ -72,8 +70,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
static constexpr index_t kAlignmentVGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad<Problem>();
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias<Problem>();
static constexpr index_t kAlignmentBias = 1;
static constexpr const char* name = "kr_ktr_vr";
@@ -554,7 +551,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
});
}
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
bool need_perpixel_check = mask.IsEdgeTile(
seqlen_q_step, k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});

View File

@@ -49,8 +49,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum;
@@ -72,8 +70,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
static constexpr index_t kAlignmentVGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad<Problem>();
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias<Problem>();
static constexpr index_t kAlignmentBias = 1;
static constexpr const char* name = "kr_ktr_vr_iglp";
@@ -590,7 +587,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
});
}
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
bool need_perpixel_check = mask.IsEdgeTile(
seqlen_q_step, k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
@@ -849,7 +845,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
});
}
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
bool need_perpixel_check = mask.IsEdgeTile(
seqlen_q_step, k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});

View File

@@ -0,0 +1,30 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp"
namespace ck_tile {
template <typename Problem>
class BlockFmhaBwdDQDKDVPipelineSelector
{
static constexpr bool has_dpad = Problem::Traits::kPadHeadDimQ || Problem::Traits::kPadHeadDimV;
public:
using type = std::conditional_t<has_dpad,
BlockFmhaBwdDQDKDVPipelineKRKTRVR<Problem>,
BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<Problem>>;
};
template <typename Problem>
class BlockFmhaBwdDQDKDVPipeline : public BlockFmhaBwdDQDKDVPipelineSelector<Problem>::type
{
public:
static constexpr const char* name = "auto";
};
} // namespace ck_tile

View File

@@ -1,15 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck_tile {
// This class is used for codegen pattern matching
enum class BlockFmhaBwdPipelineEnum
{
KRKTRVR_IGLP = 0,
KRKTRVR,
};
} // namespace ck_tile

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -55,13 +55,13 @@ struct BlockFmhaBwdPipelineProblem
static constexpr bool kIsDeterministic = kIsDeterministic_;
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr auto BiasEnum = Traits::BiasEnum;
static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
static_assert(!Traits::kPadSeqLenQ, "BlockFmhaBwdPipelineProblem does not need kPadSeqLenQ");
static_assert(!Traits::kPadSeqLenK, "BlockFmhaBwdPipelineProblem does not need kPadSeqLenQ");
};
template <typename ODataType_,