Merge commit '3f57ec3d2dc856a30ca1c652eda19e5dd4ee6041' into develop

This commit is contained in:
assistant-librarian[bot]
2025-08-12 23:14:34 +00:00
parent 36f828fd80
commit 3bdebe4686
18 changed files with 2547 additions and 291 deletions

View File

@@ -262,219 +262,67 @@ struct PassThroughPack2
struct PassThrough
{
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
template <class T>
using raw_t = std::remove_cv_t<std::remove_reference_t<T>>;
template <>
CK_TILE_HOST_DEVICE void operator()<double, double>(double& y, const double& x) const
template <class Y, class X>
CK_TILE_HOST_DEVICE void operator()(Y&& y, const X& x) const
{
y = x;
/* Only do the assignment when
- y is an *l-value* and
- y is *not* const */
if constexpr(std::is_lvalue_reference_v<Y&&> && !std::is_const_v<raw_t<Y>>)
{
y = ck_tile::type_convert<raw_t<Y>>(x);
}
/* otherwise (r-value or const) → do nothing */
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, double>(float& y, const double& x) const
template <typename E, typename C, typename... Ds>
CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const Ds&... ds) const -> void
{
y = type_convert<float>(x);
}
// Suppress unused parameter warning for ds
((void)ds, ...);
template <>
CK_TILE_HOST_DEVICE void operator()<double, float>(double& y, const float& x) const
{
y = type_convert<double>(x);
// Just assign e with c
if constexpr(std::is_same_v<E, C>)
{
e = c;
}
else
{
e = ck_tile::type_convert<E>(c);
}
}
};
template <>
CK_TILE_HOST_DEVICE void operator()<float, float>(float& y, const float& x) const
struct MultiDMultiply
{
template <typename E, typename C, typename... Ds>
CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const Ds&... ds) const -> void
{
y = x;
}
// Start with the base value c
float result = ck_tile::type_convert<float>(c);
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::fp16_t, ck_tile::fp16_t>(ck_tile::fp16_t& y, const ck_tile::fp16_t& x) const
{
y = x;
}
// Multiply by each D parameter using fold expression
((result *= ck_tile::type_convert<float>(ds)), ...);
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp16_t, float>(ck_tile::fp16_t& y,
const float& x) const
{
y = type_convert<ck_tile::fp16_t>(x);
e = ck_tile::type_convert<E>(result);
}
};
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::bf16_t, ck_tile::bf16_t>(ck_tile::bf16_t& y, const ck_tile::bf16_t& x) const
struct MultiDAdd
{
template <typename E, typename C, typename... Ds>
CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const Ds&... ds) const -> void
{
y = x;
}
// Start with the base value c
float result = ck_tile::type_convert<float>(c);
template <>
CK_TILE_HOST_DEVICE void operator()<int32_t, int32_t>(int32_t& y, const int32_t& x) const
{
y = x;
}
// Add by each D parameter using fold expression
((result += ck_tile::type_convert<float>(ds)), ...);
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::bf16_t, float>(ck_tile::bf16_t& y,
const float& x) const
{
y = type_convert<ck_tile::bf16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::bf16_t>(float& y,
const ck_tile::bf16_t& x) const
{
y = type_convert<float>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::bf16_t, ck_tile::fp16_t>(ck_tile::bf16_t& y, const ck_tile::fp16_t& x) const
{
y = type_convert<ck_tile::bf16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::fp16_t>(float& y,
const ck_tile::fp16_t& x) const
{
y = type_convert<float>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp16_t, int8_t>(ck_tile::fp16_t& y,
const int8_t& x) const
{
y = type_convert<ck_tile::fp16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::bf16_t, int8_t>(ck_tile::bf16_t& y,
const int8_t& x) const
{
y = type_convert<ck_tile::bf16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<uint8_t, uint8_t>(uint8_t& y, const uint8_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<int8_t, int32_t>(int8_t& y, const int32_t& x) const
{
y = type_convert<int8_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<int32_t, int8_t>(int32_t& y, const int8_t& x) const
{
y = type_convert<int32_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<int8_t, float>(int8_t& y, const float& x) const
{
y = type_convert<int8_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, int8_t>(float& y, const int8_t& x) const
{
y = type_convert<float>(x);
}
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <>
CK_TILE_HOST_DEVICE void operator()<int4_t, int4_t>(int4_t& y, const int4_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<int4_t, int>(int4_t& y, const int& x) const
{
y = type_convert<int4_t>(x);
}
#endif
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::fp8_t, ck_tile::fp8_t>(ck_tile::fp8_t& y, const ck_tile::fp8_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::fp8_t>(float& y,
const ck_tile::fp8_t& x) const
{
y = type_convert<float>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp8_t, float>(ck_tile::fp8_t& y,
const float& x) const
{
y = type_convert<ck_tile::fp8_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::fp16_t, ck_tile::fp8_t>(ck_tile::fp16_t& y, const ck_tile::fp8_t& x) const
{
y = type_convert<ck_tile::fp16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::fp8_t, ck_tile::fp16_t>(ck_tile::fp8_t& y, const ck_tile::fp16_t& x) const
{
y = type_convert<ck_tile::fp8_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::bf8_t, ck_tile::bf8_t>(ck_tile::bf8_t& y, const ck_tile::bf8_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::bf8_t>(float& y,
const ck_tile::bf8_t& x) const
{
y = type_convert<float>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::bf8_t, float>(ck_tile::bf8_t& y,
const float& x) const
{
y = type_convert<ck_tile::bf8_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::fp16_t, ck_tile::bf8_t>(ck_tile::fp16_t& y, const ck_tile::bf8_t& x) const
{
y = type_convert<ck_tile::fp16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::bf8_t, ck_tile::fp16_t>(ck_tile::bf8_t& y, const ck_tile::fp16_t& x) const
{
y = ck_tile::type_convert<ck_tile::bf8_t>(x);
e = ck_tile::type_convert<E>(result);
}
};

View File

@@ -6,10 +6,10 @@
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
#include "ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp"
#include "ck_tile/ops/reduce/pipeline/reduce2d_default_policy.hpp"
#include "ck_tile/ops/reduce/pipeline/reduce2d_problem.hpp"
#include "ck_tile/ops/reduce/pipeline/reduce2d_shape.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"