save an example for __bf16 type

This commit is contained in:
aska-0096
2025-06-19 05:11:52 +00:00
parent 4bd5fd4a3c
commit e0a634ef97
4 changed files with 26 additions and 11 deletions

View File

@@ -17,11 +17,14 @@ if(NOT "fwd" IN_LIST FMHA_FWD_ENABLE_APIS)
list(APPEND FMHA_FWD_ENABLE_APIS "fwd")
endif()
# Filtering kernel
set(KERNEL fmha_fwd_d64_bf16_batch_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_w32x32x16_qr_async_vr_psskddv_nlogits_nbias_nmask_nlse_ndropout_nskip_nsquant)
string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}")
# generate a list of kernels, but not actually emit files at config sta
execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api ${FMHA_FWD_APIS} --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt
--api ${FMHA_FWD_APIS} --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt --filter ${KERNEL}
RESULT_VARIABLE ret
)
if(ret AND NOT ret EQUAL 0)
@@ -45,7 +48,7 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS)
add_custom_command(
OUTPUT ${FMHA_FWD_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api ${FMHA_FWD_APIS} --output_dir ${CMAKE_CURRENT_BINARY_DIR}
--api ${FMHA_FWD_APIS} --output_dir ${CMAKE_CURRENT_BINARY_DIR} --filter ${KERNEL}
)
add_custom_command(
@@ -86,6 +89,7 @@ if(FMHA_FWD_FAST_EXP2)
else()
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
endif()
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -fgpu-flush-denormals-to-zero)
# conditionally enable call to the fwd_splitkv API in fmha_fwd example

View File

@@ -6,6 +6,9 @@
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#if defined(__gfx950__)
#include <hip/hip_bfloat16.h>
#endif
#include <stdint.h>
#pragma once
@@ -102,7 +105,11 @@ struct native_t<bfloat16_t>
using bf16_t = bfloat16_t;
using bf16_raw_t = typename bf16_t::raw_type;
#else
#if 1 // ROCm 7.0
using bfloat16_t = __bf16;
#else
using bfloat16_t = ushort;
#endif
using bf16_t = bfloat16_t;
using bf16_raw_t = uint16_t;
#endif
@@ -280,7 +287,11 @@ template <bf16_rounding_mode rounding =
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
CK_TILE_HOST_DEVICE constexpr bfloat16_t float_to_bf16(float f, constant<rounding> = {})
{
#if defined (__gfx950__)
return static_cast<bfloat16_t>(f);
#else
return bit_cast<bfloat16_t>(float_to_bf16_raw(f, constant<rounding>{}));
#endif
}
template <bf16_rounding_mode rounding =

View File

@@ -99,7 +99,7 @@ struct numeric_traits<pk_int4_t>
using fp32x2_t = float __attribute__((ext_vector_type(2)));
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2)));
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t& x)
{
@@ -140,9 +140,9 @@ CK_TILE_HOST_DEVICE bf16x2_t pk_int4_t_to_bfloat16x2_t(const pk_int4_t& x)
float x_h = ((x_u8 & 0xf0) >> 4) - 8.f;
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
bf16x2_t res = {type_convert<bf16_t>(x_h), type_convert<bf16_t>(x_l)};
bf16x2_t res = {static_cast<bf16_t>(x_h), static_cast<bf16_t>(x_l)};
#elif
bf16x2_t res = {type_convert<bf16_t>(x_l), type_convert<bf16_t>(x_h)};
bf16x2_t res = {static_cast<bf16_t>(x_l), static_cast<bf16_t>(x_h)};
#endif
return res;
}

View File

@@ -131,12 +131,12 @@ using fp16x64_t = _Float16 __attribute__((ext_vector_type(64)));
// bf16
// using bf16_t = ...
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
using bf16x4_t = bf16_raw_t __attribute__((ext_vector_type(4)));
using bf16x8_t = bf16_raw_t __attribute__((ext_vector_type(8)));
using bf16x16_t = bf16_raw_t __attribute__((ext_vector_type(16)));
using bf16x32_t = bf16_raw_t __attribute__((ext_vector_type(32)));
using bf16x64_t = bf16_raw_t __attribute__((ext_vector_type(64)));
using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2)));
using bf16x4_t = bfloat16_t __attribute__((ext_vector_type(4)));
using bf16x8_t = bfloat16_t __attribute__((ext_vector_type(8)));
using bf16x16_t = bfloat16_t __attribute__((ext_vector_type(16)));
using bf16x32_t = bfloat16_t __attribute__((ext_vector_type(32)));
using bf16x64_t = bfloat16_t __attribute__((ext_vector_type(64)));
// i32
// using int32_t = ...