Padding for attention: bmm+scale+softmax+bmm kernel (#385)

* add padding algo for bmm+scale+softmax+bmm. Version for verification

* remove verification code

* remove comments

* add padded bmm scale softmax bmm example

* format

* refactor

* add comments for usages of padding bmm+scale+softmax+bmm

Co-authored-by: Chao Liu <lc.roy86@gmail.com>

[ROCm/composable_kernel commit: 45adb736e7]
This commit is contained in:
Shaojie WANG
2022-08-31 00:01:37 +08:00
committed by GitHub
parent 65e451c3ca
commit e26256cd7d
6 changed files with 436 additions and 2 deletions

View File

@@ -111,6 +111,15 @@ __global__ void
// Computes C = A * B0 * B1
// ^^^^^^ (Acc0)
// ^^^^^^^^^^^ (Acc1)
// When using NPadding as GemmSpecialization, AccElementwiseOperation should be set to
// ScaleAndResetNaNToMinusInfinity.
// if !isNan(AccElement)
// AccElement *= scale
// else
// AccElement = -INFINITY
// Otherwise, result may be wrong.
template <typename ALayout,
typename BLayout, // B0Layout
typename B1Layout,

View File

@@ -97,6 +97,22 @@ struct Scale
float scale_;
};
struct ScaleAndResetNaNToMinusInfinity
{
__host__ __device__ ScaleAndResetNaNToMinusInfinity(float scale) : scale_(scale) {}
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const;
template <>
__host__ __device__ void operator()<float, float>(float& y, const float& x) const
{
y = ck::math::isnan(x) ? -ck::NumericLimits<float>::Infinity() : scale_ * x;
};
float scale_;
};
struct UnaryDivide
{
__host__ __device__ UnaryDivide(const int32_t divider = 1) : divider_(divider) {}

View File

@@ -349,9 +349,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const Block2CTileMap& block_2_ctile_map)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
p_a_grid,
a_grid_desc_ak0_m_ak1.GetElementSpaceSize(),
NumericLimits<FloatAB>::QuietNaN());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
p_b_grid,
b_grid_desc_bk0_n_bk1.GetElementSpaceSize(),
NumericLimits<FloatAB>::QuietNaN());
const auto b1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b1_grid, b1_grid_desc_bk0_n_bk1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(

View File

@@ -1023,6 +1023,8 @@ struct NumericLimits
{
return std::numeric_limits<T>::quiet_NaN();
}
__host__ __device__ static constexpr T Infinity() { return std::numeric_limits<T>::infinity(); }
};
template <>