diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index ec84bb2a06..bd03aee924 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -13,9 +13,9 @@ foreach(api ${FMHA_FWD_ENABLE_APIS}) endforeach() # "fwd" is a must-have api for the fmha_fwd example, add it if not specified -# if(NOT "fwd" IN_LIST FMHA_FWD_ENABLE_APIS) -# list(APPEND FMHA_FWD_ENABLE_APIS "fwd") -# endif() +if(NOT "fwd" IN_LIST FMHA_FWD_ENABLE_APIS) + list(APPEND FMHA_FWD_ENABLE_APIS "fwd") +endif() file(GLOB_RECURSE CODE_GEN_SCRIPTS CONFIGURE_DEPENDS ${CMAKE_CURRENT_LIST_DIR}/generate.py @@ -109,7 +109,6 @@ if(FMHA_FWD_FAST_EXP2) else() list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0) endif() -# list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker) list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -fgpu-flush-denormals-to-zero) # conditionally enable call to the fwd_splitkv API in fmha_fwd example diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index d873388876..5b7a71538d 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -676,51 +676,6 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::FillUniformDistribution{0.f, 1.f, seed}(vnew_host); ck_tile::FillUniformDistribution{0.f, 1.f, seed}(bias_host); } - else if(init_method == "v1" || init_method == "97") - { - ck_tile::FillUniformDistribution{0.f, 1.f, seed}(q_host); - ck_tile::FillUniformDistribution{0.f, 1.f, seed}(k_host); - ck_tile::FillUniformDistribution{0.f, 1.f, seed}(knew_host); - ck_tile::FillUniformDistribution{1.f, 1.f, seed}(v_host); - ck_tile::FillUniformDistribution{1.f, 1.f, seed}(vnew_host); - ck_tile::FillUniformDistribution{0.f, 1.f, seed}(bias_host); - } - else if(init_method == "k1" || init_method == "96") - { - ck_tile::FillUniformDistribution{0.f, 1.f, seed}(q_host); - ck_tile::FillUniformDistribution{1.f, 1.f, seed}(k_host); - ck_tile::FillUniformDistribution{1.f, 1.f, seed}(knew_host); - ck_tile::FillUniformDistribution{0.f, 1.f, seed}(v_host); - ck_tile::FillUniformDistribution{0.f, 1.f, seed}(vnew_host); - ck_tile::FillUniformDistribution{0.f, 1.f, seed}(bias_host); - } - else if(init_method == "q1" || init_method == "95") - { - ck_tile::FillUniformDistribution{1.f, 1.f, seed}(q_host); - ck_tile::FillUniformDistribution{0.f, 1.f, seed}(k_host); - ck_tile::FillUniformDistribution{0.f, 1.f, seed}(knew_host); - ck_tile::FillUniformDistribution{0.f, 1.f, seed}(v_host); - ck_tile::FillUniformDistribution{0.f, 1.f, seed}(vnew_host); - ck_tile::FillUniformDistribution{0.f, 1.f, seed}(bias_host); - } - else if(init_method == "kv1" || init_method == "98") - { - ck_tile::FillUniformDistribution{0.f, 1.f, seed}(q_host); - ck_tile::FillUniformDistribution{1.f, 1.f, seed}(k_host); - ck_tile::FillUniformDistribution{1.f, 1.f, seed}(knew_host); - ck_tile::FillUniformDistribution{1.f, 1.f, seed}(v_host); - ck_tile::FillUniformDistribution{1.f, 1.f, seed}(vnew_host); - ck_tile::FillUniformDistribution{0.f, 1.f, seed}(bias_host); - } - else if(init_method == "qkv1" || init_method == "99") - { - ck_tile::FillUniformDistribution{1.f, 1.f, seed}(q_host); - ck_tile::FillUniformDistribution{1.f, 1.f, seed}(k_host); - ck_tile::FillUniformDistribution{1.f, 1.f, seed}(knew_host); - ck_tile::FillUniformDistribution{1.f, 1.f, seed}(v_host); - ck_tile::FillUniformDistribution{1.f, 1.f, seed}(vnew_host); - ck_tile::FillUniformDistribution{0.f, 1.f, seed}(bias_host); - } else if(init_method == "nf") { ck_tile::FillNormalDistribution{0.f, 3.f, seed}(q_host); @@ -1144,7 +1099,8 @@ bool run(const ck_tile::ArgParser& arg_parser) return fmha_fwd_splitkv(fmha_splitkv_traits, fmha_splitkv_args, stream_config); } -#elif CK_TILE_FMHA_FWD_PAGEDKV_API +#endif +#if CK_TILE_FMHA_FWD_PAGEDKV_API if(use_kvcache) { fmha_fwd_pagedkv_traits fmha_pagedkv_traits; @@ -1155,7 +1111,7 @@ bool run(const ck_tile::ArgParser& arg_parser) return fmha_fwd_pagedkv(fmha_pagedkv_traits, fmha_pagedkv_args, stream_config); } -#else +#endif fmha_fwd_traits fmha_traits; init_traits(fmha_traits); @@ -1163,7 +1119,6 @@ bool run(const ck_tile::ArgParser& arg_parser) init_args(fmha_args); return fmha_fwd(fmha_traits, fmha_args, stream_config); -#endif }(); if(appendkv_ave_time < 0.0f || fwd_ave_time < 0.0f) diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 07a2a6a666..07be65a150 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -41,10 +41,6 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz { buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD}; int32x4_t r = __builtin_bit_cast(int32x4_t, res); - // r.x = __builtin_amdgcn_readfirstlane(r.x); - // r.y = __builtin_amdgcn_readfirstlane(r.y); - // r.z = __builtin_amdgcn_readfirstlane(r.z); - // r.w = __builtin_amdgcn_readfirstlane(r.w); return r; } diff --git a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index 740b2100ef..c64b296408 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -32,10 +32,6 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz { buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD}; int32x4_t r = __builtin_bit_cast(int32x4_t, res); - // r.x = __builtin_amdgcn_readfirstlane(r.x); - // r.y = __builtin_amdgcn_readfirstlane(r.y); - // r.z = __builtin_amdgcn_readfirstlane(r.z); - // r.w = __builtin_amdgcn_readfirstlane(r.w); return r; } diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index 53ac69570d..4213397cdf 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -82,8 +82,7 @@ CK_TILE_DEVICE index_t get_lane_id() { return __lane_id(); } CK_TILE_DEVICE index_t get_warp_id() { - // return __builtin_amdgcn_readfirstlane(threadIdx.x / get_warp_size()); - return threadIdx.x / get_warp_size(); + return __builtin_amdgcn_readfirstlane(threadIdx.x / get_warp_size()); } CK_TILE_DEVICE index_t get_thread_id() { return threadIdx.x; } diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index c471f416c3..555a047e40 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -191,6 +191,15 @@ #endif #endif +// use llvm builtin bf16 data type after ROCm 6.5 +#ifndef CK_TILE_USE_LLVM_BUILTIN_BF16 +#if(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 5 && HIP_VERSION_PATCH >= 50421) || \ + (HIP_VERSION_MAJOR >= 7) +#define CK_TILE_USE_LLVM_BUILTIN_BF16 1 +#else +#define CK_TILE_USE_LLVM_BUILTIN_BF16 0 +#endif + #ifndef CK_TILE_DEBUG_LOG #define CK_TILE_DEBUG_LOG 0 #endif diff --git a/include/ck_tile/core/numeric/bfloat16.hpp b/include/ck_tile/core/numeric/bfloat16.hpp index 86bb1740a9..245fb7244f 100644 --- a/include/ck_tile/core/numeric/bfloat16.hpp +++ b/include/ck_tile/core/numeric/bfloat16.hpp @@ -6,7 +6,7 @@ #include "ck_tile/core/numeric/half.hpp" #include "ck_tile/core/numeric/integral_constant.hpp" #include "ck_tile/core/numeric/numeric.hpp" -#if defined(__gfx950__) +#if CK_TILE_USE_LLVM_BUILTIN_BF16 #include #endif #include @@ -105,8 +105,7 @@ struct native_t using bf16_t = bfloat16_t; using bf16_raw_t = typename bf16_t::raw_type; #else -#if(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 5 && HIP_VERSION_PATCH >= 50421) || \ - (HIP_VERSION_MAJOR >= 7) +#if CK_TILE_USE_LLVM_BUILTIN_BF16 using bfloat16_t = __bf16; #else using bfloat16_t = ushort; diff --git a/include/ck_tile/core/tensor/load_tile_transpose.hpp b/include/ck_tile/core/tensor/load_tile_transpose.hpp index b359fe68ba..1535250722 100644 --- a/include/ck_tile/core/tensor/load_tile_transpose.hpp +++ b/include/ck_tile/core/tensor/load_tile_transpose.hpp @@ -14,7 +14,6 @@ #include "ck_tile/core/container/statically_indexed_array.hpp" #include "ck_tile/core/numeric/math.hpp" #include "ck_tile/core/utility/type_traits.hpp" -#include "ck_tile/core/utility/debug.hpp" namespace ck_tile { @@ -131,7 +130,6 @@ struct DefaultTranspose util::is_sequence_suffix_v; static constexpr bool suffix_valid_dim1 = util::is_sequence_suffix_v; - // using bbb = decltype(CK_PRINT()); // 3. PS→RHS mapping constraints static constexpr auto input_ps_major = InDstrEncode::ps_to_rhss_major_; @@ -171,8 +169,6 @@ struct DefaultTranspose static constexpr bool ys_mapping_valid = (input_ys_major.back() == 2) && (input_ys_minor.back() == input_hs[I1].size() - 1); - // using aaa = decltype(CK_PRINT()); static constexpr bool value = dims_valid && suffix_valid_dim0 && suffix_valid_dim1 && ps_mapping_valid && ys_mapping_valid; }; diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index 11dd234b83..ad5902f16e 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -288,8 +288,7 @@ struct tile_window_with_static_distribution sizeof(LdsDataType) - size_per_buf; - const index_t m0_init_value = - __builtin_amdgcn_readfirstlane(size_per_buf + size_per_wave * get_warp_id()); + const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id(); m0_set_with_memory(m0_init_value); // This should be wave independent using Traits = typename Base::Traits; @@ -434,8 +433,6 @@ struct tile_window_with_static_distribution // data index [y0, y1, ...] constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); - // printf("Tid: %03d, tr_load_idx: %d\n", - // get_thread_local_1d_id(),bottom_tensor_thread_coord.get_offset()); // read from bottom tensor const vector_t vec_value = this->get_bottom_tensor_view() diff --git a/include/ck_tile/core/tensor/tile_window_linear.hpp b/include/ck_tile/core/tensor/tile_window_linear.hpp index bbb3e26588..b5a89e5f51 100644 --- a/include/ck_tile/core/tensor/tile_window_linear.hpp +++ b/include/ck_tile/core/tensor/tile_window_linear.hpp @@ -516,8 +516,7 @@ struct tile_window_linear sizeof(LdsDataType) - size_per_buf; - const index_t m0_init_value = - __builtin_amdgcn_readfirstlane(size_per_buf + size_per_wave * get_warp_id()); + const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id(); m0_set_with_memory(m0_init_value); // This should be wave independent using vector_t = typename Base::Traits::vector_t; diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index a419ea8ed3..501aa26667 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -1071,20 +1071,20 @@ struct FmhaFwdSplitKVKernel { return FmhaPipeline{}(q_dram_window, k_dram_window_lengths, - // k_page_block_navigator, + k_page_block_navigator, v_dram_window_lengths, - // v_page_block_navigator, + v_page_block_navigator, bias_dram_window, lse_acc_dram_window, - // kargs.num_splits, - // i_split_, + kargs.num_splits, + i_split_, mask, position_encoding, kargs.scale_s, variant, variant_params, block_indices, - // kv_l2p_offset, + kv_l2p_offset, smem_ptr); } }(); diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index d583bf84d1..8c6f39e511 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -45,7 +45,7 @@ template<> struct WarpGemmMfmaDispatcher struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; }; -// template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleBTransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleBTransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution; }; // fp16 2:4 structural sparsity @@ -76,7 +76,7 @@ template<> struct WarpGemmMfmaDispatcher struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; }; -// template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleBTransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleBTransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution; }; // fp8