mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[CK_TILE] fused-moe first version (#1634)
* moe pipeline * update code * compile OK * update * update cpu reference * update pipeline_gemm0 * compiler ok * update pipeline * rename to ex pipeline * block-asm * update * update * update first gemm ok * compute correct * update file structure * update README * update * update * update code * update API * return unsupport case * add comment * update readme * update * uncomment * update * fix build err --------- Co-authored-by: valarLip <340077269@qq.com>
This commit is contained in:
@@ -16,7 +16,7 @@ namespace ck_tile {
|
||||
*/
|
||||
template <typename DataType>
|
||||
CK_TILE_HOST void
|
||||
reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::vector<index_t> dims)
|
||||
reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::vector<index_t> perm)
|
||||
{
|
||||
const auto x_len = x.mDesc.get_lengths();
|
||||
const auto y_len = y.mDesc.get_lengths();
|
||||
@@ -43,7 +43,7 @@ reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::v
|
||||
std::vector<size_t> tmp(rank, 0);
|
||||
for(index_t i = 0; i < rank; i++)
|
||||
{
|
||||
tmp[dims[i]] = y_coord[i];
|
||||
tmp[perm[i]] = y_coord[i];
|
||||
}
|
||||
return tmp;
|
||||
}();
|
||||
@@ -54,4 +54,23 @@ reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::v
|
||||
|
||||
make_ParallelTensorFunctor(f, x_elm)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
CK_TILE_HOST auto reference_permute(const HostTensor<DataType>& x, std::vector<index_t> perm)
|
||||
{
|
||||
auto x_shape = x.get_lengths();
|
||||
ck_tile::index_t rank = perm.size();
|
||||
std::vector<ck_tile::index_t> y_shape = [&]() {
|
||||
std::vector<ck_tile::index_t> tmp(rank, 0);
|
||||
for(int i = 0; i < static_cast<int>(rank); i++)
|
||||
{
|
||||
tmp[i] = x_shape[perm[i]];
|
||||
}
|
||||
return tmp;
|
||||
}();
|
||||
|
||||
HostTensor<DataType> y(y_shape);
|
||||
reference_permute(x, y, perm);
|
||||
return y;
|
||||
}
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user