From c7de3af246e025f0c32aba113cc520f7d21107c8 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Sat, 6 Jun 2026 06:41:25 +0000 Subject: [PATCH] Split hstu_attention_util.hpp into host_util.hpp and kernel_util.hpp --- .../example_hstu_attention_fwd.cpp | 2 +- .../hstu_attention_fwd_kernel.hpp | 3 +- .../hstu_attention_fwd_setting.hpp | 2 +- .../hstu_attention_fwd_splitkv_kernel.hpp | 3 +- .../hstu_attention_host_util.hpp | 32 +++++++++++++++++++ ...til.hpp => hstu_attention_kernel_util.hpp} | 26 --------------- ...tention_no_softmax_fwd_trload_pipeline.hpp | 2 +- .../hstu_attention_splitkv_helper.hpp | 4 ++- 8 files changed, 40 insertions(+), 34 deletions(-) create mode 100644 example/ck_tile/18_hstu_attention/hstu_attention_host_util.hpp rename example/ck_tile/18_hstu_attention/{hstu_attention_util.hpp => hstu_attention_kernel_util.hpp} (85%) diff --git a/example/ck_tile/18_hstu_attention/example_hstu_attention_fwd.cpp b/example/ck_tile/18_hstu_attention/example_hstu_attention_fwd.cpp index 901993f805..7e1cd9e643 100644 --- a/example/ck_tile/18_hstu_attention/example_hstu_attention_fwd.cpp +++ b/example/ck_tile/18_hstu_attention/example_hstu_attention_fwd.cpp @@ -27,7 +27,7 @@ #include "hstu_attention_params.hpp" #include "reference_hstu_attention_fwd.hpp" -#include "hstu_attention_util.hpp" +#include "hstu_attention_host_util.hpp" #include "hstu_attention_api.hpp" template diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp index d0ecab36bd..1ea3309163 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp @@ -4,7 +4,6 @@ #pragma once #include -#include #include #include @@ -13,7 +12,7 @@ #include #include "hstu_block_masking.hpp" -#include "hstu_attention_util.hpp" +#include "hstu_attention_kernel_util.hpp" #ifndef HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM #define HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM 1 diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp index 98c912fc5f..ce6dc897e1 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp @@ -7,7 +7,7 @@ #include "hstu_attention_fwd_type_config.hpp" #include "hstu_attention_tile_setting_define.hpp" -#include "hstu_attention_util.hpp" +#include "hstu_attention_host_util.hpp" using WarpTile_16x16x16 = ck_tile::sequence<16, 16, 16>; using WarpTile_16x16x32 = ck_tile::sequence<16, 16, 32>; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_kernel.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_kernel.hpp index 75d04f1262..a91b667619 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_kernel.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_kernel.hpp @@ -4,7 +4,6 @@ #pragma once #include -#include #include #include @@ -13,7 +12,7 @@ #include #include "hstu_block_masking.hpp" -#include "hstu_attention_util.hpp" +#include "hstu_attention_kernel_util.hpp" #ifndef HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM #define HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM 1 diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_host_util.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_host_util.hpp new file mode 100644 index 0000000000..26d266d7f0 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/hstu_attention_host_util.hpp @@ -0,0 +1,32 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include + +#define HSTU_CHECK(COND, ERR) \ + if(!(COND)) \ + { \ + std::ostringstream ostr; \ + ostr << "'" #COND "' failed: " << ERR; \ + throw std::runtime_error(ostr.str()); \ + } + +static inline int get_number_of_cu() +{ + int device; + + HIP_CHECK_ERROR(hipGetDevice(&device)); + + hipDeviceProp_t props; + + HIP_CHECK_ERROR(hipGetDeviceProperties(&props, device)); + + return props.multiProcessorCount; +} diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_util.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_kernel_util.hpp similarity index 85% rename from example/ck_tile/18_hstu_attention/hstu_attention_util.hpp rename to example/ck_tile/18_hstu_attention/hstu_attention_kernel_util.hpp index a891c1bc88..45da8261a3 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_util.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_kernel_util.hpp @@ -4,33 +4,7 @@ #pragma once -#include -#include -#include - #include -#include - -#define HSTU_CHECK(COND, ERR) \ - if(!(COND)) \ - { \ - std::ostringstream ostr; \ - ostr << "'" #COND "' failed: " << ERR; \ - throw std::runtime_error(ostr.str()); \ - } - -static inline int get_number_of_cu() -{ - int device; - - HIP_CHECK_ERROR(hipGetDevice(&device)); - - hipDeviceProp_t props; - - HIP_CHECK_ERROR(hipGetDeviceProperties(&props, device)); - - return props.multiProcessorCount; -} namespace ck_tile { diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp index 14e7897c88..9a6949d382 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp @@ -7,7 +7,7 @@ #include #include "hstu_attention_fwd_pipeline_policy.hpp" -#include "hstu_attention_util.hpp" +#include "hstu_attention_kernel_util.hpp" namespace ck_tile { diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_splitkv_helper.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_splitkv_helper.hpp index 728dcb3095..8fec82f739 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_splitkv_helper.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_splitkv_helper.hpp @@ -3,7 +3,9 @@ #pragma once -#include "hstu_attention_util.hpp" +#include + +#include "hstu_attention_host_util.hpp" static float get_estimated_cu_coverage_ratio(int num_batches, int num_heads, int max_seqlen_q) {