[CK_TILE] FA bwd kernels optimization (#1397)

* tmp save

* fix batch deterministic bugs

* fix group deterministic bugs

* codegen update

* reorder files

* bias support

* hd256 bias support

* bwd smoke test update

* simplify convert dq

* fix hd256 dropout scratch

* do{}while() -> while(){}

* comments

* remove FmhaBwdTilePartitioner

* save clear_tile

* refactor dropout

* code cleanup

* code cleanup

* comments

* fix epilogue problem

* fix fwd dropout

* group convert_dq opt

* fix dq alignment

* Do not store storerandval in bwd for flash attention integration

* fix hd32 error and boost performance

* revert

* Remove duplicated WarpGemm definitions in the policy file

* dropout patch for mrepeat 16*16

* code sync up

* dq_acc stride

* dq_acc stride stuff

* codegen update

* fwd dropout revert

* fix hd128 scratches and boost performance

* receipt 3 for simplified smoke test

* more strides for fa integration

* fix hd64 scratches and boost performance

* non-iglp pipeline for headdim padding cases

* dpad same as dvpad for flash attention integration

* unpadded lse&d for group mode

* Support unpad layout for group lse

* Support unpad lse layout for splitkv

* Fix stride for splitkv kernel

* fix unpadded lse issue in fwd splitkv

* comment

* solve lds read&write conflicts

* rename

* bias rename

* tile index revert

---------

Co-authored-by: danyao12 <danyao12>
Co-authored-by: rocking <ChunYu.Lai@amd.com>
Co-authored-by: Qianfeng Zhang <Qianfeng.Zhang@amd.com>

[ROCm/composable_kernel commit: 79a5d9c10c]
This commit is contained in:
Dan Yao
2024-08-17 04:40:10 +08:00
committed by GitHub
parent dffd5eacc0
commit 14402bb211
43 changed files with 5515 additions and 4222 deletions

View File

@@ -1341,7 +1341,7 @@ struct modulo : public base_transform<1, 1>
};
// 2D XOR, NOTE: "xor" is a keyword
template <typename LowLengths, typename RightShift>
template <typename LowLengths>
struct xor_t : public base_transform<2, 2>
{
static constexpr auto type_enum = coord_transform_enum::xor_t;
@@ -1352,15 +1352,10 @@ struct xor_t : public base_transform<2, 2>
using UpLengths = LowLengths;
UpLengths up_lengths_;
RightShift right_shift_;
CK_TILE_HOST_DEVICE constexpr xor_t() : up_lengths_{}, right_shift_{} {}
CK_TILE_HOST_DEVICE constexpr xor_t() : up_lengths_{} {}
CK_TILE_HOST_DEVICE constexpr xor_t(const LowLengths& low_lengths,
const RightShift& right_shift)
: up_lengths_{low_lengths}, right_shift_{right_shift}
{
}
CK_TILE_HOST_DEVICE constexpr xor_t(const LowLengths& low_lengths) : up_lengths_{low_lengths} {}
CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
{
@@ -1378,13 +1373,8 @@ struct xor_t : public base_transform<2, 2>
idx_low(number<0>{}) = idx_up[number<0>{}];
const auto idx_low_1_tmp =
(idx_up[number<1>{}] - idx_up[number<0>{}] * right_shift_) % up_lengths_[number<1>{}];
const auto idx_low_1 =
(idx_low_1_tmp >= 0) ? idx_low_1_tmp : up_lengths_[number<1>{}] + idx_low_1_tmp;
idx_low(number<1>{}) = idx_low_1;
idx_low(number<1>{}) =
idx_up[number<1>{}] ^ (idx_up[number<0>{}] % up_lengths_[number<1>{}]);
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
@@ -1419,8 +1409,7 @@ struct xor_t : public base_transform<2, 2>
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
ck_tile::is_known_at_compile_time<RightShift>::value;
return ck_tile::is_known_at_compile_time<UpLengths>::value;
}
// MUST be static function
@@ -1432,14 +1421,6 @@ struct xor_t : public base_transform<2, 2>
array<index_t, 2> up_vector_lengths = low_vector_lengths;
array<index_t, 2> up_vector_strides = low_vector_strides;
if constexpr(ck_tile::is_known_at_compile_time<RightShift>::value)
{
if(low_vector_lengths[1] != -1)
{
up_vector_lengths(1) = gcd(low_vector_lengths[1], abs(right_shift_));
}
}
return make_tuple(up_vector_lengths, up_vector_strides);
}
@@ -1452,10 +1433,6 @@ struct xor_t : public base_transform<2, 2>
print(up_lengths_);
printf(", ");
//
printf("right_shift_: ");
print(right_shift_);
printf("}");
}
};
@@ -1655,11 +1632,10 @@ CK_TILE_HOST_DEVICE constexpr auto make_modulo_transform(const Modulus& modulus,
return modulo<Modulus, UpLength>{modulus, up_length};
}
template <typename LowLengths, typename RightShift>
CK_TILE_HOST_DEVICE constexpr auto make_xor_transform(const LowLengths& low_lengths,
const RightShift& right_shift)
template <typename LowLengths>
CK_TILE_HOST_DEVICE constexpr auto make_xor_transform(const LowLengths& low_lengths)
{
return xor_t<LowLengths, RightShift>{low_lengths, right_shift};
return xor_t<LowLengths>{low_lengths};
}
template <typename LowLength, typename OffsetLength>

