diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index 9c81207361..fbd6551091 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -47,7 +47,7 @@ set(FMHA_FWD_CODE_GEN_COMMON_ARGS ${CMAKE_CURRENT_LIST_DIR}/generate.py --targets ${FMHA_TARGETS_ARG} --api ${FMHA_FWD_APIS} - --optdim 32,64,128,256 + --optdim 32,64,80,128,256 # --filter fmha_fwd... ) set(FMHA_BWD_CODE_GEN_COMMON_ARGS diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index d157a165fc..0cffb2642c 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -40,7 +40,16 @@ DTYPE_BITS = { "bf8": 8, } -K0_MAX_SUBMAX_MAP = {32: 32, 48: 48, 64: 64, 96: 128, 128: 128, 192: 192, 256: 256} +K0_MAX_SUBMAX_MAP = { + 32: 32, + 48: 48, + 64: 64, + 80: 96, + 96: 128, + 128: 128, + 192: 192, + 256: 256, +} FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n @@ -930,6 +939,7 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9): ( 64, 64) : [FmhaFwdTileSize( 16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), FmhaFwdTileSize( 32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + ( 80, 96) : [FmhaFwdTileSize(128, 128, 16, 96, 32, 80, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], ( 96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (128, 128) : [FmhaFwdTileSize( 16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), FmhaFwdTileSize( 32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 9c2ce62856..9f79bdbee6 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -1121,6 +1121,20 @@ llvm_amdgcn_raw_buffer_load_i32x2(int32x4_t srsrc, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i32"); +// dwordx3 - use union to convert between int32x3 and fp16/bf16 types +union dwordx3_union +{ + int32_t as_i32[3]; + fp16_t as_fp16[6]; + bf16_t as_bf16[6]; +}; + +CK_TILE_DEVICE_EXTERN int32x3_t +llvm_amdgcn_raw_buffer_load_i32x3(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v3i32"); + CK_TILE_DEVICE_EXTERN int32x4_t llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc, index_t voffset, @@ -1540,9 +1554,9 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && - (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) || + (N == 1 || N == 2 || N == 4 || N == 6 || N == 8 || N == 16 || N == 32)) || (std::is_same::value && - (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) || + (N == 1 || N == 2 || N == 4 || N == 6 || N == 8 || N == 16 || N == 32)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || @@ -1659,6 +1673,26 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe src_wave_addr_offset, static_cast(coherence))); } + else if constexpr(N == 6) + { + // N = 6: load as dwordx3 (12 bytes = 6 fp16), using buffer_load_dwordx3 instruction + int32x3_t tmp_i32x3 = + llvm_amdgcn_raw_buffer_load_i32x3(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + + // Use union to reinterpret int32x3 as fp16x6 + dwordx3_union tmp_union; + tmp_union.as_i32[0] = tmp_i32x3[0]; + tmp_union.as_i32[1] = tmp_i32x3[1]; + tmp_union.as_i32[2] = tmp_i32x3[2]; + + thread_buffer result; + static_for<0, N, 1>{}([&](auto i) { result[i] = tmp_union.as_fp16[i]; }); + + return result; + } else if constexpr(N == 8) { // use fp32 load to mimic fp16 load @@ -1744,6 +1778,26 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe src_wave_addr_offset, static_cast(coherence))); } + else if constexpr(N == 6) + { + // N = 6: load as dwordx3 (12 bytes = 6 bf16), using buffer_load_dwordx3 instruction + int32x3_t tmp_i32x3 = + llvm_amdgcn_raw_buffer_load_i32x3(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + + // Use union to reinterpret int32x3 as bf16x6 + dwordx3_union tmp_union; + tmp_union.as_i32[0] = tmp_i32x3[0]; + tmp_union.as_i32[1] = tmp_i32x3[1]; + tmp_union.as_i32[2] = tmp_i32x3[2]; + + thread_buffer result; + static_for<0, N, 1>{}([&](auto i) { result[i] = tmp_union.as_bf16[i]; }); + + return result; + } else if constexpr(N == 8) { int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, diff --git a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index 764df83539..4627b249d6 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -989,6 +989,20 @@ llvm_amdgcn_raw_buffer_load_i32x2(int32x4_t srsrc, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i32"); +// dwordx3 - use union to convert between int32x3 and fp16/bf16 types +union dwordx3_union +{ + int32_t as_i32[3]; + fp16_t as_fp16[6]; + bf16_t as_bf16[6]; +}; + +CK_TILE_DEVICE_EXTERN int32x3_t +llvm_amdgcn_raw_buffer_load_i32x3(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v3i32"); + CK_TILE_DEVICE_EXTERN int32x4_t llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc, index_t voffset, @@ -1408,9 +1422,9 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && - (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) || + (N == 1 || N == 2 || N == 4 || N == 6 || N == 8 || N == 16 || N == 32)) || (std::is_same::value && - (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) || + (N == 1 || N == 2 || N == 4 || N == 6 || N == 8 || N == 16 || N == 32)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || @@ -1529,6 +1543,26 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe src_wave_addr_offset, static_cast(coherence))); } + else if constexpr(N == 6) + { + // N = 6: load as dwordx3 (12 bytes = 6 fp16), using buffer_load_dwordx3 instruction + int32x3_t tmp_i32x3 = + llvm_amdgcn_raw_buffer_load_i32x3(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + + // Use union to reinterpret int32x3 as fp16x6 + dwordx3_union tmp_union; + tmp_union.as_i32[0] = tmp_i32x3[0]; + tmp_union.as_i32[1] = tmp_i32x3[1]; + tmp_union.as_i32[2] = tmp_i32x3[2]; + + thread_buffer result; + static_for<0, N, 1>{}([&](auto i) { result[i] = tmp_union.as_fp16[i]; }); + + return result; + } else { // N >= 8: build from fp32x4 chunks @@ -1571,6 +1605,26 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe src_wave_addr_offset, static_cast(coherence))); } + else if constexpr(N == 6) + { + // N = 6: load as dwordx3 (12 bytes = 6 bf16), using buffer_load_dwordx3 instruction + int32x3_t tmp_i32x3 = + llvm_amdgcn_raw_buffer_load_i32x3(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + + // Use union to reinterpret int32x3 as bf16x6 + dwordx3_union tmp_union; + tmp_union.as_i32[0] = tmp_i32x3[0]; + tmp_union.as_i32[1] = tmp_i32x3[1]; + tmp_union.as_i32[2] = tmp_i32x3[2]; + + thread_buffer result; + static_for<0, N, 1>{}([&](auto i) { result[i] = tmp_union.as_bf16[i]; }); + + return result; + } else { // N >= 8: build from fp32x4 chunks diff --git a/include/ck_tile/core/numeric/vector_type.hpp b/include/ck_tile/core/numeric/vector_type.hpp index 6921210b34..90ddc2a56e 100644 --- a/include/ck_tile/core/numeric/vector_type.hpp +++ b/include/ck_tile/core/numeric/vector_type.hpp @@ -152,6 +152,7 @@ using bf16x64_t = bfloat16_t __attribute__((ext_vector_type(64))); // i32 // using int32_t = ... using int32x2_t = int32_t __attribute__((ext_vector_type(2))); +using int32x3_t = int32_t __attribute__((ext_vector_type(3))); using int32x4_t = int32_t __attribute__((ext_vector_type(4))); using int32x8_t = int32_t __attribute__((ext_vector_type(8))); using int32x16_t = int32_t __attribute__((ext_vector_type(16))); diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp index ee5238869f..4045e31b17 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp @@ -12,6 +12,8 @@ static CK_TILE_HOST_DEVICE constexpr index_t ceil_to_qualified_tile_length() { if constexpr(Headdim == 48) return 48; + else if constexpr(Headdim == 80) + return 96; else if constexpr(Headdim == 96) return 128; else if constexpr(Headdim == 160)