Merge commit '191c62967bf05f58641725b88f038bea462fe651' into develop

This commit is contained in:
assistant-librarian[bot]
2025-08-11 13:24:17 +00:00
parent 63c6b9b93c
commit b2ec5dde0a
4 changed files with 7 additions and 63 deletions

View File

@@ -35,8 +35,6 @@ struct Add
return type_convert<T>(y_ + x_);
}
static constexpr bool requires_special_combine = false;
};
struct SquareAdd
@@ -64,28 +62,6 @@ struct SquareAdd
float x_ = type_convert<float>(x);
return type_convert<T>(y_ + (x_ * x_));
}
// For combining partial results
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
CK_TILE_HOST_DEVICE constexpr T combine_partial_results(const T& partial1,
const T& partial2) const
{
return partial1 + partial2; // Just add the partial sums, don't square again
}
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
CK_TILE_HOST_DEVICE constexpr T combine_partial_results(T& partial1, T& partial2) const
{
float partial1_ = type_convert<float>(partial1);
float partial2_ = type_convert<float>(partial2);
return type_convert<T>(partial1_ + partial2_);
}
static constexpr bool requires_special_combine = true;
};
struct Max
@@ -109,8 +85,6 @@ struct Max
{
return max(y, x);
}
static constexpr bool requires_special_combine = false;
};
struct AbsMax
@@ -134,8 +108,6 @@ struct AbsMax
{
return max(y, abs(x));
}
static constexpr bool requires_special_combine = false;
};
} // namespace ReduceOp