[CK_TILE] add generic_permute (#1607)

This commit is contained in:
valarLip
2024-10-29 18:05:53 +08:00
committed by GitHub
parent 922e42a039
commit 9fbd72e97e
14 changed files with 1318 additions and 0 deletions

View File

@@ -0,0 +1,169 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
// #include "ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp"
namespace ck_tile {
/* independent host side argument, no template
*/
struct GenericPermuteHostArgs
{
static constexpr index_t kMaxRanks = 8; // TODO: hardcoded
const void* p_src;
void* p_dst;
index_t rank;
index_t shape[kMaxRanks]; // input shape
index_t perm[kMaxRanks]; // permute index
};
/*
simulate torch.permute:
x_ = x_.view(x.shape[0],
x.shape[1]//16, 16,
x.shape[2]//32, 4, 8)
x_ = x_.permute(0,1,3,4,2,5)
x_ = x_.contiguous()
x_ = x_.view(x.shape[0], x.shape[1], x.shape[2]);//
this kernel is supposed not to be performant(just OK), with functional support up to kMaxRanks
dim of permutation, with a single kernel
*/
template <typename Problem_>
struct GenericPermute
{
using Problem = ck_tile::remove_cvref_t<Problem_>;
using DataType = remove_cvref_t<typename Problem::DataType>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kMaxRanks = Problem::kMaxRanks;
static constexpr bool KeepLastDim = Problem::KeepLastDim;
struct __attribute__((packed)) Kargs
{
const void* p_src;
void* p_dst;
// index_t rank;
index_t num_elements;
index_t perm_length[kMaxRanks]; // tensor length after permutation
index_t perm_stride[kMaxRanks]; // tensor stride after permutation
};
CK_TILE_HOST static constexpr index_t TotalElements(const GenericPermuteHostArgs& h)
{
index_t n = 1;
for(auto i = 0; i < h.rank; i++)
{
n *= h.shape[i];
}
return n;
}
CK_TILE_HOST static constexpr Kargs MakeKargs(const GenericPermuteHostArgs& h)
{
Kargs a;
a.p_src = h.p_src;
a.p_dst = h.p_dst;
// assert rank <= kMaxRanks
index_t i = 0;
index_t perm[kMaxRanks];
index_t x_shape[kMaxRanks];
index_t x_stride[kMaxRanks];
// index_t perm_length[kMaxRanks];
for(; i < h.rank; i++)
{
x_shape[i] = h.shape[i];
perm[i] = h.perm[i];
}
for(; i < kMaxRanks; i++)
{
x_shape[i] = 1;
perm[i] = i; // will index to len = 1
}
index_t stride = 1;
for(index_t j = kMaxRanks - 1; j >= 0; j--)
{
x_stride[j] = stride;
stride *= x_shape[j];
}
for(index_t j = 0; j < kMaxRanks; j++)
{
a.perm_length[j] = x_shape[perm[j]];
a.perm_stride[j] = x_stride[perm[j]];
}
a.num_elements = TotalElements(h);
return a;
}
CK_TILE_HOST static constexpr auto GridSize(GenericPermuteHostArgs h)
{
auto total = TotalElements(h);
auto grids = dim3((total + BlockSize() - 1) / BlockSize());
// printf("### total:%d, grids:%dx%dx%d\n", total, );
return grids;
}
CK_TILE_HOST_DEVICE static constexpr auto BlockSize() { return Problem::kBlockSize; }
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
index_t id = blockIdx.x * BlockSize() + threadIdx.x;
if(id >= kargs.num_elements)
return;
const auto perm_length =
generate_tuple([&](auto I) { return kargs.perm_length[I]; }, number<kMaxRanks>{});
const auto perm_stride =
generate_tuple([&](auto I) { return kargs.perm_stride[I]; }, number<kMaxRanks>{});
const DataType* p_src = reinterpret_cast<const DataType*>(kargs.p_src);
DataType* p_dst = reinterpret_cast<DataType*>(kargs.p_dst);
const auto src_view_0 = make_naive_tensor_view<address_space_enum::global>(
p_src, perm_length, perm_stride, number<1>{}, number<1>{});
const auto src_view = transform_tensor_view(
src_view_0,
make_tuple(make_merge_transform(perm_length)),
make_tuple(typename arithmetic_sequence_gen<0, kMaxRanks, 1>::type{}),
make_tuple(sequence<0>{}));
auto dst_view_0 = make_naive_tensor_view_packed<address_space_enum::global>(
p_dst, perm_length, number<1>{});
auto dst_view = transform_tensor_view(
dst_view_0,
make_tuple(make_merge_transform(perm_length)),
make_tuple(typename arithmetic_sequence_gen<0, kMaxRanks, 1>::type{}),
make_tuple(sequence<0>{}));
// TODO: hard code to vector 1
using vector_t = thread_buffer<DataType, 1>;
const auto src_coord =
make_tensor_coordinate(src_view.get_tensor_descriptor(), array<index_t, 1>{id});
const auto dst_coord =
make_tensor_coordinate(dst_view.get_tensor_descriptor(), array<index_t, 1>{id});
// printf("src id:%d, os:%d\n", id, src_coord.get_offset());
// printf("dst id:%d, os:%d\n", id, dst_coord.get_offset());
const vector_t x = src_view.template get_vectorized_elements<vector_t>(src_coord, 0);
dst_view.template set_vectorized_elements<vector_t>(dst_coord, 0, x);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,28 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
template <typename DataType_,
index_t kBlockSize_ = 256,
index_t kMaxRanks_ = 8,
bool KeepLastDim_ = false>
struct GenericPermuteProblem
{
using DataType = remove_cvref_t<DataType_>;
static constexpr index_t kBlockSize = kBlockSize_;
static constexpr index_t kMaxRanks = kMaxRanks_;
/* KeepLastDim:
* if last dim keep the same? this can help enable vector load
* permute(0, 2, 4, 1, 3, 5) -> true
* permute(0, 3, 2, 1) -> false
*/
static constexpr bool KeepLastDim = KeepLastDim_;
// TODO: not used(?)
};
} // namespace ck_tile