mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-25 07:14:37 +00:00
[CK_TILE] FA bwd kernels optimization (#1397)
* tmp save
* fix batch deterministic bugs
* fix group deterministic bugs
* codegen update
* reorder files
* bias support
* hd256 bias support
* bwd smoke test update
* simplify convert dq
* fix hd256 dropout scratch
* do{}while() -> while(){}
* comments
* remove FmhaBwdTilePartitioner
* save clear_tile
* refactor dropout
* code cleanup
* code cleanup
* comments
* fix epilogue problem
* fix fwd dropout
* group convert_dq opt
* fix dq alignment
* Do not store storerandval in bwd for flash attention integration
* fix hd32 error and boost performance
* revert
* Remove duplicated WarpGemm definitions in the policy file
* dropout patch for mrepeat 16*16
* code sync up
* dq_acc stride
* dq_acc stride stuff
* codegen update
* fwd dropout revert
* fix hd128 scratches and boost performance
* receipt 3 for simplified smoke test
* more strides for fa integration
* fix hd64 scratches and boost performance
* non-iglp pipeline for headdim padding cases
* dpad same as dvpad for flash attention integration
* unpadded lse&d for group mode
* Support unpad layout for group lse
* Support unpad lse layout for splitkv
* Fix stride for splitkv kernel
* fix unpadded lse issue in fwd splitkv
* comment
* solve lds read&write conflicts
* rename
* bias rename
* tile index revert
---------
Co-authored-by: danyao12 <danyao12>
Co-authored-by: rocking <ChunYu.Lai@amd.com>
Co-authored-by: Qianfeng Zhang <Qianfeng.Zhang@amd.com>
[ROCm/composable_kernel commit: 79a5d9c10c]
This commit is contained in:
@@ -1341,7 +1341,7 @@ struct modulo : public base_transform<1, 1>
|
||||
};
|
||||
|
||||
// 2D XOR, NOTE: "xor" is a keyword
|
||||
template <typename LowLengths, typename RightShift>
|
||||
template <typename LowLengths>
|
||||
struct xor_t : public base_transform<2, 2>
|
||||
{
|
||||
static constexpr auto type_enum = coord_transform_enum::xor_t;
|
||||
@@ -1352,15 +1352,10 @@ struct xor_t : public base_transform<2, 2>
|
||||
using UpLengths = LowLengths;
|
||||
|
||||
UpLengths up_lengths_;
|
||||
RightShift right_shift_;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr xor_t() : up_lengths_{}, right_shift_{} {}
|
||||
CK_TILE_HOST_DEVICE constexpr xor_t() : up_lengths_{} {}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr xor_t(const LowLengths& low_lengths,
|
||||
const RightShift& right_shift)
|
||||
: up_lengths_{low_lengths}, right_shift_{right_shift}
|
||||
{
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr xor_t(const LowLengths& low_lengths) : up_lengths_{low_lengths} {}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
|
||||
{
|
||||
@@ -1378,13 +1373,8 @@ struct xor_t : public base_transform<2, 2>
|
||||
|
||||
idx_low(number<0>{}) = idx_up[number<0>{}];
|
||||
|
||||
const auto idx_low_1_tmp =
|
||||
(idx_up[number<1>{}] - idx_up[number<0>{}] * right_shift_) % up_lengths_[number<1>{}];
|
||||
|
||||
const auto idx_low_1 =
|
||||
(idx_low_1_tmp >= 0) ? idx_low_1_tmp : up_lengths_[number<1>{}] + idx_low_1_tmp;
|
||||
|
||||
idx_low(number<1>{}) = idx_low_1;
|
||||
idx_low(number<1>{}) =
|
||||
idx_up[number<1>{}] ^ (idx_up[number<0>{}] % up_lengths_[number<1>{}]);
|
||||
}
|
||||
|
||||
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
|
||||
@@ -1419,8 +1409,7 @@ struct xor_t : public base_transform<2, 2>
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
|
||||
{
|
||||
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
|
||||
ck_tile::is_known_at_compile_time<RightShift>::value;
|
||||
return ck_tile::is_known_at_compile_time<UpLengths>::value;
|
||||
}
|
||||
|
||||
// MUST be static function
|
||||
@@ -1432,14 +1421,6 @@ struct xor_t : public base_transform<2, 2>
|
||||
array<index_t, 2> up_vector_lengths = low_vector_lengths;
|
||||
array<index_t, 2> up_vector_strides = low_vector_strides;
|
||||
|
||||
if constexpr(ck_tile::is_known_at_compile_time<RightShift>::value)
|
||||
{
|
||||
if(low_vector_lengths[1] != -1)
|
||||
{
|
||||
up_vector_lengths(1) = gcd(low_vector_lengths[1], abs(right_shift_));
|
||||
}
|
||||
}
|
||||
|
||||
return make_tuple(up_vector_lengths, up_vector_strides);
|
||||
}
|
||||
|
||||
@@ -1452,10 +1433,6 @@ struct xor_t : public base_transform<2, 2>
|
||||
print(up_lengths_);
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("right_shift_: ");
|
||||
print(right_shift_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
@@ -1655,11 +1632,10 @@ CK_TILE_HOST_DEVICE constexpr auto make_modulo_transform(const Modulus& modulus,
|
||||
return modulo<Modulus, UpLength>{modulus, up_length};
|
||||
}
|
||||
|
||||
template <typename LowLengths, typename RightShift>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_xor_transform(const LowLengths& low_lengths,
|
||||
const RightShift& right_shift)
|
||||
template <typename LowLengths>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_xor_transform(const LowLengths& low_lengths)
|
||||
{
|
||||
return xor_t<LowLengths, RightShift>{low_lengths, right_shift};
|
||||
return xor_t<LowLengths>{low_lengths};
|
||||
}
|
||||
|
||||
template <typename LowLength, typename OffsetLength>
|
||||
|
||||
@@ -117,6 +117,15 @@ using int32x16_t = int32_t __attribute__((ext_vector_type(16)));
|
||||
using int32x32_t = int32_t __attribute__((ext_vector_type(32)));
|
||||
using int32x64_t = int32_t __attribute__((ext_vector_type(64)));
|
||||
|
||||
// u32
|
||||
// using uint32_t = ...
|
||||
using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
|
||||
using uint32x4_t = uint32_t __attribute__((ext_vector_type(4)));
|
||||
using uint32x8_t = uint32_t __attribute__((ext_vector_type(8)));
|
||||
using uint32x16_t = uint32_t __attribute__((ext_vector_type(16)));
|
||||
using uint32x32_t = uint32_t __attribute__((ext_vector_type(32)));
|
||||
using uint32x64_t = uint32_t __attribute__((ext_vector_type(64)));
|
||||
|
||||
// i16
|
||||
// using int16_t = ...
|
||||
using int16x2_t = int16_t __attribute__((ext_vector_type(2)));
|
||||
|
||||
@@ -746,8 +746,9 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
|
||||
return make_tuple(
|
||||
make_static_tile_distribution(
|
||||
tile_distribution_encoding<typename Encoding::RsLengths,
|
||||
decltype(sliced_h_lengths), // only need to change the
|
||||
// h_lengths type
|
||||
remove_cvref_t<decltype(sliced_h_lengths)>, // only need to
|
||||
// change the
|
||||
// h_lengths type
|
||||
typename Encoding::Ps2RHssMajor,
|
||||
typename Encoding::Ps2RHssMinor,
|
||||
typename Encoding::Ys2RHsMajor,
|
||||
|
||||
@@ -53,6 +53,39 @@ class philox
|
||||
out_tmp[3] = tmp_ph.w;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void get_random_8x8(uint8_t* out,
|
||||
const unsigned long long subsequence,
|
||||
const index_t start_idx) const
|
||||
{
|
||||
uint4 tmp_ph;
|
||||
tmp_ph = get_philox_4x32(subsequence);
|
||||
|
||||
uint32x4_t tmp;
|
||||
tmp[0] = tmp_ph.x;
|
||||
tmp[1] = tmp_ph.y;
|
||||
tmp[2] = tmp_ph.z;
|
||||
tmp[3] = tmp_ph.w;
|
||||
uint32_t* out_tmp = reinterpret_cast<uint32_t*>(&out[0]);
|
||||
out_tmp[0] = tmp[start_idx];
|
||||
out_tmp[1] = tmp[start_idx + 2];
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void get_random_4x8(uint8_t* out,
|
||||
const unsigned long long subsequence,
|
||||
const index_t start_idx) const
|
||||
{
|
||||
uint4 tmp_ph;
|
||||
tmp_ph = get_philox_4x32(subsequence);
|
||||
|
||||
uint32x4_t tmp;
|
||||
tmp[0] = tmp_ph.x;
|
||||
tmp[1] = tmp_ph.y;
|
||||
tmp[2] = tmp_ph.z;
|
||||
tmp[3] = tmp_ph.w;
|
||||
uint32_t* out_tmp = reinterpret_cast<uint32_t*>(&out[0]);
|
||||
out_tmp[0] = tmp[start_idx];
|
||||
}
|
||||
|
||||
private:
|
||||
struct ull2
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user