mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 12:30:16 +00:00
CK Tile FA Training kernels (#1286)
* FA fwd dropout
* FA bwd
* epilogue reuse
* CMakeLists update
* [CK_TILE] support alibi (#1269)
* add alibi support
* fix code
* update code based on comment
* Support more hdim
* fix fp8 bias
* support seqlen_k=0 case
* remove unused printf
* fix format
---------
Co-authored-by: rocking <ChunYu.Lai@amd.com>
* now fwd/bwd can build
* bwd alibi
* add bwd validation stream_config
* update generated filenames
* update bwd kernel launch
* CK_TILE_HOST_DEVICE in philox
* Transpose -> transpose
* format
* format
* format
* Generate the instance for FA required
* format
* fix error in WarpGemm
---------
Co-authored-by: danyao12 <danyao12>
Co-authored-by: carlushuang <carlus.huang@amd.com>
Co-authored-by: rocking <ChunYu.Lai@amd.com>
Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
Co-authored-by: Jing Zhang <jizhan@amd.com>
[ROCm/composable_kernel commit: 2cab8d39e3]
This commit is contained in:
89
include/ck_tile/core/utility/philox_rand.hpp
Normal file
89
include/ck_tile/core/utility/philox_rand.hpp
Normal file
@@ -0,0 +1,89 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Reference: https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/philox.cuh
|
||||
class philox
|
||||
{
|
||||
public:
|
||||
CK_TILE_HOST_DEVICE philox(unsigned long long seed_, unsigned long long offset_)
|
||||
: seed(reinterpret_cast<const uint2&>(seed_))
|
||||
{
|
||||
|
||||
ull2* tmp = reinterpret_cast<ull2*>(&counter);
|
||||
tmp->x = offset_;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE uint4 get_philox_4x32(const unsigned long long subsequence) const
|
||||
{
|
||||
|
||||
uint4 counter_ = counter;
|
||||
ull2* tmp = reinterpret_cast<ull2*>(&counter_);
|
||||
tmp->y = subsequence;
|
||||
|
||||
uint2 key_ = seed;
|
||||
// 7-round philox
|
||||
#pragma unroll
|
||||
for(int i = 0; i < 6; i++)
|
||||
{
|
||||
counter_ = philox_single_round(counter_, key_);
|
||||
key_.x += kPhilox10A;
|
||||
key_.y += kPhilox10B;
|
||||
}
|
||||
uint4 output = philox_single_round(counter_, key_);
|
||||
return output;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void get_random_16x8(uint8_t* out,
|
||||
const unsigned long long subsequence) const
|
||||
{
|
||||
uint4 tmp_ph;
|
||||
tmp_ph = get_philox_4x32(subsequence);
|
||||
|
||||
uint32_t* out_tmp = reinterpret_cast<uint32_t*>(&out[0]);
|
||||
|
||||
out_tmp[0] = tmp_ph.x;
|
||||
out_tmp[1] = tmp_ph.y;
|
||||
out_tmp[2] = tmp_ph.z;
|
||||
out_tmp[3] = tmp_ph.w;
|
||||
}
|
||||
|
||||
private:
|
||||
struct ull2
|
||||
{
|
||||
uint64_t x;
|
||||
uint64_t y;
|
||||
};
|
||||
uint4 counter;
|
||||
const uint2 seed;
|
||||
|
||||
CK_TILE_HOST_DEVICE uint2 mulhilo32(const unsigned int a, const unsigned int b) const
|
||||
{
|
||||
uint2* res;
|
||||
unsigned long long tmp;
|
||||
tmp = static_cast<unsigned long long>(a) * b;
|
||||
res = reinterpret_cast<uint2*>(&tmp);
|
||||
return *res;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE uint4 philox_single_round(const uint4 ctr, const uint2 key) const
|
||||
{
|
||||
|
||||
uint2 res0 = mulhilo32(kPhiloxSA, ctr.x);
|
||||
uint2 res1 = mulhilo32(kPhiloxSB, ctr.z);
|
||||
uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x};
|
||||
return ret;
|
||||
}
|
||||
|
||||
static const unsigned long kPhilox10A = 0x9E3779B9;
|
||||
static const unsigned long kPhilox10B = 0xBB67AE85;
|
||||
static const unsigned long kPhiloxSA = 0xD2511F53;
|
||||
static const unsigned long kPhiloxSB = 0xCD9E8D57;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user