mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 11:30:02 +00:00
Padding for attention: bmm+scale+softmax+bmm kernel (#385)
* add padding algo for bmm+scale+softmax+bmm. Version for verification
* remove verification code
* remove comments
* add padded bmm scale softmax bmm example
* format
* refactor
* add comments for usages of padding bmm+scale+softmax+bmm
Co-authored-by: Chao Liu <lc.roy86@gmail.com>
[ROCm/composable_kernel commit: 45adb736e7]
This commit is contained in:
@@ -111,6 +111,15 @@ __global__ void
|
||||
// Computes C = A * B0 * B1
|
||||
// ^^^^^^ (Acc0)
|
||||
// ^^^^^^^^^^^ (Acc1)
|
||||
|
||||
// When using NPadding as GemmSpecialization, AccElementwiseOperation should be set to
|
||||
// ScaleAndResetNaNToMinusInfinity.
|
||||
// if !isNan(AccElement)
|
||||
// AccElement *= scale
|
||||
// else
|
||||
// AccElement = -INFINITY
|
||||
// Otherwise, result may be wrong.
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout, // B0Layout
|
||||
typename B1Layout,
|
||||
|
||||
@@ -97,6 +97,22 @@ struct Scale
|
||||
float scale_;
|
||||
};
|
||||
|
||||
struct ScaleAndResetNaNToMinusInfinity
|
||||
{
|
||||
__host__ __device__ ScaleAndResetNaNToMinusInfinity(float scale) : scale_(scale) {}
|
||||
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ void operator()(Y& y, const X& x) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<float, float>(float& y, const float& x) const
|
||||
{
|
||||
y = ck::math::isnan(x) ? -ck::NumericLimits<float>::Infinity() : scale_ * x;
|
||||
};
|
||||
|
||||
float scale_;
|
||||
};
|
||||
|
||||
struct UnaryDivide
|
||||
{
|
||||
__host__ __device__ UnaryDivide(const int32_t divider = 1) : divider_(divider) {}
|
||||
|
||||
@@ -349,9 +349,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
{
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
p_a_grid,
|
||||
a_grid_desc_ak0_m_ak1.GetElementSpaceSize(),
|
||||
NumericLimits<FloatAB>::QuietNaN());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
p_b_grid,
|
||||
b_grid_desc_bk0_n_bk1.GetElementSpaceSize(),
|
||||
NumericLimits<FloatAB>::QuietNaN());
|
||||
const auto b1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b1_grid, b1_grid_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
|
||||
@@ -1023,6 +1023,8 @@ struct NumericLimits
|
||||
{
|
||||
return std::numeric_limits<T>::quiet_NaN();
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr T Infinity() { return std::numeric_limits<T>::infinity(); }
|
||||
};
|
||||
|
||||
template <>
|
||||
|
||||
Reference in New Issue
Block a user