mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 20:40:07 +00:00
[CK_TILE] fused-moe first version (#1634)
* moe pipeline
* update code
* compile OK
* update
* update cpu reference
* update pipeline_gemm0
* compiler ok
* update pipeline
* rename to ex pipeline
* block-asm
* update
* update
* update first gemm ok
* compute correct
* update file structure
* update README
* update
* update
* update code
* update API
* return unsupport case
* add comment
* update readme
* update
* uncomment
* update
* fix build err
---------
Co-authored-by: valarLip <340077269@qq.com>
[ROCm/composable_kernel commit: 440e28b08f]
This commit is contained in:
@@ -7,6 +7,7 @@
|
||||
#include <stdint.h>
|
||||
#include <stdexcept>
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
template <typename T>
|
||||
@@ -36,6 +37,19 @@ struct DeviceMem
|
||||
mpDeviceBuf = nullptr;
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
DeviceMem(const HostTensor<T>& t) : mMemSize(t.get_element_space_size_in_bytes())
|
||||
{
|
||||
if(mMemSize != 0)
|
||||
{
|
||||
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
|
||||
}
|
||||
else
|
||||
{
|
||||
mpDeviceBuf = nullptr;
|
||||
}
|
||||
ToDevice(t.data());
|
||||
}
|
||||
void Realloc(std::size_t mem_size)
|
||||
{
|
||||
if(mpDeviceBuf)
|
||||
@@ -92,6 +106,27 @@ struct DeviceMem
|
||||
HIP_CHECK_ERROR(hipMemcpy(p, mpDeviceBuf, cpySize, hipMemcpyDeviceToHost));
|
||||
}
|
||||
}
|
||||
|
||||
// construct a host tensor with type T
|
||||
template <typename T>
|
||||
HostTensor<T> ToHost(std::size_t cpySize)
|
||||
{
|
||||
// TODO: host tensor could be slightly larger than the device tensor
|
||||
// we just copy all data from GPU buffer
|
||||
std::size_t host_elements = (cpySize + sizeof(T) - 1) / sizeof(T);
|
||||
HostTensor<T> h_({host_elements});
|
||||
if(mpDeviceBuf)
|
||||
{
|
||||
HIP_CHECK_ERROR(hipMemcpy(h_.data(), mpDeviceBuf, cpySize, hipMemcpyDeviceToHost));
|
||||
}
|
||||
return h_;
|
||||
}
|
||||
template <typename T>
|
||||
HostTensor<T> ToHost()
|
||||
{
|
||||
return ToHost<T>(mMemSize);
|
||||
}
|
||||
|
||||
void SetZero() const
|
||||
{
|
||||
if(mpDeviceBuf)
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
#include <unordered_set>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/joinable_thread.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -22,13 +23,44 @@ struct FillUniformDistribution
|
||||
float a_{-5.f};
|
||||
float b_{5.f};
|
||||
std::optional<uint32_t> seed_{11939};
|
||||
// ATTENTION: threaded does not guarantee the distribution between thread
|
||||
bool threaded = false;
|
||||
|
||||
template <typename ForwardIter>
|
||||
void operator()(ForwardIter first, ForwardIter last) const
|
||||
{
|
||||
std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}());
|
||||
std::uniform_real_distribution<float> dis(a_, b_);
|
||||
std::generate(first, last, [&dis, &gen]() { return ck_tile::type_convert<T>(dis(gen)); });
|
||||
if(threaded)
|
||||
{
|
||||
uint32_t num_thread = std::thread::hardware_concurrency();
|
||||
auto total = static_cast<std::size_t>(std::distance(first, last));
|
||||
auto work_per_thread = static_cast<std::size_t>((total + num_thread - 1) / num_thread);
|
||||
|
||||
std::vector<joinable_thread> threads(num_thread);
|
||||
for(std::size_t it = 0; it < num_thread; ++it)
|
||||
{
|
||||
std::size_t iw_begin = it * work_per_thread;
|
||||
std::size_t iw_end = std::min((it + 1) * work_per_thread, total);
|
||||
auto thread_f = [this, total, iw_begin, iw_end, &first] {
|
||||
if(iw_begin > total || iw_end > total)
|
||||
return;
|
||||
// need to make each thread unique, add an offset to current seed
|
||||
std::mt19937 gen(seed_.has_value() ? (*seed_ + iw_begin)
|
||||
: std::random_device{}());
|
||||
std::uniform_real_distribution<float> dis(a_, b_);
|
||||
std::generate(first + iw_begin, first + iw_end, [&dis, &gen]() {
|
||||
return ck_tile::type_convert<T>(dis(gen));
|
||||
});
|
||||
};
|
||||
threads[it] = joinable_thread(thread_f);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}());
|
||||
std::uniform_real_distribution<float> dis(a_, b_);
|
||||
std::generate(
|
||||
first, last, [&dis, &gen]() { return ck_tile::type_convert<T>(dis(gen)); });
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ForwardRange>
|
||||
@@ -115,13 +147,44 @@ struct FillNormalDistribution
|
||||
float mean_{0.f};
|
||||
float variance_{1.f};
|
||||
std::optional<uint32_t> seed_{11939};
|
||||
// ATTENTION: threaded does not guarantee the distribution between thread
|
||||
bool threaded = false;
|
||||
|
||||
template <typename ForwardIter>
|
||||
void operator()(ForwardIter first, ForwardIter last) const
|
||||
{
|
||||
std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}());
|
||||
std::normal_distribution<float> dis(mean_, std::sqrt(variance_));
|
||||
std::generate(first, last, [&dis, &gen]() { return ck_tile::type_convert<T>(dis(gen)); });
|
||||
if(threaded)
|
||||
{
|
||||
uint32_t num_thread = std::thread::hardware_concurrency();
|
||||
auto total = static_cast<std::size_t>(std::distance(first, last));
|
||||
auto work_per_thread = static_cast<std::size_t>((total + num_thread - 1) / num_thread);
|
||||
|
||||
std::vector<joinable_thread> threads(num_thread);
|
||||
for(std::size_t it = 0; it < num_thread; ++it)
|
||||
{
|
||||
std::size_t iw_begin = it * work_per_thread;
|
||||
std::size_t iw_end = std::min((it + 1) * work_per_thread, total);
|
||||
auto thread_f = [this, total, iw_begin, iw_end, &first] {
|
||||
if(iw_begin > total || iw_end > total)
|
||||
return;
|
||||
// need to make each thread unique, add an offset to current seed
|
||||
std::mt19937 gen(seed_.has_value() ? (*seed_ + iw_begin)
|
||||
: std::random_device{}());
|
||||
std::normal_distribution<float> dis(mean_, std::sqrt(variance_));
|
||||
std::generate(first + iw_begin, first + iw_end, [&dis, &gen]() {
|
||||
return ck_tile::type_convert<T>(dis(gen));
|
||||
});
|
||||
};
|
||||
threads[it] = joinable_thread(thread_f);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}());
|
||||
std::normal_distribution<float> dis(mean_, std::sqrt(variance_));
|
||||
std::generate(
|
||||
first, last, [&dis, &gen]() { return ck_tile::type_convert<T>(dis(gen)); });
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ForwardRange>
|
||||
@@ -235,6 +298,44 @@ struct FillMonotonicSeq
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, bool IsAscending = true>
|
||||
struct FillStepRange
|
||||
{
|
||||
float start_value_{0};
|
||||
float end_value_{3};
|
||||
float step_{1};
|
||||
|
||||
template <typename ForwardIter>
|
||||
void operator()(ForwardIter first, ForwardIter last) const
|
||||
{
|
||||
std::generate(first, last, [=, n = start_value_]() mutable {
|
||||
auto tmp = n;
|
||||
n += step_;
|
||||
if constexpr(IsAscending)
|
||||
{
|
||||
if(n > end_value_)
|
||||
n = start_value_;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(n < end_value_)
|
||||
n = start_value_;
|
||||
}
|
||||
|
||||
return type_convert<T>(tmp);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename ForwardRange>
|
||||
auto operator()(ForwardRange&& range) const -> std::void_t<
|
||||
decltype(std::declval<const FillStepRange&>()(std::begin(std::forward<ForwardRange>(range)),
|
||||
std::end(std::forward<ForwardRange>(range))))>
|
||||
{
|
||||
(*this)(std::begin(std::forward<ForwardRange>(range)),
|
||||
std::end(std::forward<ForwardRange>(range)));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct FillConstant
|
||||
{
|
||||
|
||||
@@ -8,12 +8,13 @@
|
||||
#include <iostream>
|
||||
#include <iomanip>
|
||||
#include <numeric>
|
||||
#include <thread>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
#include <fstream>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/joinable_thread.hpp"
|
||||
#include "ck_tile/host/ranges.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
@@ -213,23 +214,6 @@ CK_TILE_HOST HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old
|
||||
return HostTensorDescriptor(new_lengths, new_strides);
|
||||
}
|
||||
|
||||
struct joinable_thread : std::thread
|
||||
{
|
||||
template <typename... Xs>
|
||||
joinable_thread(Xs&&... xs) : std::thread(std::forward<Xs>(xs)...)
|
||||
{
|
||||
}
|
||||
|
||||
joinable_thread(joinable_thread&&) = default;
|
||||
joinable_thread& operator=(joinable_thread&&) = default;
|
||||
|
||||
~joinable_thread()
|
||||
{
|
||||
if(this->joinable())
|
||||
this->join();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename F, typename... Xs>
|
||||
struct ParallelTensorFunctor
|
||||
{
|
||||
@@ -590,6 +574,107 @@ struct HostTensor
|
||||
size() * FromSize / ToSize};
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const HostTensor<T>& t)
|
||||
{
|
||||
os << t.mDesc;
|
||||
os << "[";
|
||||
for(typename Data::size_type idx = 0; idx < t.mData.size(); ++idx)
|
||||
{
|
||||
if(0 < idx)
|
||||
{
|
||||
os << ", ";
|
||||
}
|
||||
if constexpr(std::is_same_v<T, bf16_t> || std::is_same_v<T, fp16_t>)
|
||||
{
|
||||
os << type_convert<float>(t.mData[idx]) << " #### ";
|
||||
}
|
||||
else
|
||||
{
|
||||
os << t.mData[idx];
|
||||
}
|
||||
}
|
||||
os << "]";
|
||||
return os;
|
||||
}
|
||||
|
||||
// read data from a file, as dtype
|
||||
// the file could dumped from torch as (targeting tensor is t here)
|
||||
// numpy.savetxt("f.txt", t.view(-1).numpy())
|
||||
// numpy.savetxt("f.txt", t.cpu().view(-1).numpy()) # from cuda to cpu to save
|
||||
// numpy.savetxt("f.txt", t.cpu().view(-1).numpy(), fmt="%d") # save as int
|
||||
// will output f.txt, each line is a value
|
||||
// dtype=float or int, internally will cast to real type
|
||||
void loadtxt(std::string file_name, std::string dtype = "float")
|
||||
{
|
||||
std::ifstream file(file_name);
|
||||
|
||||
if(file.is_open())
|
||||
{
|
||||
std::string line;
|
||||
|
||||
index_t cnt = 0;
|
||||
while(std::getline(file, line))
|
||||
{
|
||||
if(cnt >= static_cast<index_t>(mData.size()))
|
||||
{
|
||||
throw std::runtime_error(std::string("data read from file:") + file_name +
|
||||
" is too big");
|
||||
}
|
||||
|
||||
if(dtype == "float")
|
||||
{
|
||||
mData[cnt] = type_convert<T>(std::stof(line));
|
||||
}
|
||||
else if(dtype == "int" || dtype == "int32")
|
||||
{
|
||||
mData[cnt] = type_convert<T>(std::stoi(line));
|
||||
}
|
||||
cnt++;
|
||||
}
|
||||
file.close();
|
||||
if(cnt < static_cast<index_t>(mData.size()))
|
||||
{
|
||||
std::cerr << "Warning! reading from file:" << file_name
|
||||
<< ", does not match the size of this tensor" << std::endl;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Print an error message to the standard error
|
||||
// stream if the file cannot be opened.
|
||||
throw std::runtime_error(std::string("unable to open file:") + file_name);
|
||||
}
|
||||
}
|
||||
|
||||
// can save to a txt file and read from torch as:
|
||||
// torch.from_numpy(np.loadtxt('f.txt', dtype=np.int32/np.float32...)).view([...]).contiguous()
|
||||
void savetxt(std::string file_name, std::string dtype = "float")
|
||||
{
|
||||
std::ofstream file(file_name);
|
||||
|
||||
if(file.is_open())
|
||||
{
|
||||
for(auto& itm : mData)
|
||||
{
|
||||
if(dtype == "float")
|
||||
file << type_convert<float>(itm) << std::endl;
|
||||
else if(dtype == "int")
|
||||
file << type_convert<int>(itm) << std::endl;
|
||||
else
|
||||
// TODO: we didn't implement operator<< for all custom
|
||||
// data types, here fall back to float in case compile error
|
||||
file << type_convert<float>(itm) << std::endl;
|
||||
}
|
||||
file.close();
|
||||
}
|
||||
else
|
||||
{
|
||||
// Print an error message to the standard error
|
||||
// stream if the file cannot be opened.
|
||||
throw std::runtime_error(std::string("unable to open file:") + file_name);
|
||||
}
|
||||
}
|
||||
|
||||
Descriptor mDesc;
|
||||
Data mData;
|
||||
};
|
||||
|
||||
27
include/ck_tile/host/joinable_thread.hpp
Normal file
27
include/ck_tile/host/joinable_thread.hpp
Normal file
@@ -0,0 +1,27 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <thread>
|
||||
#include <utility>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct joinable_thread : std::thread
|
||||
{
|
||||
template <typename... Xs>
|
||||
joinable_thread(Xs&&... xs) : std::thread(std::forward<Xs>(xs)...)
|
||||
{
|
||||
}
|
||||
|
||||
joinable_thread(joinable_thread&&) = default;
|
||||
joinable_thread& operator=(joinable_thread&&) = default;
|
||||
|
||||
~joinable_thread()
|
||||
{
|
||||
if(this->joinable())
|
||||
this->join();
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
196
include/ck_tile/host/reference/reference_fused_moe.hpp
Normal file
196
include/ck_tile/host/reference/reference_fused_moe.hpp
Normal file
@@ -0,0 +1,196 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
// [indexing implementation-1]
|
||||
// using M_a as constexpr block_size to partition all tokens into different slices
|
||||
// each slice map to one expert, and one expert can have multiple slices
|
||||
// e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5
|
||||
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
|
||||
// tok-0 tok-1 tok-2 tok-3 tok-4
|
||||
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float
|
||||
// number)
|
||||
//
|
||||
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]]
|
||||
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
|
||||
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
|
||||
//
|
||||
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1)
|
||||
// max_num_tokens_padded : topk * input_tokens + num_experts * M_a - topk (updated)
|
||||
// * this could be larger than actual, since actual tokens are on GPU
|
||||
//
|
||||
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6,
|
||||
// 0, 1, 2, 5]
|
||||
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4
|
||||
// -|- exp-5 -|
|
||||
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *,
|
||||
// c, f, i, o]
|
||||
//
|
||||
// * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr
|
||||
//
|
||||
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5]
|
||||
// * length is (max_num_tokens_padded + block_size - 1) / block_size
|
||||
///
|
||||
// num_tokens_post_padded_ptr : [28]
|
||||
// num_sorted_tiles_ptr : [7]
|
||||
|
||||
template <typename AccDataType, // you only need to explcitly set this one
|
||||
typename Activation, // ck_tile::element_wise::Gelu
|
||||
typename ADataType,
|
||||
typename GDataType,
|
||||
typename DDataType,
|
||||
typename ODataType,
|
||||
typename AScaleDataType,
|
||||
typename GScaleDataType,
|
||||
typename DScaleDataType,
|
||||
typename YSmoothScaleDataType,
|
||||
typename TopkWeightDataType,
|
||||
typename IndexDataType>
|
||||
void reference_fused_moe(
|
||||
const ck_tile::HostTensor<ADataType>& a_host, // [tokens, hidden_size]
|
||||
const ck_tile::HostTensor<GDataType>& g_host, // [experts, interme_size_0, hidden_size]
|
||||
const ck_tile::HostTensor<DDataType>& d_host, // [experts, hidden_size, interme_size_1]
|
||||
const ck_tile::HostTensor<AScaleDataType>& sa_host, // [tokens, 1],
|
||||
const ck_tile::HostTensor<GScaleDataType>& sg_host, // [experts, 1, interme_size_0]
|
||||
const ck_tile::HostTensor<DScaleDataType>& sd_host, // [experts, 1, hidden_size],
|
||||
const ck_tile::HostTensor<YSmoothScaleDataType>& sy_host, // [experts, 1, interme_size_0]
|
||||
ck_tile::HostTensor<ODataType>& o_host, // [tokens, hidden_size]
|
||||
const ck_tile::HostTensor<IndexDataType>& sorted_token_ids_host, // [max_num_tokens_padded]
|
||||
const ck_tile::HostTensor<TopkWeightDataType>& sorted_weight_host, // [max_num_tokens_padded]
|
||||
const ck_tile::HostTensor<IndexDataType>&
|
||||
sorted_expert_ids_host, // [(max_num_tokens_padded + block_size - 1) / block_size]
|
||||
const ck_tile::HostTensor<IndexDataType>& num_sorted_tiles_host, // [1]
|
||||
|
||||
const ck_tile::HostTensor<IndexDataType>&
|
||||
token_ids_host, // [tokens, topk] --> ugly!!! remove in the future
|
||||
|
||||
ck_tile::index_t block_m,
|
||||
ck_tile::index_t tokens,
|
||||
ck_tile::index_t experts,
|
||||
ck_tile::index_t hidden_size,
|
||||
ck_tile::index_t intermediate_size, // this size is for gate/up
|
||||
ck_tile::index_t topk,
|
||||
ck_tile::index_t gate_only)
|
||||
{
|
||||
assert(sorted_token_ids_host.get_num_of_dimension() == 1);
|
||||
assert(sorted_weight_host.get_num_of_dimension() == 1);
|
||||
assert(sorted_expert_ids_host.get_num_of_dimension() == 1);
|
||||
assert(num_sorted_tiles_host.get_element_size() == 1);
|
||||
ck_tile::index_t num_sorted_tiles = num_sorted_tiles_host.mData[0] / block_m;
|
||||
ck_tile::index_t intermediate_size_0 = intermediate_size;
|
||||
ck_tile::index_t intermediate_size_1 = intermediate_size / (gate_only ? 1 : 2);
|
||||
|
||||
// TODO: better remove this in the future, or modify the token_id value
|
||||
auto get_topk_id = [&](ck_tile::index_t token_id_, ck_tile::index_t expert_id_) {
|
||||
for(ck_tile::index_t i_ = 0; i_ < topk; i_++)
|
||||
{
|
||||
if(token_ids_host(token_id_, i_) == expert_id_)
|
||||
return i_;
|
||||
}
|
||||
throw std::runtime_error("not correct token/expert pair\n");
|
||||
return -1; // TODO: not correct!!
|
||||
};
|
||||
|
||||
ck_tile::HostTensor<AccDataType> out_topk_tokens({tokens, topk, hidden_size});
|
||||
|
||||
int max_num_tokens_padded = topk * tokens + experts * block_m - topk;
|
||||
// assert();
|
||||
auto f = [&](auto i_flatten) {
|
||||
ck_tile::index_t i_tile = i_flatten / block_m;
|
||||
if(i_tile >= num_sorted_tiles)
|
||||
return;
|
||||
ck_tile::index_t i_expert = sorted_expert_ids_host.mData[i_tile];
|
||||
ck_tile::index_t i_token = sorted_token_ids_host.mData[i_flatten];
|
||||
if(i_token >= tokens)
|
||||
return;
|
||||
ck_tile::index_t i_topk = get_topk_id(i_token, i_expert); // TODO: ugly
|
||||
auto weight = sorted_weight_host.mData[i_flatten];
|
||||
|
||||
ck_tile::HostTensor<AccDataType> acc_0({1, intermediate_size_0});
|
||||
// first gemm
|
||||
for(ck_tile::index_t i_n = 0; i_n < intermediate_size_0; i_n++)
|
||||
{
|
||||
AccDataType acc = static_cast<AccDataType>(0);
|
||||
for(ck_tile::index_t i_k = 0; i_k < hidden_size; i_k++)
|
||||
{
|
||||
acc += type_convert<AccDataType>(a_host(i_token, i_k)) *
|
||||
type_convert<AccDataType>(g_host(i_expert, i_n, i_k));
|
||||
}
|
||||
acc_0(0, i_n) = acc;
|
||||
// printf("ie:%2d, it:%3d, in:%d, %f\n", i_expert, i_token, i_n, acc);
|
||||
}
|
||||
|
||||
ck_tile::HostTensor<AccDataType> y({1, intermediate_size_1});
|
||||
if(gate_only)
|
||||
{
|
||||
if(intermediate_size_1 != intermediate_size_0)
|
||||
throw std::runtime_error(
|
||||
"intermediate_size not correct, 0:" + std::to_string(intermediate_size_0) +
|
||||
", 1:" + std::to_string(intermediate_size_1));
|
||||
for(ck_tile::index_t i_n = 0; i_n < intermediate_size_1; i_n++)
|
||||
{
|
||||
Activation{}(y(0, i_n), acc_0(0, i_n));
|
||||
// printf("ie:%2d, it:%3d, in:%d, %f\n", i_expert, i_token, i_n, y(0, i_n));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(intermediate_size_1 * 2 != intermediate_size_0)
|
||||
throw std::runtime_error(
|
||||
"intermediate_size not correct, 0:" + std::to_string(intermediate_size_0) +
|
||||
", 1:" + std::to_string(intermediate_size_1));
|
||||
for(ck_tile::index_t i_n = 0; i_n < intermediate_size_1; i_n++)
|
||||
{
|
||||
AccDataType tmp;
|
||||
Activation{}(tmp, acc_0(0, i_n));
|
||||
y(0, i_n) = tmp * acc_0(0, i_n + intermediate_size_1); // TODO: elementwise mul
|
||||
}
|
||||
}
|
||||
|
||||
// second gemm, loop along gemm-n
|
||||
ck_tile::HostTensor<AccDataType> acc_1({1, hidden_size});
|
||||
for(ck_tile::index_t i_n = 0; i_n < hidden_size; i_n++)
|
||||
{
|
||||
AccDataType acc = static_cast<AccDataType>(0);
|
||||
for(ck_tile::index_t i_k = 0; i_k < intermediate_size_1; i_k++)
|
||||
{
|
||||
acc += y(0, i_k) * type_convert<AccDataType>(d_host(i_expert, i_n, i_k));
|
||||
}
|
||||
acc_1(0, i_n) = acc * weight; // multiple weight here
|
||||
}
|
||||
|
||||
for(ck_tile::index_t i_n = 0; i_n < hidden_size; i_n++)
|
||||
{
|
||||
out_topk_tokens(i_token, i_topk, i_n) = acc_1(0, i_n);
|
||||
}
|
||||
};
|
||||
|
||||
// make_ParallelTensorFunctor(f, max_num_tokens_padded)(std::thread::hardware_concurrency());
|
||||
make_ParallelTensorFunctor(f, max_num_tokens_padded)(1);
|
||||
|
||||
// reduce
|
||||
auto r = [&](auto i_token) {
|
||||
for(ck_tile::index_t i_n = 0; i_n < hidden_size; i_n++)
|
||||
{
|
||||
AccDataType acc = type_convert<AccDataType>(0);
|
||||
for(ck_tile::index_t i_topk = 0; i_topk < topk; i_topk++)
|
||||
{
|
||||
acc += out_topk_tokens(i_token, i_topk, i_n);
|
||||
}
|
||||
o_host(i_token, i_n) = type_convert<ODataType>(acc);
|
||||
}
|
||||
};
|
||||
make_ParallelTensorFunctor(r, tokens)(std::thread::hardware_concurrency());
|
||||
|
||||
(void)num_sorted_tiles_host;
|
||||
(void)sa_host;
|
||||
(void)sg_host;
|
||||
(void)sd_host;
|
||||
(void)sy_host;
|
||||
}
|
||||
} // namespace ck_tile
|
||||
@@ -16,7 +16,7 @@ namespace ck_tile {
|
||||
*/
|
||||
template <typename DataType>
|
||||
CK_TILE_HOST void
|
||||
reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::vector<index_t> dims)
|
||||
reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::vector<index_t> perm)
|
||||
{
|
||||
const auto x_len = x.mDesc.get_lengths();
|
||||
const auto y_len = y.mDesc.get_lengths();
|
||||
@@ -43,7 +43,7 @@ reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::v
|
||||
std::vector<size_t> tmp(rank, 0);
|
||||
for(index_t i = 0; i < rank; i++)
|
||||
{
|
||||
tmp[dims[i]] = y_coord[i];
|
||||
tmp[perm[i]] = y_coord[i];
|
||||
}
|
||||
return tmp;
|
||||
}();
|
||||
@@ -54,4 +54,23 @@ reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::v
|
||||
|
||||
make_ParallelTensorFunctor(f, x_elm)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
CK_TILE_HOST auto reference_permute(const HostTensor<DataType>& x, std::vector<index_t> perm)
|
||||
{
|
||||
auto x_shape = x.get_lengths();
|
||||
ck_tile::index_t rank = perm.size();
|
||||
std::vector<ck_tile::index_t> y_shape = [&]() {
|
||||
std::vector<ck_tile::index_t> tmp(rank, 0);
|
||||
for(int i = 0; i < static_cast<int>(rank); i++)
|
||||
{
|
||||
tmp[i] = x_shape[perm[i]];
|
||||
}
|
||||
return tmp;
|
||||
}();
|
||||
|
||||
HostTensor<DataType> y(y_shape);
|
||||
reference_permute(x, y, perm);
|
||||
return y;
|
||||
}
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user