mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
Small refinements in C++ source files
This commit is contained in:
@@ -55,15 +55,16 @@ PIPELINE_MAP = {
|
||||
"qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync",
|
||||
}
|
||||
|
||||
ELEMENT_FUNC_MAP = {
|
||||
"no" : "FmhaDefaultElementFunctions",
|
||||
"f8_static_quant" : "FmhaF8StaticQuantizationElementFunctions",
|
||||
}
|
||||
PIPELINE_ENUM_MAP = {
|
||||
"qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
|
||||
"qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
|
||||
}
|
||||
|
||||
ELEMENT_FUNC_MAP = {
|
||||
"no" : "FmhaDefaultElementFunctions",
|
||||
"f8_static_quant" : "FmhaF8StaticQuantizationElementFunctions",
|
||||
}
|
||||
|
||||
BOOL_MAP = {
|
||||
"t" : "true",
|
||||
"f" : "false"
|
||||
@@ -101,13 +102,7 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
|
||||
{F_occupancy}>;
|
||||
using fmha_mask_{F_idx} = {F_mask};
|
||||
|
||||
using fmha_element_function_{F_idx} = ck_tile::FmhaElementFunctions<
|
||||
// typename {F_element_func}::QElementFunction,
|
||||
// typename {F_element_func}::KElementFunction,
|
||||
// typename {F_element_func}::VElementFunction,
|
||||
// typename {F_element_func}::BiasElementFunction,
|
||||
// typename {F_element_func}::LSEElementFunction,
|
||||
// typename {F_element_func}::SAccElementFunction,
|
||||
using fmha_element_function_{F_idx} = ck_tile::TileFmhaElementFunctions<
|
||||
typename {F_element_func}::PComputeElementFunction,
|
||||
typename {F_element_func}::OAccElementFunction>;
|
||||
|
||||
|
||||
@@ -49,10 +49,10 @@
|
||||
#include "ck_tile/core/tensor/tile_window.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/unary_element_function.hpp"
|
||||
#include "ck_tile/core/utility/ignore.hpp"
|
||||
#include "ck_tile/core/utility/magic_div.hpp"
|
||||
#include "ck_tile/core/utility/random.hpp"
|
||||
#include "ck_tile/core/utility/to_sequence.hpp"
|
||||
#include "ck_tile/core/utility/transpose_vectors.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/utility/unary_element_function.hpp"
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/fmha_element_functions.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/tile_fmha_element_functions.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
@@ -9,10 +9,9 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// deprecated pipeline
|
||||
// This pipeline is qkv all located in LDS
|
||||
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
|
||||
struct BlockFmhaPipelineQRKSVSFp8
|
||||
struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
|
||||
@@ -9,7 +9,7 @@ namespace ck_tile {
|
||||
|
||||
/// TODO: support specifying more elementwise functions for input/output tensors
|
||||
template <typename PComputeElementFunction_, typename OAccElementFunction_>
|
||||
struct FmhaElementFunctions
|
||||
struct TileFmhaElementFunctions
|
||||
{
|
||||
using PComputeElementFunction = remove_cvref_t<PComputeElementFunction_>;
|
||||
using OAccElementFunction = remove_cvref_t<OAccElementFunction_>;
|
||||
Reference in New Issue
Block a user