[CK_TILE] fused-moe first version (#1634)

* moe pipeline

* update code

* compile OK

* update

* update cpu reference

* update pipeline_gemm0

* compiler ok

* update pipeline

* rename to ex pipeline

* block-asm

* update

* update

* update first gemm ok

* compute correct

* update file structure

* update README

* update

* update

* update code

* update API

* return unsupport case

* add comment

* update readme

* update

* uncomment

* update

* fix build err

---------

Co-authored-by: valarLip <340077269@qq.com>

[ROCm/composable_kernel commit: 440e28b08f]
This commit is contained in:
carlushuang
2024-11-26 11:14:56 +08:00
committed by GitHub
parent f81addbe42
commit 8acce2dee1
66 changed files with 8066 additions and 308 deletions

View File

@@ -7,6 +7,7 @@
#include <stdint.h>
#include <stdexcept>
#include "ck_tile/host/hip_check_error.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace ck_tile {
template <typename T>
@@ -36,6 +37,19 @@ struct DeviceMem
mpDeviceBuf = nullptr;
}
}
template <typename T>
DeviceMem(const HostTensor<T>& t) : mMemSize(t.get_element_space_size_in_bytes())
{
if(mMemSize != 0)
{
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
}
else
{
mpDeviceBuf = nullptr;
}
ToDevice(t.data());
}
void Realloc(std::size_t mem_size)
{
if(mpDeviceBuf)
@@ -92,6 +106,27 @@ struct DeviceMem
HIP_CHECK_ERROR(hipMemcpy(p, mpDeviceBuf, cpySize, hipMemcpyDeviceToHost));
}
}
// construct a host tensor with type T
template <typename T>
HostTensor<T> ToHost(std::size_t cpySize)
{
// TODO: host tensor could be slightly larger than the device tensor
// we just copy all data from GPU buffer
std::size_t host_elements = (cpySize + sizeof(T) - 1) / sizeof(T);
HostTensor<T> h_({host_elements});
if(mpDeviceBuf)
{
HIP_CHECK_ERROR(hipMemcpy(h_.data(), mpDeviceBuf, cpySize, hipMemcpyDeviceToHost));
}
return h_;
}
template <typename T>
HostTensor<T> ToHost()
{
return ToHost<T>(mMemSize);
}
void SetZero() const
{
if(mpDeviceBuf)

View File

@@ -13,6 +13,7 @@
#include <unordered_set>
#include "ck_tile/core.hpp"
#include "ck_tile/host/joinable_thread.hpp"
namespace ck_tile {
@@ -22,13 +23,44 @@ struct FillUniformDistribution
float a_{-5.f};
float b_{5.f};
std::optional<uint32_t> seed_{11939};
// ATTENTION: threaded does not guarantee the distribution between thread
bool threaded = false;
template <typename ForwardIter>
void operator()(ForwardIter first, ForwardIter last) const
{
std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}());
std::uniform_real_distribution<float> dis(a_, b_);
std::generate(first, last, [&dis, &gen]() { return ck_tile::type_convert<T>(dis(gen)); });
if(threaded)
{
uint32_t num_thread = std::thread::hardware_concurrency();
auto total = static_cast<std::size_t>(std::distance(first, last));
auto work_per_thread = static_cast<std::size_t>((total + num_thread - 1) / num_thread);
std::vector<joinable_thread> threads(num_thread);
for(std::size_t it = 0; it < num_thread; ++it)
{
std::size_t iw_begin = it * work_per_thread;
std::size_t iw_end = std::min((it + 1) * work_per_thread, total);
auto thread_f = [this, total, iw_begin, iw_end, &first] {
if(iw_begin > total || iw_end > total)
return;
// need to make each thread unique, add an offset to current seed
std::mt19937 gen(seed_.has_value() ? (*seed_ + iw_begin)
: std::random_device{}());
std::uniform_real_distribution<float> dis(a_, b_);
std::generate(first + iw_begin, first + iw_end, [&dis, &gen]() {
return ck_tile::type_convert<T>(dis(gen));
});
};
threads[it] = joinable_thread(thread_f);
}
}
else
{
std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}());
std::uniform_real_distribution<float> dis(a_, b_);
std::generate(
first, last, [&dis, &gen]() { return ck_tile::type_convert<T>(dis(gen)); });
}
}
template <typename ForwardRange>
@@ -115,13 +147,44 @@ struct FillNormalDistribution
float mean_{0.f};
float variance_{1.f};
std::optional<uint32_t> seed_{11939};
// ATTENTION: threaded does not guarantee the distribution between thread
bool threaded = false;
template <typename ForwardIter>
void operator()(ForwardIter first, ForwardIter last) const
{
std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}());
std::normal_distribution<float> dis(mean_, std::sqrt(variance_));
std::generate(first, last, [&dis, &gen]() { return ck_tile::type_convert<T>(dis(gen)); });
if(threaded)
{
uint32_t num_thread = std::thread::hardware_concurrency();
auto total = static_cast<std::size_t>(std::distance(first, last));
auto work_per_thread = static_cast<std::size_t>((total + num_thread - 1) / num_thread);
std::vector<joinable_thread> threads(num_thread);
for(std::size_t it = 0; it < num_thread; ++it)
{
std::size_t iw_begin = it * work_per_thread;
std::size_t iw_end = std::min((it + 1) * work_per_thread, total);
auto thread_f = [this, total, iw_begin, iw_end, &first] {
if(iw_begin > total || iw_end > total)
return;
// need to make each thread unique, add an offset to current seed
std::mt19937 gen(seed_.has_value() ? (*seed_ + iw_begin)
: std::random_device{}());
std::normal_distribution<float> dis(mean_, std::sqrt(variance_));
std::generate(first + iw_begin, first + iw_end, [&dis, &gen]() {
return ck_tile::type_convert<T>(dis(gen));
});
};
threads[it] = joinable_thread(thread_f);
}
}
else
{
std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}());
std::normal_distribution<float> dis(mean_, std::sqrt(variance_));
std::generate(
first, last, [&dis, &gen]() { return ck_tile::type_convert<T>(dis(gen)); });
}
}
template <typename ForwardRange>
@@ -235,6 +298,44 @@ struct FillMonotonicSeq
}
};
template <typename T, bool IsAscending = true>
struct FillStepRange
{
float start_value_{0};
float end_value_{3};
float step_{1};
template <typename ForwardIter>
void operator()(ForwardIter first, ForwardIter last) const
{
std::generate(first, last, [=, n = start_value_]() mutable {
auto tmp = n;
n += step_;
if constexpr(IsAscending)
{
if(n > end_value_)
n = start_value_;
}
else
{
if(n < end_value_)
n = start_value_;
}
return type_convert<T>(tmp);
});
}
template <typename ForwardRange>
auto operator()(ForwardRange&& range) const -> std::void_t<
decltype(std::declval<const FillStepRange&>()(std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range))))>
{
(*this)(std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range)));
}
};
template <typename T>
struct FillConstant
{

View File

@@ -8,12 +8,13 @@
#include <iostream>
#include <iomanip>
#include <numeric>
#include <thread>
#include <utility>
#include <vector>
#include <functional>
#include <fstream>
#include "ck_tile/core.hpp"
#include "ck_tile/host/joinable_thread.hpp"
#include "ck_tile/host/ranges.hpp"
namespace ck_tile {
@@ -213,23 +214,6 @@ CK_TILE_HOST HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old
return HostTensorDescriptor(new_lengths, new_strides);
}
struct joinable_thread : std::thread
{
template <typename... Xs>
joinable_thread(Xs&&... xs) : std::thread(std::forward<Xs>(xs)...)
{
}
joinable_thread(joinable_thread&&) = default;
joinable_thread& operator=(joinable_thread&&) = default;
~joinable_thread()
{
if(this->joinable())
this->join();
}
};
template <typename F, typename... Xs>
struct ParallelTensorFunctor
{
@@ -590,6 +574,107 @@ struct HostTensor
size() * FromSize / ToSize};
}
friend std::ostream& operator<<(std::ostream& os, const HostTensor<T>& t)
{
os << t.mDesc;
os << "[";
for(typename Data::size_type idx = 0; idx < t.mData.size(); ++idx)
{
if(0 < idx)
{
os << ", ";
}
if constexpr(std::is_same_v<T, bf16_t> || std::is_same_v<T, fp16_t>)
{
os << type_convert<float>(t.mData[idx]) << " #### ";
}
else
{
os << t.mData[idx];
}
}
os << "]";
return os;
}
// read data from a file, as dtype
// the file could dumped from torch as (targeting tensor is t here)
// numpy.savetxt("f.txt", t.view(-1).numpy())
// numpy.savetxt("f.txt", t.cpu().view(-1).numpy()) # from cuda to cpu to save
// numpy.savetxt("f.txt", t.cpu().view(-1).numpy(), fmt="%d") # save as int
// will output f.txt, each line is a value
// dtype=float or int, internally will cast to real type
void loadtxt(std::string file_name, std::string dtype = "float")
{
std::ifstream file(file_name);
if(file.is_open())
{
std::string line;
index_t cnt = 0;
while(std::getline(file, line))
{
if(cnt >= static_cast<index_t>(mData.size()))
{
throw std::runtime_error(std::string("data read from file:") + file_name +
" is too big");
}
if(dtype == "float")
{
mData[cnt] = type_convert<T>(std::stof(line));
}
else if(dtype == "int" || dtype == "int32")
{
mData[cnt] = type_convert<T>(std::stoi(line));
}
cnt++;
}
file.close();
if(cnt < static_cast<index_t>(mData.size()))
{
std::cerr << "Warning! reading from file:" << file_name
<< ", does not match the size of this tensor" << std::endl;
}
}
else
{
// Print an error message to the standard error
// stream if the file cannot be opened.
throw std::runtime_error(std::string("unable to open file:") + file_name);
}
}
// can save to a txt file and read from torch as:
// torch.from_numpy(np.loadtxt('f.txt', dtype=np.int32/np.float32...)).view([...]).contiguous()
void savetxt(std::string file_name, std::string dtype = "float")
{
std::ofstream file(file_name);
if(file.is_open())
{
for(auto& itm : mData)
{
if(dtype == "float")
file << type_convert<float>(itm) << std::endl;
else if(dtype == "int")
file << type_convert<int>(itm) << std::endl;
else
// TODO: we didn't implement operator<< for all custom
// data types, here fall back to float in case compile error
file << type_convert<float>(itm) << std::endl;
}
file.close();
}
else
{
// Print an error message to the standard error
// stream if the file cannot be opened.
throw std::runtime_error(std::string("unable to open file:") + file_name);
}
}
Descriptor mDesc;
Data mData;
};

