mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
flashattention fwd add (80, 96) instance (#3415)
* add hdim (96,96) instance * change to (80,96) * format py * remove 96 in optdim * when N=6 change to llvm_amdgcn_raw_buffer_load_i32x3
This commit is contained in:
@@ -1121,6 +1121,20 @@ llvm_amdgcn_raw_buffer_load_i32x2(int32x4_t srsrc,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i32");
|
||||
|
||||
// dwordx3 - use union to convert between int32x3 and fp16/bf16 types
|
||||
union dwordx3_union
|
||||
{
|
||||
int32_t as_i32[3];
|
||||
fp16_t as_fp16[6];
|
||||
bf16_t as_bf16[6];
|
||||
};
|
||||
|
||||
CK_TILE_DEVICE_EXTERN int32x3_t
|
||||
llvm_amdgcn_raw_buffer_load_i32x3(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v3i32");
|
||||
|
||||
CK_TILE_DEVICE_EXTERN int32x4_t
|
||||
llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
@@ -1540,9 +1554,9 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
|
||||
(std::is_same<T, double>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, fp16_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) ||
|
||||
(N == 1 || N == 2 || N == 4 || N == 6 || N == 8 || N == 16 || N == 32)) ||
|
||||
(std::is_same<T, bf16_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) ||
|
||||
(N == 1 || N == 2 || N == 4 || N == 6 || N == 8 || N == 16 || N == 32)) ||
|
||||
(std::is_same<T, int32_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
@@ -1659,6 +1673,26 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence)));
|
||||
}
|
||||
else if constexpr(N == 6)
|
||||
{
|
||||
// N = 6: load as dwordx3 (12 bytes = 6 fp16), using buffer_load_dwordx3 instruction
|
||||
int32x3_t tmp_i32x3 =
|
||||
llvm_amdgcn_raw_buffer_load_i32x3(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
// Use union to reinterpret int32x3 as fp16x6
|
||||
dwordx3_union tmp_union;
|
||||
tmp_union.as_i32[0] = tmp_i32x3[0];
|
||||
tmp_union.as_i32[1] = tmp_i32x3[1];
|
||||
tmp_union.as_i32[2] = tmp_i32x3[2];
|
||||
|
||||
thread_buffer<fp16_t, N> result;
|
||||
static_for<0, N, 1>{}([&](auto i) { result[i] = tmp_union.as_fp16[i]; });
|
||||
|
||||
return result;
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
// use fp32 load to mimic fp16 load
|
||||
@@ -1744,6 +1778,26 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence)));
|
||||
}
|
||||
else if constexpr(N == 6)
|
||||
{
|
||||
// N = 6: load as dwordx3 (12 bytes = 6 bf16), using buffer_load_dwordx3 instruction
|
||||
int32x3_t tmp_i32x3 =
|
||||
llvm_amdgcn_raw_buffer_load_i32x3(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
// Use union to reinterpret int32x3 as bf16x6
|
||||
dwordx3_union tmp_union;
|
||||
tmp_union.as_i32[0] = tmp_i32x3[0];
|
||||
tmp_union.as_i32[1] = tmp_i32x3[1];
|
||||
tmp_union.as_i32[2] = tmp_i32x3[2];
|
||||
|
||||
thread_buffer<bf16_t, N> result;
|
||||
static_for<0, N, 1>{}([&](auto i) { result[i] = tmp_union.as_bf16[i]; });
|
||||
|
||||
return result;
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
|
||||
|
||||
@@ -989,6 +989,20 @@ llvm_amdgcn_raw_buffer_load_i32x2(int32x4_t srsrc,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i32");
|
||||
|
||||
// dwordx3 - use union to convert between int32x3 and fp16/bf16 types
|
||||
union dwordx3_union
|
||||
{
|
||||
int32_t as_i32[3];
|
||||
fp16_t as_fp16[6];
|
||||
bf16_t as_bf16[6];
|
||||
};
|
||||
|
||||
CK_TILE_DEVICE_EXTERN int32x3_t
|
||||
llvm_amdgcn_raw_buffer_load_i32x3(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v3i32");
|
||||
|
||||
CK_TILE_DEVICE_EXTERN int32x4_t
|
||||
llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
@@ -1408,9 +1422,9 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
|
||||
(std::is_same<T, double>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, fp16_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) ||
|
||||
(N == 1 || N == 2 || N == 4 || N == 6 || N == 8 || N == 16 || N == 32)) ||
|
||||
(std::is_same<T, bf16_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) ||
|
||||
(N == 1 || N == 2 || N == 4 || N == 6 || N == 8 || N == 16 || N == 32)) ||
|
||||
(std::is_same<T, int32_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
@@ -1529,6 +1543,26 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence)));
|
||||
}
|
||||
else if constexpr(N == 6)
|
||||
{
|
||||
// N = 6: load as dwordx3 (12 bytes = 6 fp16), using buffer_load_dwordx3 instruction
|
||||
int32x3_t tmp_i32x3 =
|
||||
llvm_amdgcn_raw_buffer_load_i32x3(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
// Use union to reinterpret int32x3 as fp16x6
|
||||
dwordx3_union tmp_union;
|
||||
tmp_union.as_i32[0] = tmp_i32x3[0];
|
||||
tmp_union.as_i32[1] = tmp_i32x3[1];
|
||||
tmp_union.as_i32[2] = tmp_i32x3[2];
|
||||
|
||||
thread_buffer<fp16_t, N> result;
|
||||
static_for<0, N, 1>{}([&](auto i) { result[i] = tmp_union.as_fp16[i]; });
|
||||
|
||||
return result;
|
||||
}
|
||||
else
|
||||
{
|
||||
// N >= 8: build from fp32x4 chunks
|
||||
@@ -1571,6 +1605,26 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence)));
|
||||
}
|
||||
else if constexpr(N == 6)
|
||||
{
|
||||
// N = 6: load as dwordx3 (12 bytes = 6 bf16), using buffer_load_dwordx3 instruction
|
||||
int32x3_t tmp_i32x3 =
|
||||
llvm_amdgcn_raw_buffer_load_i32x3(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
// Use union to reinterpret int32x3 as bf16x6
|
||||
dwordx3_union tmp_union;
|
||||
tmp_union.as_i32[0] = tmp_i32x3[0];
|
||||
tmp_union.as_i32[1] = tmp_i32x3[1];
|
||||
tmp_union.as_i32[2] = tmp_i32x3[2];
|
||||
|
||||
thread_buffer<bf16_t, N> result;
|
||||
static_for<0, N, 1>{}([&](auto i) { result[i] = tmp_union.as_bf16[i]; });
|
||||
|
||||
return result;
|
||||
}
|
||||
else
|
||||
{
|
||||
// N >= 8: build from fp32x4 chunks
|
||||
|
||||
@@ -152,6 +152,7 @@ using bf16x64_t = bfloat16_t __attribute__((ext_vector_type(64)));
|
||||
// i32
|
||||
// using int32_t = ...
|
||||
using int32x2_t = int32_t __attribute__((ext_vector_type(2)));
|
||||
using int32x3_t = int32_t __attribute__((ext_vector_type(3)));
|
||||
using int32x4_t = int32_t __attribute__((ext_vector_type(4)));
|
||||
using int32x8_t = int32_t __attribute__((ext_vector_type(8)));
|
||||
using int32x16_t = int32_t __attribute__((ext_vector_type(16)));
|
||||
|
||||
Reference in New Issue
Block a user