mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[CK_TILE] FA bwd kernels optimization (#1397)
* tmp save
* fix batch deterministic bugs
* fix group deterministic bugs
* codegen update
* reorder files
* bias support
* hd256 bias support
* bwd smoke test update
* simplify convert dq
* fix hd256 dropout scratch
* do{}while() -> while(){}
* comments
* remove FmhaBwdTilePartitioner
* save clear_tile
* refactor dropout
* code cleanup
* code cleanup
* comments
* fix epilogue problem
* fix fwd dropout
* group convert_dq opt
* fix dq alignment
* Do not store storerandval in bwd for flash attention integration
* fix hd32 error and boost performance
* revert
* Remove duplicated WarpGemm definitions in the policy file
* dropout patch for mrepeat 16*16
* code sync up
* dq_acc stride
* dq_acc stride stuff
* codegen update
* fwd dropout revert
* fix hd128 scratches and boost performance
* receipt 3 for simplified smoke test
* more strides for fa integration
* fix hd64 scratches and boost performance
* non-iglp pipeline for headdim padding cases
* dpad same as dvpad for flash attention integration
* unpadded lse&d for group mode
* Support unpad layout for group lse
* Support unpad lse layout for splitkv
* Fix stride for splitkv kernel
* fix unpadded lse issue in fwd splitkv
* comment
* solve lds read&write conflicts
* rename
* bias rename
* tile index revert
---------
Co-authored-by: danyao12 <danyao12>
Co-authored-by: rocking <ChunYu.Lai@amd.com>
Co-authored-by: Qianfeng Zhang <Qianfeng.Zhang@amd.com>
This commit is contained in:
@@ -53,6 +53,39 @@ class philox
|
||||
out_tmp[3] = tmp_ph.w;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void get_random_8x8(uint8_t* out,
|
||||
const unsigned long long subsequence,
|
||||
const index_t start_idx) const
|
||||
{
|
||||
uint4 tmp_ph;
|
||||
tmp_ph = get_philox_4x32(subsequence);
|
||||
|
||||
uint32x4_t tmp;
|
||||
tmp[0] = tmp_ph.x;
|
||||
tmp[1] = tmp_ph.y;
|
||||
tmp[2] = tmp_ph.z;
|
||||
tmp[3] = tmp_ph.w;
|
||||
uint32_t* out_tmp = reinterpret_cast<uint32_t*>(&out[0]);
|
||||
out_tmp[0] = tmp[start_idx];
|
||||
out_tmp[1] = tmp[start_idx + 2];
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void get_random_4x8(uint8_t* out,
|
||||
const unsigned long long subsequence,
|
||||
const index_t start_idx) const
|
||||
{
|
||||
uint4 tmp_ph;
|
||||
tmp_ph = get_philox_4x32(subsequence);
|
||||
|
||||
uint32x4_t tmp;
|
||||
tmp[0] = tmp_ph.x;
|
||||
tmp[1] = tmp_ph.y;
|
||||
tmp[2] = tmp_ph.z;
|
||||
tmp[3] = tmp_ph.w;
|
||||
uint32_t* out_tmp = reinterpret_cast<uint32_t*>(&out[0]);
|
||||
out_tmp[0] = tmp[start_idx];
|
||||
}
|
||||
|
||||
private:
|
||||
struct ull2
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user