mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 03:07:02 +00:00
Split hstu_attention_util.hpp into host_util.hpp and kernel_util.hpp
This commit is contained in:
@@ -27,7 +27,7 @@
|
||||
#include "hstu_attention_params.hpp"
|
||||
#include "reference_hstu_attention_fwd.hpp"
|
||||
|
||||
#include "hstu_attention_util.hpp"
|
||||
#include "hstu_attention_host_util.hpp"
|
||||
#include "hstu_attention_api.hpp"
|
||||
|
||||
template <typename T>
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include <ck_tile/ops/common.hpp>
|
||||
#include <ck_tile/ops/fmha/block/block_dropout.hpp>
|
||||
|
||||
#include <string>
|
||||
@@ -13,7 +12,7 @@
|
||||
#include <variant>
|
||||
|
||||
#include "hstu_block_masking.hpp"
|
||||
#include "hstu_attention_util.hpp"
|
||||
#include "hstu_attention_kernel_util.hpp"
|
||||
|
||||
#ifndef HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM
|
||||
#define HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM 1
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
|
||||
#include "hstu_attention_fwd_type_config.hpp"
|
||||
#include "hstu_attention_tile_setting_define.hpp"
|
||||
#include "hstu_attention_util.hpp"
|
||||
#include "hstu_attention_host_util.hpp"
|
||||
|
||||
using WarpTile_16x16x16 = ck_tile::sequence<16, 16, 16>;
|
||||
using WarpTile_16x16x32 = ck_tile::sequence<16, 16, 32>;
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include <ck_tile/ops/common.hpp>
|
||||
#include <ck_tile/ops/fmha/block/block_dropout.hpp>
|
||||
|
||||
#include <string>
|
||||
@@ -13,7 +12,7 @@
|
||||
#include <variant>
|
||||
|
||||
#include "hstu_block_masking.hpp"
|
||||
#include "hstu_attention_util.hpp"
|
||||
#include "hstu_attention_kernel_util.hpp"
|
||||
|
||||
#ifndef HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM
|
||||
#define HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM 1
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
|
||||
#include <ck_tile/host/hip_check_error.hpp>
|
||||
|
||||
#define HSTU_CHECK(COND, ERR) \
|
||||
if(!(COND)) \
|
||||
{ \
|
||||
std::ostringstream ostr; \
|
||||
ostr << "'" #COND "' failed: " << ERR; \
|
||||
throw std::runtime_error(ostr.str()); \
|
||||
}
|
||||
|
||||
static inline int get_number_of_cu()
|
||||
{
|
||||
int device;
|
||||
|
||||
HIP_CHECK_ERROR(hipGetDevice(&device));
|
||||
|
||||
hipDeviceProp_t props;
|
||||
|
||||
HIP_CHECK_ERROR(hipGetDeviceProperties(&props, device));
|
||||
|
||||
return props.multiProcessorCount;
|
||||
}
|
||||
@@ -4,33 +4,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include <ck_tile/host/hip_check_error.hpp>
|
||||
|
||||
#define HSTU_CHECK(COND, ERR) \
|
||||
if(!(COND)) \
|
||||
{ \
|
||||
std::ostringstream ostr; \
|
||||
ostr << "'" #COND "' failed: " << ERR; \
|
||||
throw std::runtime_error(ostr.str()); \
|
||||
}
|
||||
|
||||
static inline int get_number_of_cu()
|
||||
{
|
||||
int device;
|
||||
|
||||
HIP_CHECK_ERROR(hipGetDevice(&device));
|
||||
|
||||
hipDeviceProp_t props;
|
||||
|
||||
HIP_CHECK_ERROR(hipGetDeviceProperties(&props, device));
|
||||
|
||||
return props.multiProcessorCount;
|
||||
}
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
#include <ck_tile/ops/fmha/block/block_dropout.hpp>
|
||||
|
||||
#include "hstu_attention_fwd_pipeline_policy.hpp"
|
||||
#include "hstu_attention_util.hpp"
|
||||
#include "hstu_attention_kernel_util.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
|
||||
@@ -3,7 +3,9 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "hstu_attention_util.hpp"
|
||||
#include <ck_tile/core/numeric/math.hpp>
|
||||
|
||||
#include "hstu_attention_host_util.hpp"
|
||||
|
||||
static float get_estimated_cu_coverage_ratio(int num_batches, int num_heads, int max_seqlen_q)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user