Optimize fmha fwd decode & prefill for gfx950 (#2641)

* Fix for fwd/bwd kernel build filter

* fix bwd code

* save an example for __bf16 type

* temp save, waiting for debug

* tempsave, fmha_decode

* temp save, change all instance to 1wave

* fix async copytest bug

* Add block_sync_lds_direct_load utility

* fix the s_waitcnt_imm calculation

* Improve s_waitcnt_imm calculation

* fix vmcnt shift

* add input validation and bug fix

* remove unnecessary output

* move test_copy into test

* temp save

* tempsave

* compile pass

* tempsave, trload+asyncload done

* tempsave. asynccopy+trload sanity checked

* remove unnecessary features

* fix the lds alignment caused performance regression

* enable prefill overload operator().

* remove all lds bankconflict with xor layouts

* enable larger tile size; upgrade xor pattern

* upgrade prefill pipeline; simple iglp; consistent data produce and consume order

* small refactor

* Load Q through lds, implement xor;

* add vmcnt guard before load ktile

* Add v_permlaneb32 for block_reduce. Disable it as it will cause un-coexecutable packed math in FA

* Add XOR fold strategy for hdim<128, but perf dropped; disable it by default; wait further perf debug

* add __restrict__ to tr load

* merge fa_decode pipeline into fmha_fwd api

* remove unnecessary files; rename some files

* Remove unnecessary changes

* bug fix, clang format;

* remove non-necessary change

* fix clangformat with 18.1.3

* fix bugs

* fix bug

* fix bug on non-gfx950

* fix bugs in gemm

* fix bug in pki4

* tempsave, update the blocksync functions

* change the warp setting for hdim32 fmha fwd

* clang format

* fix conflict. disable all v-col instance for fmha fwd

* Fix the bug

* clang format

---------

Co-authored-by: Max Podkorytov <4273004+tenpercent@users.noreply.github.com>

[ROCm/composable_kernel commit: b7322a521a]
This commit is contained in:
Haocong WANG
2025-08-12 19:43:14 +08:00
committed by GitHub
parent eab0dc96f6
commit 8a20b06f54
31 changed files with 3533 additions and 627 deletions

View File

@@ -41,10 +41,6 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz
{
buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD};
int32x4_t r = __builtin_bit_cast(int32x4_t, res);
r.x = __builtin_amdgcn_readfirstlane(r.x);
r.y = __builtin_amdgcn_readfirstlane(r.y);
r.z = __builtin_amdgcn_readfirstlane(r.z);
r.w = __builtin_amdgcn_readfirstlane(r.w);
return r;
}
@@ -1318,6 +1314,17 @@ enum struct amd_buffer_coherence_enum
glc = 1,
slc = 2,
glc_slc = 3,
// gfx94: bit 0 = sc0, bit 1 = nt, bit 3 = swz, bit 4 = sc1
// SC[1:0] System Cache level: 0=wave, 1=group, 2=device, 3=system
// NT Non-Temporal: 0=expect temporal reuse; 1=do not expect temporal reuse
WAVE_NT0 = 0,
WAVE_NT1 = 2,
GROUP_NT0 = 1,
GROUP_NT1 = 3,
DEVICE_NT0 = 8,
DEVICE_NT1 = 10,
SYSTEM_NT0 = 9,
SYSTEM_NT1 = 11,
};
template <index_t N,
@@ -2756,7 +2763,7 @@ CK_TILE_DEVICE void amd_buffer_atomic_max(const thread_buffer<T, N>& src_thread_
#if defined(__gfx950__)
template <typename T, index_t N, address_space_enum BufferAddressSpace>
__device__ auto amd_transpose_load_to_vgpr(const T* in_ptr)
__device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
{
static_assert(__has_builtin(__builtin_amdgcn_raw_buffer_load_b32),

View File

@@ -32,10 +32,6 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz
{
buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD};
int32x4_t r = __builtin_bit_cast(int32x4_t, res);
r.x = __builtin_amdgcn_readfirstlane(r.x);
r.y = __builtin_amdgcn_readfirstlane(r.y);
r.z = __builtin_amdgcn_readfirstlane(r.z);
r.w = __builtin_amdgcn_readfirstlane(r.w);
return r;
}
@@ -1186,6 +1182,17 @@ enum struct amd_buffer_coherence_enum
glc = 1,
slc = 2,
glc_slc = 3,
// gfx94: bit 0 = sc0, bit 1 = nt, bit 3 = swz, bit 4 = sc1
// SC[1:0] System Cache level: 0=wave, 1=group, 2=device, 3=system
// NT Non-Temporal: 0=expect temporal reuse; 1=do not expect temporal reuse
WAVE_NT0 = 0,
WAVE_NT1 = 2,
GROUP_NT0 = 1,
GROUP_NT1 = 3,
DEVICE_NT0 = 8,
DEVICE_NT1 = 10,
SYSTEM_NT0 = 9,
SYSTEM_NT1 = 11,
};
template <index_t N,
@@ -2574,7 +2581,7 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
#if defined(__gfx950__)
template <typename T, index_t N, address_space_enum BufferAddressSpace>
__device__ auto amd_transpose_load_to_vgpr(const T* in_ptr)
__device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
{
static_assert(__has_builtin(__builtin_amdgcn_raw_buffer_load_b32),

View File

@@ -89,21 +89,6 @@ CK_TILE_DEVICE index_t get_thread_id() { return threadIdx.x; }
CK_TILE_DEVICE index_t get_block_id() { return blockIdx.x; }
CK_TILE_DEVICE void block_sync_lds()
{
#if CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
// asm volatile("\
// s_waitcnt lgkmcnt(0) \n \
// s_barrier \
// " ::);
__builtin_amdgcn_s_waitcnt(0xc07f);
__builtin_amdgcn_s_barrier();
#else
__syncthreads();
#endif
}
CK_TILE_DEVICE void block_sync_load_raw(index_t cnt = 0)
{
#ifdef __gfx12__
@@ -174,6 +159,18 @@ CK_TILE_DEVICE void s_waitcnt_barrier()
__builtin_amdgcn_s_barrier();
}
template <index_t lgkmcnt = 0>
CK_TILE_DEVICE void block_sync_lds()
{
s_waitcnt_barrier<waitcnt_arg::kMaxVmCnt, waitcnt_arg::kMaxExpCnt, lgkmcnt>();
}
template <index_t vmcnt = 0>
CK_TILE_DEVICE void block_sync_lds_direct_load()
{
s_waitcnt_barrier<vmcnt, waitcnt_arg::kMaxExpCnt, waitcnt_arg::kMaxLgkmCnt>();
}
CK_TILE_DEVICE void s_nop(index_t cnt = 0)
{
#if 1

View File

@@ -59,6 +59,21 @@ CK_TILE_DEVICE T warp_shuffle_down(const T& v_local, uint32_t lane_delta)
#endif
}
template <typename T>
CK_TILE_DEVICE auto warp_shuffle_down_pair(const T& v_local)
{
static_assert(sizeof(T) == sizeof(int32_t), "wrong!");
const int32x2_t x = __builtin_amdgcn_permlane32_swap(
bit_cast<int32_t>(v_local), bit_cast<int32_t>(v_local), false, false);
thread_buffer<T, 2> v;
v(0) = bit_cast<T>(x[0]);
v(1) = bit_cast<T>(x[1]);
return v;
}
template <typename T>
CK_TILE_DEVICE T warp_shuffle(const T& v_local, uint32_t src_lane)
{

View File

@@ -191,6 +191,16 @@
#endif
#endif
// use llvm builtin bf16 data type after ROCm 6.5
#ifndef CK_TILE_USE_LLVM_BUILTIN_BF16
#if(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 5 && HIP_VERSION_PATCH >= 50421) || \
(HIP_VERSION_MAJOR >= 7)
#define CK_TILE_USE_LLVM_BUILTIN_BF16 1
#else
#define CK_TILE_USE_LLVM_BUILTIN_BF16 0
#endif
#endif
#ifndef CK_TILE_DEBUG_LOG
#define CK_TILE_DEBUG_LOG 0
#endif

View File

@@ -6,6 +6,9 @@
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#if CK_TILE_USE_LLVM_BUILTIN_BF16
#include <hip/hip_bfloat16.h>
#endif
#include <stdint.h>
#pragma once
@@ -102,7 +105,11 @@ struct native_t<bfloat16_t>
using bf16_t = bfloat16_t;
using bf16_raw_t = typename bf16_t::raw_type;
#else
#if CK_TILE_USE_LLVM_BUILTIN_BF16
using bfloat16_t = __bf16;
#else
using bfloat16_t = ushort;
#endif
using bf16_t = bfloat16_t;
using bf16_raw_t = uint16_t;
#endif
@@ -280,7 +287,11 @@ template <bf16_rounding_mode rounding =
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
CK_TILE_HOST_DEVICE constexpr bfloat16_t float_to_bf16(float f, constant<rounding> = {})
{
#if defined(__gfx950__)
return static_cast<bfloat16_t>(f);
#else
return bit_cast<bfloat16_t>(float_to_bf16_raw(f, constant<rounding>{}));
#endif
}
template <bf16_rounding_mode rounding =

View File

@@ -21,7 +21,7 @@ namespace ck_tile {
using fp32_t = float;
using fp32x2_t = float __attribute__((ext_vector_type(2)));
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2)));
CK_TILE_HOST_DEVICE constexpr uint8_t float_to_e2m1(float x, float scale = 1.f);

View File

@@ -99,7 +99,7 @@ struct numeric_traits<pk_int4_t>
using fp32x2_t = float __attribute__((ext_vector_type(2)));
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2)));
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t& x)
{

View File

@@ -131,12 +131,12 @@ using fp16x64_t = _Float16 __attribute__((ext_vector_type(64)));
// bf16
// using bf16_t = ...
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
using bf16x4_t = bf16_raw_t __attribute__((ext_vector_type(4)));
using bf16x8_t = bf16_raw_t __attribute__((ext_vector_type(8)));
using bf16x16_t = bf16_raw_t __attribute__((ext_vector_type(16)));
using bf16x32_t = bf16_raw_t __attribute__((ext_vector_type(32)));
using bf16x64_t = bf16_raw_t __attribute__((ext_vector_type(64)));
using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2)));
using bf16x4_t = bfloat16_t __attribute__((ext_vector_type(4)));
using bf16x8_t = bfloat16_t __attribute__((ext_vector_type(8)));
using bf16x16_t = bfloat16_t __attribute__((ext_vector_type(16)));
using bf16x32_t = bfloat16_t __attribute__((ext_vector_type(32)));
using bf16x64_t = bfloat16_t __attribute__((ext_vector_type(64)));
// i32
// using int32_t = ...