flashattention fwd add (80, 96) instance (#3415)

* add hdim (96,96) instance

* change to (80,96)

* format py

* remove 96 in optdim

* when N=6 change to llvm_amdgcn_raw_buffer_load_i32x3
This commit is contained in:
ltqin
2025-12-18 01:16:11 +08:00
committed by GitHub
parent fe3d52d9b0
commit 92653168c2
6 changed files with 127 additions and 6 deletions

View File

@@ -1121,6 +1121,20 @@ llvm_amdgcn_raw_buffer_load_i32x2(int32x4_t srsrc,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i32");
// dwordx3 - use union to convert between int32x3 and fp16/bf16 types
union dwordx3_union
{
int32_t as_i32[3];
fp16_t as_fp16[6];
bf16_t as_bf16[6];
};
CK_TILE_DEVICE_EXTERN int32x3_t
llvm_amdgcn_raw_buffer_load_i32x3(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v3i32");
CK_TILE_DEVICE_EXTERN int32x4_t
llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc,
index_t voffset,
@@ -1540,9 +1554,9 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
(std::is_same<T, double>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, fp16_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) ||
(N == 1 || N == 2 || N == 4 || N == 6 || N == 8 || N == 16 || N == 32)) ||
(std::is_same<T, bf16_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) ||
(N == 1 || N == 2 || N == 4 || N == 6 || N == 8 || N == 16 || N == 32)) ||
(std::is_same<T, int32_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
@@ -1659,6 +1673,26 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
src_wave_addr_offset,
static_cast<index_t>(coherence)));
}
else if constexpr(N == 6)
{
// N = 6: load as dwordx3 (12 bytes = 6 fp16), using buffer_load_dwordx3 instruction
int32x3_t tmp_i32x3 =
llvm_amdgcn_raw_buffer_load_i32x3(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
// Use union to reinterpret int32x3 as fp16x6
dwordx3_union tmp_union;
tmp_union.as_i32[0] = tmp_i32x3[0];
tmp_union.as_i32[1] = tmp_i32x3[1];
tmp_union.as_i32[2] = tmp_i32x3[2];
thread_buffer<fp16_t, N> result;
static_for<0, N, 1>{}([&](auto i) { result[i] = tmp_union.as_fp16[i]; });
return result;
}
else if constexpr(N == 8)
{
// use fp32 load to mimic fp16 load
@@ -1744,6 +1778,26 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
src_wave_addr_offset,
static_cast<index_t>(coherence)));
}
else if constexpr(N == 6)
{
// N = 6: load as dwordx3 (12 bytes = 6 bf16), using buffer_load_dwordx3 instruction
int32x3_t tmp_i32x3 =
llvm_amdgcn_raw_buffer_load_i32x3(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
// Use union to reinterpret int32x3 as bf16x6
dwordx3_union tmp_union;
tmp_union.as_i32[0] = tmp_i32x3[0];
tmp_union.as_i32[1] = tmp_i32x3[1];
tmp_union.as_i32[2] = tmp_i32x3[2];
thread_buffer<bf16_t, N> result;
static_for<0, N, 1>{}([&](auto i) { result[i] = tmp_union.as_bf16[i]; });
return result;
}
else if constexpr(N == 8)
{
int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,

View File

@@ -989,6 +989,20 @@ llvm_amdgcn_raw_buffer_load_i32x2(int32x4_t srsrc,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i32");
// dwordx3 - use union to convert between int32x3 and fp16/bf16 types
union dwordx3_union
{
int32_t as_i32[3];
fp16_t as_fp16[6];
bf16_t as_bf16[6];
};
CK_TILE_DEVICE_EXTERN int32x3_t
llvm_amdgcn_raw_buffer_load_i32x3(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v3i32");
CK_TILE_DEVICE_EXTERN int32x4_t
llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc,
index_t voffset,
@@ -1408,9 +1422,9 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
(std::is_same<T, double>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, fp16_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) ||
(N == 1 || N == 2 || N == 4 || N == 6 || N == 8 || N == 16 || N == 32)) ||
(std::is_same<T, bf16_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) ||
(N == 1 || N == 2 || N == 4 || N == 6 || N == 8 || N == 16 || N == 32)) ||
(std::is_same<T, int32_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
@@ -1529,6 +1543,26 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
src_wave_addr_offset,
static_cast<index_t>(coherence)));
}
else if constexpr(N == 6)
{
// N = 6: load as dwordx3 (12 bytes = 6 fp16), using buffer_load_dwordx3 instruction
int32x3_t tmp_i32x3 =
llvm_amdgcn_raw_buffer_load_i32x3(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
// Use union to reinterpret int32x3 as fp16x6
dwordx3_union tmp_union;
tmp_union.as_i32[0] = tmp_i32x3[0];
tmp_union.as_i32[1] = tmp_i32x3[1];
tmp_union.as_i32[2] = tmp_i32x3[2];
thread_buffer<fp16_t, N> result;
static_for<0, N, 1>{}([&](auto i) { result[i] = tmp_union.as_fp16[i]; });
return result;
}
else
{
// N >= 8: build from fp32x4 chunks
@@ -1571,6 +1605,26 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
src_wave_addr_offset,
static_cast<index_t>(coherence)));
}
else if constexpr(N == 6)
{
// N = 6: load as dwordx3 (12 bytes = 6 bf16), using buffer_load_dwordx3 instruction
int32x3_t tmp_i32x3 =
llvm_amdgcn_raw_buffer_load_i32x3(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
// Use union to reinterpret int32x3 as bf16x6
dwordx3_union tmp_union;
tmp_union.as_i32[0] = tmp_i32x3[0];
tmp_union.as_i32[1] = tmp_i32x3[1];
tmp_union.as_i32[2] = tmp_i32x3[2];
thread_buffer<bf16_t, N> result;
static_for<0, N, 1>{}([&](auto i) { result[i] = tmp_union.as_bf16[i]; });
return result;
}
else
{
// N >= 8: build from fp32x4 chunks

View File

@@ -152,6 +152,7 @@ using bf16x64_t = bfloat16_t __attribute__((ext_vector_type(64)));
// i32
// using int32_t = ...
using int32x2_t = int32_t __attribute__((ext_vector_type(2)));
using int32x3_t = int32_t __attribute__((ext_vector_type(3)));
using int32x4_t = int32_t __attribute__((ext_vector_type(4)));
using int32x8_t = int32_t __attribute__((ext_vector_type(8)));
using int32x16_t = int32_t __attribute__((ext_vector_type(16)));