mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
save an example for __bf16 type
This commit is contained in:
@@ -17,11 +17,14 @@ if(NOT "fwd" IN_LIST FMHA_FWD_ENABLE_APIS)
|
||||
list(APPEND FMHA_FWD_ENABLE_APIS "fwd")
|
||||
endif()
|
||||
|
||||
# Filtering kernel
|
||||
set(KERNEL fmha_fwd_d64_bf16_batch_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_w32x32x16_qr_async_vr_psskddv_nlogits_nbias_nmask_nlse_ndropout_nskip_nsquant)
|
||||
|
||||
string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}")
|
||||
# generate a list of kernels, but not actually emit files at config sta
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api ${FMHA_FWD_APIS} --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt
|
||||
--api ${FMHA_FWD_APIS} --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt --filter ${KERNEL}
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
@@ -45,7 +48,7 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS)
|
||||
add_custom_command(
|
||||
OUTPUT ${FMHA_FWD_GEN_BLOBS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api ${FMHA_FWD_APIS} --output_dir ${CMAKE_CURRENT_BINARY_DIR}
|
||||
--api ${FMHA_FWD_APIS} --output_dir ${CMAKE_CURRENT_BINARY_DIR} --filter ${KERNEL}
|
||||
)
|
||||
|
||||
add_custom_command(
|
||||
@@ -86,6 +89,7 @@ if(FMHA_FWD_FAST_EXP2)
|
||||
else()
|
||||
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
|
||||
endif()
|
||||
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
|
||||
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -fgpu-flush-denormals-to-zero)
|
||||
|
||||
# conditionally enable call to the fwd_splitkv API in fmha_fwd example
|
||||
|
||||
@@ -6,6 +6,9 @@
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/numeric/numeric.hpp"
|
||||
#if defined(__gfx950__)
|
||||
#include <hip/hip_bfloat16.h>
|
||||
#endif
|
||||
#include <stdint.h>
|
||||
|
||||
#pragma once
|
||||
@@ -102,7 +105,11 @@ struct native_t<bfloat16_t>
|
||||
using bf16_t = bfloat16_t;
|
||||
using bf16_raw_t = typename bf16_t::raw_type;
|
||||
#else
|
||||
#if 1 // ROCm 7.0
|
||||
using bfloat16_t = __bf16;
|
||||
#else
|
||||
using bfloat16_t = ushort;
|
||||
#endif
|
||||
using bf16_t = bfloat16_t;
|
||||
using bf16_raw_t = uint16_t;
|
||||
#endif
|
||||
@@ -280,7 +287,11 @@ template <bf16_rounding_mode rounding =
|
||||
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE constexpr bfloat16_t float_to_bf16(float f, constant<rounding> = {})
|
||||
{
|
||||
#if defined (__gfx950__)
|
||||
return static_cast<bfloat16_t>(f);
|
||||
#else
|
||||
return bit_cast<bfloat16_t>(float_to_bf16_raw(f, constant<rounding>{}));
|
||||
#endif
|
||||
}
|
||||
|
||||
template <bf16_rounding_mode rounding =
|
||||
|
||||
@@ -99,7 +99,7 @@ struct numeric_traits<pk_int4_t>
|
||||
|
||||
using fp32x2_t = float __attribute__((ext_vector_type(2)));
|
||||
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
|
||||
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
|
||||
using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2)));
|
||||
|
||||
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t& x)
|
||||
{
|
||||
@@ -140,9 +140,9 @@ CK_TILE_HOST_DEVICE bf16x2_t pk_int4_t_to_bfloat16x2_t(const pk_int4_t& x)
|
||||
float x_h = ((x_u8 & 0xf0) >> 4) - 8.f;
|
||||
|
||||
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
|
||||
bf16x2_t res = {type_convert<bf16_t>(x_h), type_convert<bf16_t>(x_l)};
|
||||
bf16x2_t res = {static_cast<bf16_t>(x_h), static_cast<bf16_t>(x_l)};
|
||||
#elif
|
||||
bf16x2_t res = {type_convert<bf16_t>(x_l), type_convert<bf16_t>(x_h)};
|
||||
bf16x2_t res = {static_cast<bf16_t>(x_l), static_cast<bf16_t>(x_h)};
|
||||
#endif
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -131,12 +131,12 @@ using fp16x64_t = _Float16 __attribute__((ext_vector_type(64)));
|
||||
|
||||
// bf16
|
||||
// using bf16_t = ...
|
||||
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
|
||||
using bf16x4_t = bf16_raw_t __attribute__((ext_vector_type(4)));
|
||||
using bf16x8_t = bf16_raw_t __attribute__((ext_vector_type(8)));
|
||||
using bf16x16_t = bf16_raw_t __attribute__((ext_vector_type(16)));
|
||||
using bf16x32_t = bf16_raw_t __attribute__((ext_vector_type(32)));
|
||||
using bf16x64_t = bf16_raw_t __attribute__((ext_vector_type(64)));
|
||||
using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2)));
|
||||
using bf16x4_t = bfloat16_t __attribute__((ext_vector_type(4)));
|
||||
using bf16x8_t = bfloat16_t __attribute__((ext_vector_type(8)));
|
||||
using bf16x16_t = bfloat16_t __attribute__((ext_vector_type(16)));
|
||||
using bf16x32_t = bfloat16_t __attribute__((ext_vector_type(32)));
|
||||
using bf16x64_t = bfloat16_t __attribute__((ext_vector_type(64)));
|
||||
|
||||
// i32
|
||||
// using int32_t = ...
|
||||
|
||||
Reference in New Issue
Block a user