View File

@@ -0,0 +1,27 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <thread>
#include <utility>
namespace ck_tile {
struct joinable_thread : std::thread
{
template <typename... Xs>
joinable_thread(Xs&&... xs) : std::thread(std::forward<Xs>(xs)...)
{
}
joinable_thread(joinable_thread&&) = default;
joinable_thread& operator=(joinable_thread&&) = default;
~joinable_thread()
{
if(this->joinable())
this->join();
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,196 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace ck_tile {
// [indexing implementation-1]
// using M_a as constexpr block_size to partition all tokens into different slices
// each slice map to one expert, and one expert can have multiple slices
// e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
// tok-0 tok-1 tok-2 tok-3 tok-4
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float
// number)
//
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]]
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1)
// max_num_tokens_padded : topk * input_tokens + num_experts * M_a - topk (updated)
// * this could be larger than actual, since actual tokens are on GPU
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6,
// 0, 1, 2, 5]
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4
// -|- exp-5 -|
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *,
// c, f, i, o]
//
// * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr
//
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5]
// * length is (max_num_tokens_padded + block_size - 1) / block_size
///
// num_tokens_post_padded_ptr : [28]
// num_sorted_tiles_ptr : [7]
template <typename AccDataType, // you only need to explcitly set this one
typename Activation, // ck_tile::element_wise::Gelu
typename ADataType,
typename GDataType,
typename DDataType,
typename ODataType,
typename AScaleDataType,
typename GScaleDataType,
typename DScaleDataType,
typename YSmoothScaleDataType,
typename TopkWeightDataType,
typename IndexDataType>
void reference_fused_moe(
const ck_tile::HostTensor<ADataType>& a_host, // [tokens, hidden_size]
const ck_tile::HostTensor<GDataType>& g_host, // [experts, interme_size_0, hidden_size]
const ck_tile::HostTensor<DDataType>& d_host, // [experts, hidden_size, interme_size_1]
const ck_tile::HostTensor<AScaleDataType>& sa_host, // [tokens, 1],
const ck_tile::HostTensor<GScaleDataType>& sg_host, // [experts, 1, interme_size_0]
const ck_tile::HostTensor<DScaleDataType>& sd_host, // [experts, 1, hidden_size],
const ck_tile::HostTensor<YSmoothScaleDataType>& sy_host, // [experts, 1, interme_size_0]
ck_tile::HostTensor<ODataType>& o_host, // [tokens, hidden_size]
const ck_tile::HostTensor<IndexDataType>& sorted_token_ids_host, // [max_num_tokens_padded]
const ck_tile::HostTensor<TopkWeightDataType>& sorted_weight_host, // [max_num_tokens_padded]
const ck_tile::HostTensor<IndexDataType>&
sorted_expert_ids_host, // [(max_num_tokens_padded + block_size - 1) / block_size]
const ck_tile::HostTensor<IndexDataType>& num_sorted_tiles_host, // [1]
const ck_tile::HostTensor<IndexDataType>&
token_ids_host, // [tokens, topk] --> ugly!!! remove in the future
ck_tile::index_t block_m,
ck_tile::index_t tokens,
ck_tile::index_t experts,
ck_tile::index_t hidden_size,
ck_tile::index_t intermediate_size, // this size is for gate/up
ck_tile::index_t topk,
ck_tile::index_t gate_only)
{
assert(sorted_token_ids_host.get_num_of_dimension() == 1);
assert(sorted_weight_host.get_num_of_dimension() == 1);
assert(sorted_expert_ids_host.get_num_of_dimension() == 1);
assert(num_sorted_tiles_host.get_element_size() == 1);
ck_tile::index_t num_sorted_tiles = num_sorted_tiles_host.mData[0] / block_m;
ck_tile::index_t intermediate_size_0 = intermediate_size;
ck_tile::index_t intermediate_size_1 = intermediate_size / (gate_only ? 1 : 2);
// TODO: better remove this in the future, or modify the token_id value
auto get_topk_id = [&](ck_tile::index_t token_id_, ck_tile::index_t expert_id_) {
for(ck_tile::index_t i_ = 0; i_ < topk; i_++)
{
if(token_ids_host(token_id_, i_) == expert_id_)
return i_;
}
throw std::runtime_error("not correct token/expert pair\n");
return -1; // TODO: not correct!!
};
ck_tile::HostTensor<AccDataType> out_topk_tokens({tokens, topk, hidden_size});
int max_num_tokens_padded = topk * tokens + experts * block_m - topk;
// assert();
auto f = [&](auto i_flatten) {
ck_tile::index_t i_tile = i_flatten / block_m;
if(i_tile >= num_sorted_tiles)
return;
ck_tile::index_t i_expert = sorted_expert_ids_host.mData[i_tile];
ck_tile::index_t i_token = sorted_token_ids_host.mData[i_flatten];
if(i_token >= tokens)
return;
ck_tile::index_t i_topk = get_topk_id(i_token, i_expert); // TODO: ugly
auto weight = sorted_weight_host.mData[i_flatten];
ck_tile::HostTensor<AccDataType> acc_0({1, intermediate_size_0});
// first gemm
for(ck_tile::index_t i_n = 0; i_n < intermediate_size_0; i_n++)
{
AccDataType acc = static_cast<AccDataType>(0);
for(ck_tile::index_t i_k = 0; i_k < hidden_size; i_k++)
{
acc += type_convert<AccDataType>(a_host(i_token, i_k)) *
type_convert<AccDataType>(g_host(i_expert, i_n, i_k));
}
acc_0(0, i_n) = acc;
// printf("ie:%2d, it:%3d, in:%d, %f\n", i_expert, i_token, i_n, acc);
}
ck_tile::HostTensor<AccDataType> y({1, intermediate_size_1});
if(gate_only)
{
if(intermediate_size_1 != intermediate_size_0)
throw std::runtime_error(
"intermediate_size not correct, 0:" + std::to_string(intermediate_size_0) +
", 1:" + std::to_string(intermediate_size_1));
for(ck_tile::index_t i_n = 0; i_n < intermediate_size_1; i_n++)
{
Activation{}(y(0, i_n), acc_0(0, i_n));
// printf("ie:%2d, it:%3d, in:%d, %f\n", i_expert, i_token, i_n, y(0, i_n));
}
}
else
{
if(intermediate_size_1 * 2 != intermediate_size_0)
throw std::runtime_error(
"intermediate_size not correct, 0:" + std::to_string(intermediate_size_0) +
", 1:" + std::to_string(intermediate_size_1));
for(ck_tile::index_t i_n = 0; i_n < intermediate_size_1; i_n++)
{
AccDataType tmp;
Activation{}(tmp, acc_0(0, i_n));
y(0, i_n) = tmp * acc_0(0, i_n + intermediate_size_1); // TODO: elementwise mul
}
}
// second gemm, loop along gemm-n
ck_tile::HostTensor<AccDataType> acc_1({1, hidden_size});
for(ck_tile::index_t i_n = 0; i_n < hidden_size; i_n++)
{
AccDataType acc = static_cast<AccDataType>(0);
for(ck_tile::index_t i_k = 0; i_k < intermediate_size_1; i_k++)
{
acc += y(0, i_k) * type_convert<AccDataType>(d_host(i_expert, i_n, i_k));
}
acc_1(0, i_n) = acc * weight; // multiple weight here
}
for(ck_tile::index_t i_n = 0; i_n < hidden_size; i_n++)
{
out_topk_tokens(i_token, i_topk, i_n) = acc_1(0, i_n);
}
};
// make_ParallelTensorFunctor(f, max_num_tokens_padded)(std::thread::hardware_concurrency());
make_ParallelTensorFunctor(f, max_num_tokens_padded)(1);
// reduce
auto r = [&](auto i_token) {
for(ck_tile::index_t i_n = 0; i_n < hidden_size; i_n++)
{
AccDataType acc = type_convert<AccDataType>(0);
for(ck_tile::index_t i_topk = 0; i_topk < topk; i_topk++)
{
acc += out_topk_tokens(i_token, i_topk, i_n);
}
o_host(i_token, i_n) = type_convert<ODataType>(acc);
}
};
make_ParallelTensorFunctor(r, tokens)(std::thread::hardware_concurrency());
(void)num_sorted_tiles_host;
(void)sa_host;
(void)sg_host;
(void)sd_host;
(void)sy_host;
}
} // namespace ck_tile

