mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK_TILE] Top-K with Sigmoid kernel (#3062)
* Add sigmoid option to topk_softmax
* fix formatting
* add to changelog
* Apply suggestions from code review
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* Use else if
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
---------
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
[ROCm/composable_kernel commit: 515e283091]
This commit is contained in:
@@ -20,6 +20,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
|
||||
* Added tensor-wise quantization for CK_TILE GEMM.
|
||||
* Added support for batched contraction kernel.
|
||||
* Added pooling kernel in CK_TILE
|
||||
* Added top-k sigmoid kernel in CK_TILE
|
||||
|
||||
### Changed
|
||||
|
||||
|
||||
@@ -83,6 +83,26 @@ auto reference_topk_softmax(const ck_tile::HostTensor<InputType>& x,
|
||||
reference_topk(y, y_values, y_indices, k, dim, largest, sorted);
|
||||
}
|
||||
|
||||
template <typename InputType, typename WeightType, typename IndexType = ck_tile::index_t>
|
||||
auto reference_topk_sigmoid(const ck_tile::HostTensor<InputType>& x,
|
||||
ck_tile::HostTensor<WeightType>& y_values,
|
||||
ck_tile::HostTensor<IndexType>& y_indices,
|
||||
ck_tile::index_t k,
|
||||
ck_tile::index_t dim = -1,
|
||||
bool largest = true,
|
||||
bool sorted = true)
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
// topk only - no need to apply the sigmoid first
|
||||
auto x_fp32 = x.template CopyAsType<float>();
|
||||
reference_topk(x_fp32, y_values, y_indices, k, dim, largest, sorted);
|
||||
// apply sigmoid
|
||||
std::transform(y_values.begin(), y_values.end(), y_values.begin(), [](auto value) {
|
||||
return WeightType(1) / (WeightType(1) + exp(-value));
|
||||
});
|
||||
}
|
||||
|
||||
// different threshold for different dtype
|
||||
template <typename DataType>
|
||||
auto get_elimit(std::string /*init_method*/)
|
||||
@@ -133,7 +153,8 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("warmup", "5", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "20", "number of iterations to benchmark the kernel")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "topk_softmax.json", "json file name to dump results");
|
||||
.insert("jsonfile", "topk_softmax.json", "json file name to dump results")
|
||||
.insert("activation", "softmax", "activation function to use: softmax or sigmoid");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
@@ -154,6 +175,7 @@ bool test_topk_softmax(ck_tile::ArgParser args)
|
||||
int kname = args.get_int("kname");
|
||||
int warmup = args.get_int("warmup");
|
||||
int repeat = args.get_int("repeat");
|
||||
std::string activation = args.get_str("activation");
|
||||
|
||||
if(stride_input < 0)
|
||||
{
|
||||
@@ -204,7 +226,7 @@ bool test_topk_softmax(ck_tile::ArgParser args)
|
||||
|
||||
x_dev.ToDevice(x_host.data());
|
||||
|
||||
topk_softmax_trait trait{input_prec, weight_prec, experts};
|
||||
topk_softmax_trait trait{input_prec, weight_prec, experts, activation};
|
||||
|
||||
topk_softmax_kargs karg{x_dev.GetDeviceBuffer(),
|
||||
value_dev.GetDeviceBuffer(),
|
||||
@@ -221,7 +243,7 @@ bool test_topk_softmax(ck_tile::ArgParser args)
|
||||
warmup,
|
||||
repeat};
|
||||
auto ms = topk_softmax(trait, karg, sc);
|
||||
printf("[%s|%s]tokens:%d, experts:%d, topk:%d, st_i:%d, st_o:%d, ms:%f, ",
|
||||
printf("[%s|%s]tokens:%d, experts:%d, topk:%d, st_i:%d, st_o:%d, activation:%s, ms:%f, ",
|
||||
input_prec.c_str(),
|
||||
weight_prec.c_str(),
|
||||
tokens,
|
||||
@@ -229,6 +251,7 @@ bool test_topk_softmax(ck_tile::ArgParser args)
|
||||
topk,
|
||||
stride_input,
|
||||
stride_output,
|
||||
activation.c_str(),
|
||||
ms);
|
||||
if(ms < 0)
|
||||
printf("not supported\n");
|
||||
@@ -247,8 +270,20 @@ bool test_topk_softmax(ck_tile::ArgParser args)
|
||||
ck_tile::HostTensor<WeightType> value_ref({tokens, topk}, {stride_output, 1});
|
||||
ck_tile::HostTensor<IndexType> index_ref({tokens, topk}, {stride_output, 1});
|
||||
|
||||
reference_topk_softmax<InputType, WeightType, IndexType>(
|
||||
x_host, value_ref, index_ref, topk);
|
||||
if(activation == "softmax")
|
||||
{
|
||||
reference_topk_softmax<InputType, WeightType, IndexType>(
|
||||
x_host, value_ref, index_ref, topk);
|
||||
}
|
||||
else if(activation == "sigmoid")
|
||||
{
|
||||
reference_topk_sigmoid<InputType, WeightType, IndexType>(
|
||||
x_host, value_ref, index_ref, topk);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("unsupported activation type: " + activation);
|
||||
}
|
||||
|
||||
auto [rtol, atol] = get_elimit<InputType>("");
|
||||
for(int i_t = 0; i_t < tokens; i_t++)
|
||||
|
||||
@@ -3,27 +3,31 @@
|
||||
|
||||
#include "topk_softmax_api.hpp"
|
||||
|
||||
#define TOPK_SOFTMAX_DISPATCH(experts_) \
|
||||
constexpr ck_tile::index_t ts_experts = experts_; \
|
||||
using ts_problem = ck_tile:: \
|
||||
TopkSoftmaxWarpPerRowProblem<ts_input_type, ts_weight_type, ts_index_type, ts_experts>; \
|
||||
using ts_pipeline = ck_tile::TopkSoftmaxWarpPerRowPipeline<ts_problem>; \
|
||||
\
|
||||
using kernel = ck_tile::TopkSoftmaxKernel<ts_pipeline>; \
|
||||
\
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
\
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(); \
|
||||
\
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(kernel{}, grids, blocks, 0, kargs)); \
|
||||
\
|
||||
#define TOPK_SOFTMAX_DISPATCH(experts_, use_softmax_) \
|
||||
constexpr ck_tile::index_t ts_experts = experts_; \
|
||||
constexpr bool ts_use_softmax = use_softmax_; \
|
||||
using ts_problem = ck_tile::TopkSoftmaxWarpPerRowProblem<ts_input_type, \
|
||||
ts_weight_type, \
|
||||
ts_index_type, \
|
||||
ts_experts, \
|
||||
ts_use_softmax>; \
|
||||
using ts_pipeline = ck_tile::TopkSoftmaxWarpPerRowPipeline<ts_problem>; \
|
||||
\
|
||||
using kernel = ck_tile::TopkSoftmaxKernel<ts_pipeline>; \
|
||||
\
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
\
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(); \
|
||||
\
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(kernel{}, grids, blocks, 0, kargs)); \
|
||||
\
|
||||
return ave_time;
|
||||
|
||||
float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_config s)
|
||||
{
|
||||
if(t.input_type == "fp16" && t.weight_type == "fp32")
|
||||
if(t.input_type == "fp16" && t.weight_type == "fp32" && t.activation == "softmax")
|
||||
{
|
||||
using ts_input_type = ck_tile::fp16_t;
|
||||
using ts_weight_type = float;
|
||||
@@ -31,36 +35,36 @@ float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_c
|
||||
#if 1
|
||||
if(t.experts <= 8)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(8)
|
||||
TOPK_SOFTMAX_DISPATCH(8, true)
|
||||
}
|
||||
else if(t.experts <= 16)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(16)
|
||||
TOPK_SOFTMAX_DISPATCH(16, true)
|
||||
}
|
||||
else if(t.experts <= 32)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(32)
|
||||
TOPK_SOFTMAX_DISPATCH(32, true)
|
||||
}
|
||||
else if(t.experts <= 64)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(64)
|
||||
TOPK_SOFTMAX_DISPATCH(64, true)
|
||||
}
|
||||
else if(t.experts <= 128)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(128)
|
||||
TOPK_SOFTMAX_DISPATCH(128, true)
|
||||
}
|
||||
else if(t.experts <= 192)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(192)
|
||||
TOPK_SOFTMAX_DISPATCH(192, true)
|
||||
}
|
||||
#else
|
||||
if(t.experts <= 128)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(128)
|
||||
TOPK_SOFTMAX_DISPATCH(128, true)
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else if(t.input_type == "bf16" && t.weight_type == "fp32")
|
||||
else if(t.input_type == "bf16" && t.weight_type == "fp32" && t.activation == "softmax")
|
||||
{
|
||||
#if 1
|
||||
using ts_input_type = ck_tile::bf16_t;
|
||||
@@ -68,27 +72,96 @@ float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_c
|
||||
using ts_index_type = ck_tile::index_t;
|
||||
if(t.experts <= 8)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(8)
|
||||
TOPK_SOFTMAX_DISPATCH(8, true)
|
||||
}
|
||||
else if(t.experts <= 16)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(16)
|
||||
TOPK_SOFTMAX_DISPATCH(16, true)
|
||||
}
|
||||
else if(t.experts <= 32)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(32)
|
||||
TOPK_SOFTMAX_DISPATCH(32, true)
|
||||
}
|
||||
else if(t.experts <= 64)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(64)
|
||||
TOPK_SOFTMAX_DISPATCH(64, true)
|
||||
}
|
||||
else if(t.experts <= 128)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(128)
|
||||
TOPK_SOFTMAX_DISPATCH(128, true)
|
||||
}
|
||||
else if(t.experts <= 192)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(192)
|
||||
TOPK_SOFTMAX_DISPATCH(192, true)
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else if(t.input_type == "fp16" && t.weight_type == "fp32" && t.activation == "sigmoid")
|
||||
{
|
||||
using ts_input_type = ck_tile::fp16_t;
|
||||
using ts_weight_type = float;
|
||||
using ts_index_type = ck_tile::index_t;
|
||||
#if 1
|
||||
if(t.experts <= 8)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(8, false)
|
||||
}
|
||||
else if(t.experts <= 16)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(16, false)
|
||||
}
|
||||
else if(t.experts <= 32)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(32, false)
|
||||
}
|
||||
else if(t.experts <= 64)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(64, false)
|
||||
}
|
||||
else if(t.experts <= 128)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(128, false)
|
||||
}
|
||||
else if(t.experts <= 192)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(192, false)
|
||||
}
|
||||
#else
|
||||
if(t.experts <= 128)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(128, false)
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else if(t.input_type == "bf16" && t.weight_type == "fp32" && t.activation == "sigmoid")
|
||||
{
|
||||
#if 1
|
||||
using ts_input_type = ck_tile::bf16_t;
|
||||
using ts_weight_type = float;
|
||||
using ts_index_type = ck_tile::index_t;
|
||||
if(t.experts <= 8)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(8, false)
|
||||
}
|
||||
else if(t.experts <= 16)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(16, false)
|
||||
}
|
||||
else if(t.experts <= 32)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(32, false)
|
||||
}
|
||||
else if(t.experts <= 64)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(64, false)
|
||||
}
|
||||
else if(t.experts <= 128)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(128, false)
|
||||
}
|
||||
else if(t.experts <= 192)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(192, false)
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ struct topk_softmax_trait
|
||||
std::string input_type;
|
||||
std::string weight_type; // currently always float
|
||||
int experts;
|
||||
std::string activation; // "softmax" or "sigmoid"
|
||||
};
|
||||
|
||||
struct topk_softmax_kargs : public ck_tile::TopkSoftmaxHostArgs
|
||||
|
||||
@@ -21,7 +21,7 @@ struct TopkSoftmaxHostArgs
|
||||
index_t num_experts;
|
||||
index_t topk;
|
||||
index_t stride_input; // row stride for input, at least experts
|
||||
index_t stride_output; // row stride for output/indices, at least tpok
|
||||
index_t stride_output; // row stride for output/indices, at least topk
|
||||
};
|
||||
|
||||
template <typename Pipeline_>
|
||||
@@ -45,7 +45,7 @@ struct TopkSoftmaxKernel
|
||||
index_t num_experts;
|
||||
index_t topk;
|
||||
index_t stride_input; // row stride for input, at least experts
|
||||
index_t stride_output; // row stride for output/indices, at least tpok
|
||||
index_t stride_output; // row stride for output/indices, at least topk
|
||||
};
|
||||
|
||||
using Kargs = TopkSoftmaxKargs;
|
||||
|
||||
@@ -90,6 +90,11 @@ struct TopkSoftmaxWarpPerRowPipeline
|
||||
const auto current_expert = x_indices.at(number<1>{});
|
||||
w_(idx) =
|
||||
current_expert >= experts ? -numeric<WeightType>::infinity() : w_(idx);
|
||||
if constexpr(!Problem::ActivationIsSoftmax)
|
||||
{
|
||||
// sigmoid can be pre-computed already here if not using softmax
|
||||
w_(idx) = WeightType(1) / (WeightType(1) + exp(-w_(idx)));
|
||||
}
|
||||
};
|
||||
tile_sweeper ts{w_, w_f};
|
||||
ts();
|
||||
@@ -97,10 +102,16 @@ struct TopkSoftmaxWarpPerRowPipeline
|
||||
#endif
|
||||
}();
|
||||
|
||||
// softmax
|
||||
auto y = softmax(w);
|
||||
|
||||
topk(y, out_win, idx_win, k);
|
||||
if constexpr(Problem::ActivationIsSoftmax)
|
||||
{
|
||||
auto y = softmax(w);
|
||||
topk(y, out_win, idx_win, k);
|
||||
}
|
||||
else
|
||||
{
|
||||
// sigmoid was already pre-computed above, so only do topk now
|
||||
topk(w, out_win, idx_win, k);
|
||||
}
|
||||
|
||||
// check exit
|
||||
if constexpr(Problem::LaunchType == 0)
|
||||
|
||||
@@ -13,10 +13,11 @@ template <typename InputType_,
|
||||
typename WeightType_,
|
||||
typename IndexType_,
|
||||
index_t Experts_,
|
||||
index_t IssuesPerCol_ = 2, // issue along col, to make sure block_reduce() OK
|
||||
index_t BytesPerIssue_ = sizeof(InputType_),
|
||||
index_t LaunchType_ = 0, // 0-streaming, >0, persistent #occupancy
|
||||
index_t BlockSize_ = 256>
|
||||
bool ActivationIsSoftmax_ = true, // false: sigmoid
|
||||
index_t IssuesPerCol_ = 2, // issue along col, to make sure block_reduce() OK
|
||||
index_t BytesPerIssue_ = sizeof(InputType_),
|
||||
index_t LaunchType_ = 0, // 0-streaming, >0, persistent #occupancy
|
||||
index_t BlockSize_ = 256>
|
||||
struct TopkSoftmaxWarpPerRowProblem
|
||||
{
|
||||
// TODO: this kernel only support warp per row
|
||||
@@ -31,6 +32,8 @@ struct TopkSoftmaxWarpPerRowProblem
|
||||
static constexpr index_t BlockSize = BlockSize_;
|
||||
static constexpr index_t WarpSize = get_warp_size();
|
||||
|
||||
static constexpr bool ActivationIsSoftmax = ActivationIsSoftmax_;
|
||||
|
||||
static_assert(BytesPerIssue % sizeof(InputType) == 0);
|
||||
static constexpr index_t VectorSize = BytesPerIssue / sizeof(InputType);
|
||||
static_assert(Experts % VectorSize == 0);
|
||||
|
||||
@@ -39,6 +39,26 @@ auto reference_topk_softmax(const ck_tile::HostTensor<InputType>& x,
|
||||
reference_topk(y, y_values, y_indices, k, dim, largest, sorted);
|
||||
}
|
||||
|
||||
template <typename InputType, typename WeightType, typename IndexType = ck_tile::index_t>
|
||||
auto reference_topk_sigmoid(const ck_tile::HostTensor<InputType>& x,
|
||||
ck_tile::HostTensor<WeightType>& y_values,
|
||||
ck_tile::HostTensor<IndexType>& y_indices,
|
||||
ck_tile::index_t k,
|
||||
ck_tile::index_t dim = -1,
|
||||
bool largest = true,
|
||||
bool sorted = true)
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
// topk only - no need to apply the sigmoid first
|
||||
auto x_fp32 = x.template CopyAsType<float>();
|
||||
reference_topk(x_fp32, y_values, y_indices, k, dim, largest, sorted);
|
||||
// apply sigmoid
|
||||
std::transform(y_values.begin(), y_values.end(), y_values.begin(), [](auto value) {
|
||||
return WeightType(1) / (WeightType(1) + exp(-value));
|
||||
});
|
||||
}
|
||||
|
||||
// different threshold for different dtype
|
||||
template <typename DataType>
|
||||
auto get_elimit(std::string /*init_method*/)
|
||||
@@ -87,7 +107,8 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("seed", "-1", "seed to be used, -1 means random every time")
|
||||
.insert("kname", "0", "when set to 1 it will print kernel name")
|
||||
.insert("warmup", "5", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "20", "number of iterations to benchmark the kernel");
|
||||
.insert("repeat", "20", "number of iterations to benchmark the kernel")
|
||||
.insert("activation", "softmax", "activation function to use: softmax or sigmoid");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
@@ -108,6 +129,7 @@ bool test_topk_softmax(ck_tile::ArgParser args)
|
||||
int kname = args.get_int("kname");
|
||||
int warmup = args.get_int("warmup");
|
||||
int repeat = args.get_int("repeat");
|
||||
std::string activation = args.get_str("activation");
|
||||
|
||||
if(stride_input < 0)
|
||||
{
|
||||
@@ -158,7 +180,7 @@ bool test_topk_softmax(ck_tile::ArgParser args)
|
||||
|
||||
x_dev.ToDevice(x_host.data());
|
||||
|
||||
topk_softmax_trait trait{input_prec, weight_prec, experts};
|
||||
topk_softmax_trait trait{input_prec, weight_prec, experts, activation};
|
||||
|
||||
topk_softmax_kargs karg{x_dev.GetDeviceBuffer(),
|
||||
value_dev.GetDeviceBuffer(),
|
||||
@@ -175,7 +197,7 @@ bool test_topk_softmax(ck_tile::ArgParser args)
|
||||
warmup,
|
||||
repeat};
|
||||
auto ms = topk_softmax(trait, karg, sc);
|
||||
printf("[%s|%s]tokens:%d, experts:%d, topk:%d, st_i:%d, st_o:%d, ms:%f, ",
|
||||
printf("[%s|%s]tokens:%d, experts:%d, topk:%d, st_i:%d, st_o:%d, activation:%s, ms:%f, ",
|
||||
input_prec.c_str(),
|
||||
weight_prec.c_str(),
|
||||
tokens,
|
||||
@@ -183,6 +205,7 @@ bool test_topk_softmax(ck_tile::ArgParser args)
|
||||
topk,
|
||||
stride_input,
|
||||
stride_output,
|
||||
activation.c_str(),
|
||||
ms);
|
||||
if(ms < 0)
|
||||
printf("not supported\n");
|
||||
@@ -201,8 +224,20 @@ bool test_topk_softmax(ck_tile::ArgParser args)
|
||||
ck_tile::HostTensor<WeightType> value_ref({tokens, topk}, {stride_output, 1});
|
||||
ck_tile::HostTensor<IndexType> index_ref({tokens, topk}, {stride_output, 1});
|
||||
|
||||
reference_topk_softmax<InputType, WeightType, IndexType>(
|
||||
x_host, value_ref, index_ref, topk);
|
||||
if(activation == "softmax")
|
||||
{
|
||||
reference_topk_softmax<InputType, WeightType, IndexType>(
|
||||
x_host, value_ref, index_ref, topk);
|
||||
}
|
||||
else if(activation == "sigmoid")
|
||||
{
|
||||
reference_topk_sigmoid<InputType, WeightType, IndexType>(
|
||||
x_host, value_ref, index_ref, topk);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("unsupported activation type: " + activation);
|
||||
}
|
||||
|
||||
auto [rtol, atol] = get_elimit<InputType>("");
|
||||
for(int i_t = 0; i_t < tokens; i_t++)
|
||||
@@ -255,7 +290,10 @@ int run_gemm_combinations(std::string const& data_type)
|
||||
{"-t=71", "-e=11", "-k=11", "-st_i=30", "-st_o=12"},
|
||||
{"-t=1", "-e=1", "-k=1"},
|
||||
{"-t=99", "-e=2", "-k=1", "-st_i=11", "-st_o=5"},
|
||||
{"-t=333", "-e=99", "-k=13", "-st_i=191", "-st_o=17"}};
|
||||
{"-t=333", "-e=99", "-k=13", "-st_i=191", "-st_o=17"},
|
||||
{"-t=20", "-e=5", "-k=2", "-activation=sigmoid"},
|
||||
{"-t=220", "-e=9", "-k=3", "-activation=sigmoid"},
|
||||
{"-t=500", "-e=21", "-k=13", "-activation=sigmoid"}};
|
||||
|
||||
bool result = true;
|
||||
std::string pr_i = "-pr_i=" + data_type;
|
||||
|
||||
@@ -3,27 +3,31 @@
|
||||
|
||||
#include "test_topk_softmax_api.hpp"
|
||||
|
||||
#define TOPK_SOFTMAX_DISPATCH(experts_) \
|
||||
constexpr ck_tile::index_t ts_experts = experts_; \
|
||||
using ts_problem = ck_tile:: \
|
||||
TopkSoftmaxWarpPerRowProblem<ts_input_type, ts_weight_type, ts_index_type, ts_experts>; \
|
||||
using ts_pipeline = ck_tile::TopkSoftmaxWarpPerRowPipeline<ts_problem>; \
|
||||
\
|
||||
using kernel = ck_tile::TopkSoftmaxKernel<ts_pipeline>; \
|
||||
\
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
\
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(); \
|
||||
\
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(kernel{}, grids, blocks, 0, kargs)); \
|
||||
\
|
||||
#define TOPK_SOFTMAX_DISPATCH(experts_, use_softmax_) \
|
||||
constexpr ck_tile::index_t ts_experts = experts_; \
|
||||
constexpr bool ts_use_softmax = use_softmax_; \
|
||||
using ts_problem = ck_tile::TopkSoftmaxWarpPerRowProblem<ts_input_type, \
|
||||
ts_weight_type, \
|
||||
ts_index_type, \
|
||||
ts_experts, \
|
||||
ts_use_softmax>; \
|
||||
using ts_pipeline = ck_tile::TopkSoftmaxWarpPerRowPipeline<ts_problem>; \
|
||||
\
|
||||
using kernel = ck_tile::TopkSoftmaxKernel<ts_pipeline>; \
|
||||
\
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
\
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(); \
|
||||
\
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(kernel{}, grids, blocks, 0, kargs)); \
|
||||
\
|
||||
return ave_time;
|
||||
|
||||
float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_config s)
|
||||
{
|
||||
if(t.input_type == "fp16" && t.weight_type == "fp32")
|
||||
if(t.input_type == "fp16" && t.weight_type == "fp32" && t.activation == "softmax")
|
||||
{
|
||||
using ts_input_type = ck_tile::fp16_t;
|
||||
using ts_weight_type = float;
|
||||
@@ -31,36 +35,36 @@ float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_c
|
||||
#if 1
|
||||
if(t.experts <= 8)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(8)
|
||||
TOPK_SOFTMAX_DISPATCH(8, true)
|
||||
}
|
||||
else if(t.experts <= 16)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(16)
|
||||
TOPK_SOFTMAX_DISPATCH(16, true)
|
||||
}
|
||||
else if(t.experts <= 32)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(32)
|
||||
TOPK_SOFTMAX_DISPATCH(32, true)
|
||||
}
|
||||
else if(t.experts <= 64)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(64)
|
||||
TOPK_SOFTMAX_DISPATCH(64, true)
|
||||
}
|
||||
else if(t.experts <= 128)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(128)
|
||||
TOPK_SOFTMAX_DISPATCH(128, true)
|
||||
}
|
||||
else if(t.experts <= 192)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(192)
|
||||
TOPK_SOFTMAX_DISPATCH(192, true)
|
||||
}
|
||||
#else
|
||||
if(t.experts <= 128)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(128)
|
||||
TOPK_SOFTMAX_DISPATCH(128, true)
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else if(t.input_type == "bf16" && t.weight_type == "fp32")
|
||||
else if(t.input_type == "bf16" && t.weight_type == "fp32" && t.activation == "softmax")
|
||||
{
|
||||
#if 1
|
||||
using ts_input_type = ck_tile::bf16_t;
|
||||
@@ -68,27 +72,96 @@ float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_c
|
||||
using ts_index_type = ck_tile::index_t;
|
||||
if(t.experts <= 8)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(8)
|
||||
TOPK_SOFTMAX_DISPATCH(8, true)
|
||||
}
|
||||
else if(t.experts <= 16)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(16)
|
||||
TOPK_SOFTMAX_DISPATCH(16, true)
|
||||
}
|
||||
else if(t.experts <= 32)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(32)
|
||||
TOPK_SOFTMAX_DISPATCH(32, true)
|
||||
}
|
||||
else if(t.experts <= 64)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(64)
|
||||
TOPK_SOFTMAX_DISPATCH(64, true)
|
||||
}
|
||||
else if(t.experts <= 128)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(128)
|
||||
TOPK_SOFTMAX_DISPATCH(128, true)
|
||||
}
|
||||
else if(t.experts <= 192)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(192)
|
||||
TOPK_SOFTMAX_DISPATCH(192, true)
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else if(t.input_type == "fp16" && t.weight_type == "fp32" && t.activation == "sigmoid")
|
||||
{
|
||||
using ts_input_type = ck_tile::fp16_t;
|
||||
using ts_weight_type = float;
|
||||
using ts_index_type = ck_tile::index_t;
|
||||
#if 1
|
||||
if(t.experts <= 8)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(8, false)
|
||||
}
|
||||
else if(t.experts <= 16)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(16, false)
|
||||
}
|
||||
else if(t.experts <= 32)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(32, false)
|
||||
}
|
||||
else if(t.experts <= 64)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(64, false)
|
||||
}
|
||||
else if(t.experts <= 128)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(128, false)
|
||||
}
|
||||
else if(t.experts <= 192)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(192, false)
|
||||
}
|
||||
#else
|
||||
if(t.experts <= 128)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(128, false)
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else if(t.input_type == "bf16" && t.weight_type == "fp32" && t.activation == "sigmoid")
|
||||
{
|
||||
#if 1
|
||||
using ts_input_type = ck_tile::bf16_t;
|
||||
using ts_weight_type = float;
|
||||
using ts_index_type = ck_tile::index_t;
|
||||
if(t.experts <= 8)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(8, false)
|
||||
}
|
||||
else if(t.experts <= 16)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(16, false)
|
||||
}
|
||||
else if(t.experts <= 32)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(32, false)
|
||||
}
|
||||
else if(t.experts <= 64)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(64, false)
|
||||
}
|
||||
else if(t.experts <= 128)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(128, false)
|
||||
}
|
||||
else if(t.experts <= 192)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(192, false)
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ struct topk_softmax_trait
|
||||
std::string input_type;
|
||||
std::string weight_type; // currently always float
|
||||
int experts;
|
||||
std::string activation; // "softmax" or "sigmoid"
|
||||
};
|
||||
|
||||
struct topk_softmax_kargs : public ck_tile::TopkSoftmaxHostArgs
|
||||
|
||||
Reference in New Issue
Block a user