View File

@@ -117,6 +117,15 @@ using int32x16_t = int32_t __attribute__((ext_vector_type(16)));
using int32x32_t = int32_t __attribute__((ext_vector_type(32)));
using int32x64_t = int32_t __attribute__((ext_vector_type(64)));
// u32
// using uint32_t = ...
using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
using uint32x4_t = uint32_t __attribute__((ext_vector_type(4)));
using uint32x8_t = uint32_t __attribute__((ext_vector_type(8)));
using uint32x16_t = uint32_t __attribute__((ext_vector_type(16)));
using uint32x32_t = uint32_t __attribute__((ext_vector_type(32)));
using uint32x64_t = uint32_t __attribute__((ext_vector_type(64)));
// i16
// using int16_t = ...
using int16x2_t = int16_t __attribute__((ext_vector_type(2)));

View File

@@ -746,8 +746,9 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
return make_tuple(
make_static_tile_distribution(
tile_distribution_encoding<typename Encoding::RsLengths,
decltype(sliced_h_lengths), // only need to change the
// h_lengths type
remove_cvref_t<decltype(sliced_h_lengths)>, // only need to
// change the
// h_lengths type
typename Encoding::Ps2RHssMajor,
typename Encoding::Ps2RHssMinor,
typename Encoding::Ys2RHsMajor,

View File

@@ -53,6 +53,39 @@ class philox
out_tmp[3] = tmp_ph.w;
}
CK_TILE_HOST_DEVICE void get_random_8x8(uint8_t* out,
const unsigned long long subsequence,
const index_t start_idx) const
{
uint4 tmp_ph;
tmp_ph = get_philox_4x32(subsequence);
uint32x4_t tmp;
tmp[0] = tmp_ph.x;
tmp[1] = tmp_ph.y;
tmp[2] = tmp_ph.z;
tmp[3] = tmp_ph.w;
uint32_t* out_tmp = reinterpret_cast<uint32_t*>(&out[0]);
out_tmp[0] = tmp[start_idx];
out_tmp[1] = tmp[start_idx + 2];
}
CK_TILE_HOST_DEVICE void get_random_4x8(uint8_t* out,
const unsigned long long subsequence,
const index_t start_idx) const
{
uint4 tmp_ph;
tmp_ph = get_philox_4x32(subsequence);
uint32x4_t tmp;
tmp[0] = tmp_ph.x;
tmp[1] = tmp_ph.y;
tmp[2] = tmp_ph.z;
tmp[3] = tmp_ph.w;
uint32_t* out_tmp = reinterpret_cast<uint32_t*>(&out[0]);
out_tmp[0] = tmp[start_idx];
}
private:
struct ull2
{