Split hstu_attention_util.hpp into host_util.hpp and kernel_util.hpp

This commit is contained in:
Qianfeng Zhang
2026-06-06 06:41:25 +00:00
parent 89d6f5aa92
commit c7de3af246
8 changed files with 40 additions and 34 deletions

View File

@@ -27,7 +27,7 @@
#include "hstu_attention_params.hpp"
#include "reference_hstu_attention_fwd.hpp"
#include "hstu_attention_util.hpp"
#include "hstu_attention_host_util.hpp"
#include "hstu_attention_api.hpp"
template <typename T>

View File

@@ -4,7 +4,6 @@
#pragma once
#include <ck_tile/core.hpp>
#include <ck_tile/ops/common.hpp>
#include <ck_tile/ops/fmha/block/block_dropout.hpp>
#include <string>
@@ -13,7 +12,7 @@
#include <variant>
#include "hstu_block_masking.hpp"
#include "hstu_attention_util.hpp"
#include "hstu_attention_kernel_util.hpp"
#ifndef HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM
#define HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM 1

View File

@@ -7,7 +7,7 @@
#include "hstu_attention_fwd_type_config.hpp"
#include "hstu_attention_tile_setting_define.hpp"
#include "hstu_attention_util.hpp"
#include "hstu_attention_host_util.hpp"
using WarpTile_16x16x16 = ck_tile::sequence<16, 16, 16>;
using WarpTile_16x16x32 = ck_tile::sequence<16, 16, 32>;

View File

@@ -4,7 +4,6 @@
#pragma once
#include <ck_tile/core.hpp>
#include <ck_tile/ops/common.hpp>
#include <ck_tile/ops/fmha/block/block_dropout.hpp>
#include <string>
@@ -13,7 +12,7 @@
#include <variant>
#include "hstu_block_masking.hpp"
#include "hstu_attention_util.hpp"
#include "hstu_attention_kernel_util.hpp"
#ifndef HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM
#define HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM 1

View File

@@ -0,0 +1,32 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <stdexcept>
#include <ck_tile/host/hip_check_error.hpp>
#define HSTU_CHECK(COND, ERR) \
if(!(COND)) \
{ \
std::ostringstream ostr; \
ostr << "'" #COND "' failed: " << ERR; \
throw std::runtime_error(ostr.str()); \
}
static inline int get_number_of_cu()
{
int device;
HIP_CHECK_ERROR(hipGetDevice(&device));
hipDeviceProp_t props;
HIP_CHECK_ERROR(hipGetDeviceProperties(&props, device));
return props.multiProcessorCount;
}

View File

@@ -4,33 +4,7 @@
#pragma once
#include <iostream>
#include <sstream>
#include <stdexcept>
#include <ck_tile/core.hpp>
#include <ck_tile/host/hip_check_error.hpp>
#define HSTU_CHECK(COND, ERR) \
if(!(COND)) \
{ \
std::ostringstream ostr; \
ostr << "'" #COND "' failed: " << ERR; \
throw std::runtime_error(ostr.str()); \
}
static inline int get_number_of_cu()
{
int device;
HIP_CHECK_ERROR(hipGetDevice(&device));
hipDeviceProp_t props;
HIP_CHECK_ERROR(hipGetDeviceProperties(&props, device));
return props.multiProcessorCount;
}
namespace ck_tile {

View File

@@ -7,7 +7,7 @@
#include <ck_tile/ops/fmha/block/block_dropout.hpp>
#include "hstu_attention_fwd_pipeline_policy.hpp"
#include "hstu_attention_util.hpp"
#include "hstu_attention_kernel_util.hpp"
namespace ck_tile {

View File

@@ -3,7 +3,9 @@
#pragma once
#include "hstu_attention_util.hpp"
#include <ck_tile/core/numeric/math.hpp>
#include "hstu_attention_host_util.hpp"
static float get_estimated_cu_coverage_ratio(int num_batches, int num_heads, int max_seqlen_q)
{