mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 20:40:07 +00:00
Re-enable optimization for gfx950 fmha fwd (#2671)
* Fix for fwd/bwd kernel build filter
* fix bwd code
* save an example for __bf16 type
* temp save, waiting for debug
* tempsave, fmha_decode
* temp save, change all instance to 1wave
* fix async copytest bug
* Add block_sync_lds_direct_load utility
* fix the s_waitcnt_imm calculation
* Improve s_waitcnt_imm calculation
* fix vmcnt shift
* add input validation and bug fix
* remove unnecessary output
* move test_copy into test
* temp save
* tempsave
* compile pass
* tempsave, trload+asyncload done
* tempsave. asynccopy+trload sanity checked
* remove unnecessary features
* fix the lds alignment caused performance regression
* enable prefill overload operator().
* remove all lds bankconflict with xor layouts
* enable larger tile size; upgrade xor pattern
* upgrade prefill pipeline; simple iglp; consistent data produce and consume order
* small refactor
* Load Q through lds, implement xor;
* add vmcnt guard before load ktile
* Add v_permlaneb32 for block_reduce. Disable it as it will cause un-coexecutable packed math in FA
* Add XOR fold strategy for hdim<128, but perf dropped; disable it by default; wait further perf debug
* add __restrict__ to tr load
* merge fa_decode pipeline into fmha_fwd api
* remove unnecessary files; rename some files
* Remove unnecessary changes
* bug fix, clang format;
* remove non-necessary change
* fix clangformat with 18.1.3
* fix bugs
* fix bug
* fix bug on non-gfx950
* fix bugs in gemm
* fix bug in pki4
* tempsave, update the blocksync functions
* change the warp setting for hdim32 fmha fwd
* clang format
* fix conflict. disable all v-col instance for fmha fwd
* Fix the bug
* clang format
* refactor blockgemm change, isolate to v2;
---------
Co-authored-by: Max Podkorytov <4273004+tenpercent@users.noreply.github.com>
Co-authored-by: asleepzzz <hanwen.chang@amd.com>
[ROCm/composable_kernel commit: 05a6e92705]
This commit is contained in:
@@ -6,6 +6,9 @@
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/numeric/numeric.hpp"
|
||||
#if CK_TILE_USE_LLVM_BUILTIN_BF16
|
||||
#include <hip/hip_bfloat16.h>
|
||||
#endif
|
||||
#include <stdint.h>
|
||||
|
||||
#pragma once
|
||||
@@ -102,7 +105,11 @@ struct native_t<bfloat16_t>
|
||||
using bf16_t = bfloat16_t;
|
||||
using bf16_raw_t = typename bf16_t::raw_type;
|
||||
#else
|
||||
#if CK_TILE_USE_LLVM_BUILTIN_BF16
|
||||
using bfloat16_t = __bf16;
|
||||
#else
|
||||
using bfloat16_t = ushort;
|
||||
#endif
|
||||
using bf16_t = bfloat16_t;
|
||||
using bf16_raw_t = uint16_t;
|
||||
#endif
|
||||
@@ -280,7 +287,11 @@ template <bf16_rounding_mode rounding =
|
||||
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE constexpr bfloat16_t float_to_bf16(float f, constant<rounding> = {})
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
return static_cast<bfloat16_t>(f);
|
||||
#else
|
||||
return bit_cast<bfloat16_t>(float_to_bf16_raw(f, constant<rounding>{}));
|
||||
#endif
|
||||
}
|
||||
|
||||
template <bf16_rounding_mode rounding =
|
||||
|
||||
@@ -21,7 +21,7 @@ namespace ck_tile {
|
||||
using fp32_t = float;
|
||||
using fp32x2_t = float __attribute__((ext_vector_type(2)));
|
||||
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
|
||||
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
|
||||
using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2)));
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr uint8_t float_to_e2m1(float x, float scale = 1.f);
|
||||
|
||||
|
||||
@@ -99,7 +99,7 @@ struct numeric_traits<pk_int4_t>
|
||||
|
||||
using fp32x2_t = float __attribute__((ext_vector_type(2)));
|
||||
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
|
||||
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
|
||||
using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2)));
|
||||
|
||||
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t& x)
|
||||
{
|
||||
|
||||
@@ -131,12 +131,12 @@ using fp16x64_t = _Float16 __attribute__((ext_vector_type(64)));
|
||||
|
||||
// bf16
|
||||
// using bf16_t = ...
|
||||
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
|
||||
using bf16x4_t = bf16_raw_t __attribute__((ext_vector_type(4)));
|
||||
using bf16x8_t = bf16_raw_t __attribute__((ext_vector_type(8)));
|
||||
using bf16x16_t = bf16_raw_t __attribute__((ext_vector_type(16)));
|
||||
using bf16x32_t = bf16_raw_t __attribute__((ext_vector_type(32)));
|
||||
using bf16x64_t = bf16_raw_t __attribute__((ext_vector_type(64)));
|
||||
using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2)));
|
||||
using bf16x4_t = bfloat16_t __attribute__((ext_vector_type(4)));
|
||||
using bf16x8_t = bfloat16_t __attribute__((ext_vector_type(8)));
|
||||
using bf16x16_t = bfloat16_t __attribute__((ext_vector_type(16)));
|
||||
using bf16x32_t = bfloat16_t __attribute__((ext_vector_type(32)));
|
||||
using bf16x64_t = bfloat16_t __attribute__((ext_vector_type(64)));
|
||||
|
||||
// i32
|
||||
// using int32_t = ...
|
||||
|
||||
Reference in New Issue
Block a user