mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Add unified attention (42_unified_attention)
Squashed from aghamari/unified-attention-decode-opt branch. CK tile paged-KV attention kernel optimized for decode with 4-tier dispatch (tiny/small/medium/large), 16x16 MFMA, 2D decode grid, head-group merging. Supports hdim=64 GQA-8 and hdim=128 MHA with block_size=32. Made-with: Cursor
This commit is contained in:
@@ -36,7 +36,7 @@ DTYPE_BITS = {
|
||||
|
||||
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256}
|
||||
|
||||
SUPPORTED_PAGE_SIZE = [1, 16, 1024]
|
||||
SUPPORTED_PAGE_SIZE = [1, 16, 32, 1024]
|
||||
SUPPORTED_KV_MEMORY_LAYOUT = ["vectorized", "linear"]
|
||||
SUPPORTED_KV_LOOKUP_TABLE = ["vllm", "sglang"]
|
||||
KV_MEMORY_LAYOUT_ENUM_MAP = {
|
||||
|
||||
@@ -1,24 +1,11 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
add_executable(tile_example_topk_softmax topk_softmax.cpp topk_softmax_api.cpp)
|
||||
target_include_directories(tile_example_topk_softmax PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/)
|
||||
|
||||
set(EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS)
|
||||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
|
||||
list(APPEND EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
|
||||
# list(APPEND EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
|
||||
|
||||
add_executable(tile_example_topk_softmax topk_softmax.cpp topk_softmax_api.cpp)
|
||||
target_include_directories(tile_example_topk_softmax PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/)
|
||||
target_compile_options(tile_example_topk_softmax PRIVATE ${EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS})
|
||||
|
||||
add_executable(tile_example_topk_softmax_decode
|
||||
topk_softmax_decode.cpp
|
||||
topk_softmax_decode_api.cpp
|
||||
topk_softmax_api.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../13_moe_sorting/moe_sorting_api.cpp)
|
||||
target_include_directories(tile_example_topk_softmax_decode PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../13_moe_sorting/)
|
||||
target_compile_definitions(tile_example_topk_softmax_decode PRIVATE
|
||||
CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID=1
|
||||
MOE_SORTING_FMOE_2D_BUF=1)
|
||||
target_compile_options(tile_example_topk_softmax_decode PRIVATE ${EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS})
|
||||
|
||||
@@ -1,314 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <cassert>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <algorithm>
|
||||
#include <time.h>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/reduce.hpp"
|
||||
#include "topk_softmax_decode_api.hpp"
|
||||
#include "topk_softmax_api.hpp"
|
||||
#include "moe_sorting_api.hpp"
|
||||
|
||||
#ifndef CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
#define CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID 1
|
||||
#endif
|
||||
#include "ck_tile/host/reference/reference_moe_sorting.hpp"
|
||||
|
||||
// CPU reference: softmax -> topk -> moe_sorting
|
||||
template <typename InputType, typename WeightType, typename IndexType = ck_tile::index_t>
|
||||
bool reference_fused(const ck_tile::HostTensor<InputType>& x_host,
|
||||
ck_tile::index_t topk,
|
||||
ck_tile::index_t num_experts,
|
||||
ck_tile::index_t unit_size,
|
||||
ck_tile::HostTensor<IndexType>& ref_sorted_ids,
|
||||
ck_tile::HostTensor<WeightType>& ref_sorted_weights,
|
||||
ck_tile::HostTensor<IndexType>& ref_sorted_expert_ids,
|
||||
ck_tile::index_t& ref_unit_cnt)
|
||||
{
|
||||
auto probs = ck_tile::reference_softmax<InputType, WeightType, WeightType>(x_host);
|
||||
|
||||
ck_tile::HostTensor<WeightType> topk_vals({1, topk});
|
||||
ck_tile::HostTensor<IndexType> topk_idxs({1, topk});
|
||||
ck_tile::reference_topk(probs, topk_vals, topk_idxs, topk);
|
||||
|
||||
ck_tile::HostTensor<IndexType> local_expert_mask({num_experts});
|
||||
ref_unit_cnt = 0;
|
||||
ck_tile::reference_moe_sorting<WeightType, IndexType>(
|
||||
topk_idxs,
|
||||
topk_vals,
|
||||
local_expert_mask,
|
||||
ref_sorted_ids,
|
||||
ref_sorted_weights,
|
||||
ref_sorted_expert_ids,
|
||||
ref_unit_cnt,
|
||||
num_experts,
|
||||
unit_size,
|
||||
1,
|
||||
false,
|
||||
true);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("v", "1", "do CPU validation")
|
||||
.insert("pr_i", "bf16", "input data type: fp16/bf16")
|
||||
.insert("e", "128", "number of experts")
|
||||
.insert("k", "8", "topk")
|
||||
.insert("unit", "32", "unit_size (block_size_M)")
|
||||
.insert("model_dim", "7168", "model dimension for moe_buf zeroing")
|
||||
.insert("seed", "-1", "random seed, -1 = random")
|
||||
.insert("warmup", "5", "warmup iterations")
|
||||
.insert("repeat", "20", "benchmark iterations");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename InputType, typename WeightType, typename IndexType = ck_tile::index_t>
|
||||
bool run_test(ck_tile::ArgParser args)
|
||||
{
|
||||
int validate = args.get_int("v");
|
||||
std::string pr = args.get_str("pr_i");
|
||||
int experts = args.get_int("e");
|
||||
int topk = args.get_int("k");
|
||||
int unit_size = args.get_int("unit");
|
||||
int model_dim = args.get_int("model_dim");
|
||||
int seed = args.get_int("seed");
|
||||
int warmup = args.get_int("warmup");
|
||||
int repeat = args.get_int("repeat");
|
||||
|
||||
if(seed < 0)
|
||||
seed = std::time(nullptr);
|
||||
|
||||
if(topk > experts)
|
||||
{
|
||||
printf("topk %d > experts %d, skip\n", topk, experts);
|
||||
return false;
|
||||
}
|
||||
|
||||
int tokens = 1;
|
||||
int max_num_tokens_padded = topk + experts * unit_size - topk;
|
||||
int max_num_m_blocks = (max_num_tokens_padded + unit_size - 1) / unit_size;
|
||||
|
||||
// Host tensors
|
||||
ck_tile::HostTensor<InputType> x_host({1, experts});
|
||||
{
|
||||
auto rng = ck_tile::FillUniformDistribution_Unique<InputType>{
|
||||
-5.f, 5.f, static_cast<uint32_t>(seed)};
|
||||
ck_tile::HostTensor<InputType> row({experts});
|
||||
rng(row);
|
||||
std::copy(row.begin(), row.end(), x_host.begin());
|
||||
}
|
||||
|
||||
// ---------- Device buffers (shared) ----------
|
||||
ck_tile::DeviceMem x_dev(x_host.get_element_space_size_in_bytes());
|
||||
x_dev.ToDevice(x_host.data());
|
||||
|
||||
// ---------- Fused kernel buffers ----------
|
||||
ck_tile::DeviceMem fused_sorted_ids(max_num_tokens_padded * sizeof(IndexType));
|
||||
ck_tile::DeviceMem fused_sorted_weights(max_num_tokens_padded * sizeof(WeightType));
|
||||
ck_tile::DeviceMem fused_sorted_expert_ids(max_num_m_blocks * sizeof(IndexType));
|
||||
ck_tile::DeviceMem fused_num_valid(2 * sizeof(IndexType));
|
||||
ck_tile::DeviceMem fused_moe_buf(model_dim * sizeof(WeightType));
|
||||
{
|
||||
std::vector<float> ones(model_dim, 1.0f);
|
||||
fused_moe_buf.ToDevice(ones.data());
|
||||
}
|
||||
|
||||
// ---------- Two-kernel baseline buffers ----------
|
||||
ck_tile::DeviceMem topk_w_dev(topk * sizeof(WeightType));
|
||||
ck_tile::DeviceMem topk_i_dev(topk * sizeof(IndexType));
|
||||
ck_tile::DeviceMem base_sorted_ids(max_num_tokens_padded * sizeof(IndexType));
|
||||
ck_tile::DeviceMem base_sorted_weights(max_num_tokens_padded * sizeof(WeightType));
|
||||
ck_tile::DeviceMem base_sorted_expert_ids(max_num_m_blocks * sizeof(IndexType));
|
||||
ck_tile::DeviceMem base_num_valid(2 * sizeof(IndexType));
|
||||
ck_tile::DeviceMem base_moe_buf(model_dim * sizeof(WeightType));
|
||||
int ws_size = moe_sorting_get_workspace_size(tokens, experts, topk, 0);
|
||||
ck_tile::DeviceMem base_ws(ws_size > 0 ? ws_size : 1);
|
||||
if(ws_size > 0)
|
||||
base_ws.SetZero();
|
||||
|
||||
ck_tile::stream_config sc{nullptr, true, 0, warmup, repeat};
|
||||
|
||||
// ====================== Fused kernel ======================
|
||||
topk_softmax_decode_trait fused_trait{pr, "fp32", experts, "softmax"};
|
||||
topk_softmax_decode_kargs fused_karg{
|
||||
x_dev.GetDeviceBuffer(),
|
||||
experts,
|
||||
topk,
|
||||
experts,
|
||||
true,
|
||||
fused_sorted_ids.GetDeviceBuffer(),
|
||||
fused_sorted_weights.GetDeviceBuffer(),
|
||||
fused_sorted_expert_ids.GetDeviceBuffer(),
|
||||
fused_num_valid.GetDeviceBuffer(),
|
||||
fused_moe_buf.GetDeviceBuffer(),
|
||||
unit_size,
|
||||
model_dim,
|
||||
static_cast<int>(sizeof(WeightType))};
|
||||
|
||||
float ms_fused = topk_softmax_decode(fused_trait, fused_karg, sc);
|
||||
|
||||
// ============= Two-kernel baseline: topk + moe_sorting =============
|
||||
topk_softmax_trait ts_trait{pr, "fp32", experts, "softmax"};
|
||||
topk_softmax_kargs ts_karg{
|
||||
x_dev.GetDeviceBuffer(),
|
||||
topk_w_dev.GetDeviceBuffer(),
|
||||
topk_i_dev.GetDeviceBuffer(),
|
||||
tokens,
|
||||
experts,
|
||||
topk,
|
||||
experts,
|
||||
topk};
|
||||
|
||||
moe_sorting_trait ms_trait{"int32", "fp32", false, true, 0};
|
||||
moe_sorting_args ms_arg{
|
||||
topk_i_dev.GetDeviceBuffer(),
|
||||
topk_w_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
nullptr,
|
||||
base_sorted_ids.GetDeviceBuffer(),
|
||||
base_sorted_weights.GetDeviceBuffer(),
|
||||
base_sorted_expert_ids.GetDeviceBuffer(),
|
||||
base_num_valid.GetDeviceBuffer(),
|
||||
base_moe_buf.GetDeviceBuffer(),
|
||||
ws_size > 0 ? base_ws.GetDeviceBuffer() : nullptr,
|
||||
tokens,
|
||||
unit_size,
|
||||
experts,
|
||||
topk,
|
||||
model_dim,
|
||||
static_cast<int>(sizeof(WeightType))};
|
||||
|
||||
// Time the two kernels together using launch_kernel with two lambdas
|
||||
auto sc_sub = ck_tile::stream_config{nullptr, false, 0, 0, 1};
|
||||
float ms_baseline = ck_tile::launch_kernel(
|
||||
sc,
|
||||
[&](const ck_tile::stream_config&) { topk_softmax(ts_trait, ts_karg, sc_sub); },
|
||||
[&](const ck_tile::stream_config&) { moe_sorting(ms_trait, ms_arg, sc_sub); });
|
||||
|
||||
float speedup = (ms_baseline > 0 && ms_fused > 0) ? ms_baseline / ms_fused : 0;
|
||||
printf("[%s] E:%d, k:%d, unit:%d | fused:%.4fms baseline(topk+sort):%.4fms speedup:%.2fx",
|
||||
pr.c_str(),
|
||||
experts,
|
||||
topk,
|
||||
unit_size,
|
||||
ms_fused,
|
||||
ms_baseline,
|
||||
speedup);
|
||||
|
||||
if(ms_fused < 0 || ms_baseline < 0)
|
||||
{
|
||||
printf(" (not supported)\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// ====================== Validation ======================
|
||||
bool pass = true;
|
||||
if(validate)
|
||||
{
|
||||
ck_tile::HostTensor<IndexType> sorted_ids_host({max_num_tokens_padded});
|
||||
ck_tile::HostTensor<WeightType> sorted_weights_host({max_num_tokens_padded});
|
||||
ck_tile::HostTensor<IndexType> sorted_expert_ids_host({max_num_m_blocks});
|
||||
ck_tile::HostTensor<IndexType> num_valid_host({2});
|
||||
std::vector<float> moe_buf_host(model_dim);
|
||||
|
||||
fused_sorted_ids.FromDevice(sorted_ids_host.data());
|
||||
fused_sorted_weights.FromDevice(sorted_weights_host.data());
|
||||
fused_sorted_expert_ids.FromDevice(sorted_expert_ids_host.data());
|
||||
fused_num_valid.FromDevice(num_valid_host.data());
|
||||
fused_moe_buf.FromDevice(moe_buf_host.data());
|
||||
|
||||
ck_tile::HostTensor<IndexType> ref_sorted_ids({max_num_tokens_padded});
|
||||
ck_tile::HostTensor<WeightType> ref_sorted_weights({max_num_tokens_padded});
|
||||
ck_tile::HostTensor<IndexType> ref_sorted_expert_ids({max_num_m_blocks});
|
||||
IndexType sentinel = static_cast<uint32_t>((1 & 0x00ffffff) | ((topk & 0xff) << 24));
|
||||
std::fill(ref_sorted_ids.begin(), ref_sorted_ids.end(), sentinel);
|
||||
std::fill(ref_sorted_weights.begin(), ref_sorted_weights.end(), WeightType(0));
|
||||
std::fill(ref_sorted_expert_ids.begin(), ref_sorted_expert_ids.end(), -1);
|
||||
|
||||
ck_tile::index_t ref_unit_cnt = 0;
|
||||
reference_fused<InputType, WeightType, IndexType>(
|
||||
x_host, topk, experts, unit_size,
|
||||
ref_sorted_ids, ref_sorted_weights, ref_sorted_expert_ids, ref_unit_cnt);
|
||||
|
||||
int num_valid_padded = num_valid_host(0);
|
||||
int num_valid_tokens = num_valid_host(1);
|
||||
|
||||
if(num_valid_padded != ref_unit_cnt)
|
||||
{
|
||||
printf(" FAIL:num_valid[0] got %d ref %d;", num_valid_padded, ref_unit_cnt);
|
||||
pass = false;
|
||||
}
|
||||
if(num_valid_tokens != 1)
|
||||
{
|
||||
printf(" FAIL:num_valid[1] got %d;", num_valid_tokens);
|
||||
pass = false;
|
||||
}
|
||||
|
||||
int n_tiles = num_valid_padded / unit_size;
|
||||
for(int i = 1; i < n_tiles; i++)
|
||||
{
|
||||
if(sorted_expert_ids_host(i) < sorted_expert_ids_host(i - 1))
|
||||
{
|
||||
printf(" FAIL:expert_ids not ascending;");
|
||||
pass = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
WeightType wsum = 0;
|
||||
for(int i = 0; i < topk; i++)
|
||||
wsum += sorted_weights_host(i * unit_size);
|
||||
if(std::abs(wsum - 1.0f) > 1e-3f)
|
||||
{
|
||||
printf(" FAIL:wsum=%.6f;", static_cast<float>(wsum));
|
||||
pass = false;
|
||||
}
|
||||
|
||||
bool buf_zeroed = std::all_of(
|
||||
moe_buf_host.begin(), moe_buf_host.end(), [](float v) { return v == 0.0f; });
|
||||
if(!buf_zeroed)
|
||||
{
|
||||
printf(" FAIL:moe_buf not zeroed;");
|
||||
pass = false;
|
||||
}
|
||||
}
|
||||
|
||||
printf(" valid:%s\n", pass ? "y" : "n");
|
||||
fflush(stdout);
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
auto [result, args] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
std::string pr = args.get_str("pr_i");
|
||||
bool r = true;
|
||||
|
||||
if(pr == "fp16")
|
||||
r &= run_test<ck_tile::fp16_t, float>(args);
|
||||
else if(pr == "bf16")
|
||||
r &= run_test<ck_tile::bf16_t, float>(args);
|
||||
else
|
||||
{
|
||||
printf("unsupported pr_i: %s\n", pr.c_str());
|
||||
return -1;
|
||||
}
|
||||
|
||||
return r ? 0 : -1;
|
||||
}
|
||||
@@ -1,99 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "topk_softmax_decode_api.hpp"
|
||||
|
||||
#define TOPK_SOFTMAX_DECODE_DISPATCH(experts_, use_softmax_) \
|
||||
constexpr ck_tile::index_t ts_experts = experts_; \
|
||||
constexpr bool ts_use_softmax = use_softmax_; \
|
||||
using ts_problem = ck_tile::TopkSoftmaxWarpPerRowProblem<ts_input_type, \
|
||||
ts_weight_type, \
|
||||
ts_index_type, \
|
||||
ts_experts, \
|
||||
ts_use_softmax>; \
|
||||
using ts_pipeline = ck_tile::TopkSoftmaxDecodePipeline<ts_problem>; \
|
||||
\
|
||||
using kernel = ck_tile::TopkSoftmaxDecodeKernel<ts_pipeline>; \
|
||||
\
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
\
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(); \
|
||||
\
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(kernel{}, grids, blocks, 0, kargs)); \
|
||||
\
|
||||
return ave_time;
|
||||
|
||||
#define TOPK_SOFTMAX_DECODE_EXPERT_LADDER(use_softmax_) \
|
||||
if(t.experts <= 8) \
|
||||
{ \
|
||||
TOPK_SOFTMAX_DECODE_DISPATCH(8, use_softmax_) \
|
||||
} \
|
||||
else if(t.experts <= 16) \
|
||||
{ \
|
||||
TOPK_SOFTMAX_DECODE_DISPATCH(16, use_softmax_) \
|
||||
} \
|
||||
else if(t.experts <= 32) \
|
||||
{ \
|
||||
TOPK_SOFTMAX_DECODE_DISPATCH(32, use_softmax_) \
|
||||
} \
|
||||
else if(t.experts <= 64) \
|
||||
{ \
|
||||
TOPK_SOFTMAX_DECODE_DISPATCH(64, use_softmax_) \
|
||||
} \
|
||||
else if(t.experts <= 128) \
|
||||
{ \
|
||||
TOPK_SOFTMAX_DECODE_DISPATCH(128, use_softmax_) \
|
||||
} \
|
||||
else if(t.experts <= 192) \
|
||||
{ \
|
||||
TOPK_SOFTMAX_DECODE_DISPATCH(192, use_softmax_) \
|
||||
} \
|
||||
else if(t.experts <= 256) \
|
||||
{ \
|
||||
TOPK_SOFTMAX_DECODE_DISPATCH(256, use_softmax_) \
|
||||
} \
|
||||
else if(t.experts <= 512) \
|
||||
{ \
|
||||
TOPK_SOFTMAX_DECODE_DISPATCH(512, use_softmax_) \
|
||||
} \
|
||||
else if(t.experts <= 1024) \
|
||||
{ \
|
||||
TOPK_SOFTMAX_DECODE_DISPATCH(1024, use_softmax_)\
|
||||
}
|
||||
|
||||
float topk_softmax_decode(topk_softmax_decode_trait t,
|
||||
topk_softmax_decode_kargs a,
|
||||
ck_tile::stream_config s)
|
||||
{
|
||||
if(t.input_type == "fp16" && t.weight_type == "fp32" && t.activation == "softmax")
|
||||
{
|
||||
using ts_input_type = ck_tile::fp16_t;
|
||||
using ts_weight_type = float;
|
||||
using ts_index_type = ck_tile::index_t;
|
||||
TOPK_SOFTMAX_DECODE_EXPERT_LADDER(true)
|
||||
}
|
||||
else if(t.input_type == "bf16" && t.weight_type == "fp32" && t.activation == "softmax")
|
||||
{
|
||||
using ts_input_type = ck_tile::bf16_t;
|
||||
using ts_weight_type = float;
|
||||
using ts_index_type = ck_tile::index_t;
|
||||
TOPK_SOFTMAX_DECODE_EXPERT_LADDER(true)
|
||||
}
|
||||
else if(t.input_type == "fp16" && t.weight_type == "fp32" && t.activation == "sigmoid")
|
||||
{
|
||||
using ts_input_type = ck_tile::fp16_t;
|
||||
using ts_weight_type = float;
|
||||
using ts_index_type = ck_tile::index_t;
|
||||
TOPK_SOFTMAX_DECODE_EXPERT_LADDER(false)
|
||||
}
|
||||
else if(t.input_type == "bf16" && t.weight_type == "fp32" && t.activation == "sigmoid")
|
||||
{
|
||||
using ts_input_type = ck_tile::bf16_t;
|
||||
using ts_weight_type = float;
|
||||
using ts_index_type = ck_tile::index_t;
|
||||
TOPK_SOFTMAX_DECODE_EXPERT_LADDER(false)
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
@@ -1,24 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/topk_softmax.hpp"
|
||||
#include <string>
|
||||
|
||||
struct topk_softmax_decode_trait
|
||||
{
|
||||
std::string input_type;
|
||||
std::string weight_type; // currently always float
|
||||
int experts;
|
||||
std::string activation; // "softmax" or "sigmoid"
|
||||
};
|
||||
|
||||
struct topk_softmax_decode_kargs : public ck_tile::TopkSoftmaxDecodeHostArgs
|
||||
{
|
||||
};
|
||||
|
||||
float topk_softmax_decode(topk_softmax_decode_trait t,
|
||||
topk_softmax_decode_kargs a,
|
||||
ck_tile::stream_config s);
|
||||
Reference in New Issue
Block a user