[CK_TILE] add scatter_gather (#1609)

This commit is contained in:
valarLip
2024-10-29 18:19:29 +08:00
committed by GitHub
parent 9fbd72e97e
commit 4d7e063a0a
6 changed files with 444 additions and 0 deletions

View File

@@ -23,6 +23,7 @@ enum struct coord_transform_enum
replicate,
xor_t,
offset,
indexing,
};
template <index_t NDimLow, index_t NDimUp>
@@ -1526,6 +1527,88 @@ struct offset : public base_transform<1, 1>
}
};
template <typename UpLength, typename IndexingAdaptor>
struct indexing : public base_transform<1, 1>
{
static constexpr index_t NDimUp = 1;
using LowerIndex = multi_index<1>;
using UpperIndex = multi_index<1>;
using UpLengths = decltype(make_tuple(UpLength{}));
UpLengths up_lengths_;
IndexingAdaptor iadaptor_;
CK_TILE_HOST_DEVICE constexpr indexing() = default;
CK_TILE_HOST_DEVICE constexpr indexing(const UpLength& up_length,
const IndexingAdaptor& iadaptor)
: up_lengths_{make_tuple(up_length)}, iadaptor_{iadaptor}
{
}
CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
{
return coord_transform_enum::indexing;
}
CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
const UpIdx& idx_up) const
{
static_assert(LowIdx::size() == 1 && UpIdx::size() == NDimUp,
"wrong! inconsistent # of dimension");
iadaptor_.calculate_lower_index(idx_low, idx_up);
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx& idx_up) const
{
// TODO: nonthing changed here
static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == NDimUp &&
LowIdx::size() == 1 && UpIdx::size() == NDimUp,
"wrong! inconsistent # of dimension");
iadaptor_.update_lower_index(idx_diff_low, idx_diff_up, idx_low, idx_up);
}
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_always_mapped_to_valid_lower_index()
{
return true;
}
template <typename UpIdx>
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */)
{
return true;
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
IndexingAdaptor::is_known_at_compile_time();
}
CK_TILE_HOST_DEVICE void print() const
{
printf("embed{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
printf("}");
}
};
//*******************************************************************************************************
template <typename LowLength>
@@ -1646,3 +1729,24 @@ CK_TILE_HOST_DEVICE constexpr auto make_offset_transform(const LowLength& low_le
}
} // namespace ck_tile
#include "ck_tile/core/algorithm/indexing_adaptor.hpp"
namespace ck_tile {
template <typename UpLength, typename Indices>
CK_TILE_HOST_DEVICE constexpr auto make_indexing_transform(const UpLength& up_lengths,
const Indices& indices)
{
// by default we use the simplest one
return indexing<UpLength, indexing_adaptor_onshot_cached<remove_cvref_t<Indices>>>{
up_lengths, indexing_adaptor_onshot_cached<remove_cvref_t<Indices>>{indices}};
}
template <typename UpLength, typename IndexingAdaptor>
CK_TILE_HOST_DEVICE constexpr auto
make_indexing_transform_with_adaptor(const UpLength& up_lengths, const IndexingAdaptor& iadaptor)
{
return indexing<UpLength, IndexingAdaptor>{up_lengths, iadaptor};
}
} // namespace ck_tile

View File

@@ -0,0 +1,60 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/multi_index.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
// pre-defined indexing adaptor used for indexing(scatter/gather)
// this version cache the index inside thread register(which is also prefered in real senario)
// however it's user's responsibility that each thread only provide one indexing, which means
// move coordinate will not change on this dim
template <typename IndexingType>
struct indexing_adaptor_onshot_cached
{
CK_TILE_HOST_DEVICE constexpr indexing_adaptor_onshot_cached() = default;
CK_TILE_HOST_DEVICE constexpr indexing_adaptor_onshot_cached(const IndexingType& idx)
: cached_idx_(idx)
{
}
IndexingType cached_idx_;
template <typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
const UpIdx& /*idx_up*/) const
{
static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
idx_low(number<0>{}) = cached_idx_;
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& /*idx_low*/,
const UpIdx& /*idx_up*/) const
{
// TODO: nonthing changed here
static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
idx_diff_low(number<0>{}) = idx_diff_up[number<0>{}];
// pass the diff to lower, but not changing the actually index
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<IndexingType>::value;
}
};
} // namespace ck_tile