diff --git a/example/ck_tile/18_hstu_attention/README.md b/example/ck_tile/18_hstu_attention/README.md index 2dfe038d5f..10e692df00 100644 --- a/example/ck_tile/18_hstu_attention/README.md +++ b/example/ck_tile/18_hstu_attention/README.md @@ -13,8 +13,8 @@ ``` bash #> mkdir build #> cd build - #> ../script/cmake-ck-dev.sh .. gfx942 ; use #> rocminfo |grep "gfx" to check your gpu arch - #> make -j tile_example_hstu_attention + #> ../script/cmake-ck-dev.sh .. gfx942 -G Ninja ; use #> rocminfo |grep "gfx" to check your gpu arch + #> ninja tile_example_hstu_attention ; or using make -j tile_example_hstu_attention ; ``` ## test/verify diff --git a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp index da579605c1..def95129ab 100644 --- a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp +++ b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp @@ -27,11 +27,7 @@ #include "reference_hstu_attention.hpp" #include "hstu_attention_util.hpp" - -extern void hstu_attention_batched_forward_fp16(HstuAttentionFwdParams& param, hipStream_t stream); -extern void hstu_attention_batched_forward_bf16(HstuAttentionFwdParams& param, hipStream_t stream); -extern void hstu_attention_jagged_forward_fp16(HstuAttentionFwdParams& param, hipStream_t stream); -extern void hstu_attention_jagged_forward_bf16(HstuAttentionFwdParams& param, hipStream_t stream); +#include "hstu_attention_api.hpp" template void dumpBufferToFile(const char* fileName, T* data, size_t dataNumItems) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_api.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_api.hpp new file mode 100644 index 0000000000..b50815b8fb --- /dev/null +++ b/example/ck_tile/18_hstu_attention/hstu_attention_api.hpp @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "hstu_attention_params.hpp" + +extern void hstu_attention_batched_forward_fp16(HstuAttentionFwdParams& param, hipStream_t stream); +extern void hstu_attention_batched_forward_bf16(HstuAttentionFwdParams& param, hipStream_t stream); +extern void hstu_attention_jagged_forward_fp16(HstuAttentionFwdParams& param, hipStream_t stream); +extern void hstu_attention_jagged_forward_bf16(HstuAttentionFwdParams& param, hipStream_t stream);