View File

@@ -16,7 +16,7 @@ namespace ck_tile {
*/
template <typename DataType>
CK_TILE_HOST void
reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::vector<index_t> dims)
reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::vector<index_t> perm)
{
const auto x_len = x.mDesc.get_lengths();
const auto y_len = y.mDesc.get_lengths();
@@ -43,7 +43,7 @@ reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::v
std::vector<size_t> tmp(rank, 0);
for(index_t i = 0; i < rank; i++)
{
tmp[dims[i]] = y_coord[i];
tmp[perm[i]] = y_coord[i];
}
return tmp;
}();
@@ -54,4 +54,23 @@ reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::v
make_ParallelTensorFunctor(f, x_elm)(std::thread::hardware_concurrency());
}
template <typename DataType>
CK_TILE_HOST auto reference_permute(const HostTensor<DataType>& x, std::vector<index_t> perm)
{
auto x_shape = x.get_lengths();
ck_tile::index_t rank = perm.size();
std::vector<ck_tile::index_t> y_shape = [&]() {
std::vector<ck_tile::index_t> tmp(rank, 0);
for(int i = 0; i < static_cast<int>(rank); i++)
{
tmp[i] = x_shape[perm[i]];
}
return tmp;
}();
HostTensor<DataType> y(y_shape);
reference_permute(x, y, perm);
return y;
}
} // namespace ck_tile