mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
[CK_TILE] add scatter_gather (#1609)
This commit is contained in:
@@ -23,6 +23,7 @@ enum struct coord_transform_enum
|
||||
replicate,
|
||||
xor_t,
|
||||
offset,
|
||||
indexing,
|
||||
};
|
||||
|
||||
template <index_t NDimLow, index_t NDimUp>
|
||||
@@ -1526,6 +1527,88 @@ struct offset : public base_transform<1, 1>
|
||||
}
|
||||
};
|
||||
|
||||
template <typename UpLength, typename IndexingAdaptor>
|
||||
struct indexing : public base_transform<1, 1>
|
||||
{
|
||||
static constexpr index_t NDimUp = 1;
|
||||
|
||||
using LowerIndex = multi_index<1>;
|
||||
using UpperIndex = multi_index<1>;
|
||||
|
||||
using UpLengths = decltype(make_tuple(UpLength{}));
|
||||
UpLengths up_lengths_;
|
||||
IndexingAdaptor iadaptor_;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr indexing() = default;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr indexing(const UpLength& up_length,
|
||||
const IndexingAdaptor& iadaptor)
|
||||
: up_lengths_{make_tuple(up_length)}, iadaptor_{iadaptor}
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
|
||||
{
|
||||
return coord_transform_enum::indexing;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
|
||||
|
||||
template <typename LowIdx, typename UpIdx>
|
||||
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
|
||||
const UpIdx& idx_up) const
|
||||
{
|
||||
static_assert(LowIdx::size() == 1 && UpIdx::size() == NDimUp,
|
||||
"wrong! inconsistent # of dimension");
|
||||
iadaptor_.calculate_lower_index(idx_low, idx_up);
|
||||
}
|
||||
|
||||
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
|
||||
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
|
||||
const UpIdxDiff& idx_diff_up,
|
||||
LowIdx& idx_low,
|
||||
const UpIdx& idx_up) const
|
||||
{
|
||||
// TODO: nonthing changed here
|
||||
static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == NDimUp &&
|
||||
LowIdx::size() == 1 && UpIdx::size() == NDimUp,
|
||||
"wrong! inconsistent # of dimension");
|
||||
|
||||
iadaptor_.update_lower_index(idx_diff_low, idx_diff_up, idx_low, idx_up);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool
|
||||
is_valid_upper_index_always_mapped_to_valid_lower_index()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename UpIdx>
|
||||
CK_TILE_HOST_DEVICE static constexpr bool
|
||||
is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
|
||||
{
|
||||
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
|
||||
IndexingAdaptor::is_known_at_compile_time();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("embed{");
|
||||
|
||||
//
|
||||
printf("up_lengths_: ");
|
||||
print(up_lengths_);
|
||||
printf(", ");
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
//*******************************************************************************************************
|
||||
|
||||
template <typename LowLength>
|
||||
@@ -1646,3 +1729,24 @@ CK_TILE_HOST_DEVICE constexpr auto make_offset_transform(const LowLength& low_le
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
#include "ck_tile/core/algorithm/indexing_adaptor.hpp"
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename UpLength, typename Indices>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_indexing_transform(const UpLength& up_lengths,
|
||||
const Indices& indices)
|
||||
{
|
||||
// by default we use the simplest one
|
||||
return indexing<UpLength, indexing_adaptor_onshot_cached<remove_cvref_t<Indices>>>{
|
||||
up_lengths, indexing_adaptor_onshot_cached<remove_cvref_t<Indices>>{indices}};
|
||||
}
|
||||
|
||||
template <typename UpLength, typename IndexingAdaptor>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_indexing_transform_with_adaptor(const UpLength& up_lengths, const IndexingAdaptor& iadaptor)
|
||||
{
|
||||
return indexing<UpLength, IndexingAdaptor>{up_lengths, iadaptor};
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
60
include/ck_tile/core/algorithm/indexing_adaptor.hpp
Normal file
60
include/ck_tile/core/algorithm/indexing_adaptor.hpp
Normal file
@@ -0,0 +1,60 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/multi_index.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
// pre-defined indexing adaptor used for indexing(scatter/gather)
|
||||
|
||||
// this version cache the index inside thread register(which is also prefered in real senario)
|
||||
// however it's user's responsibility that each thread only provide one indexing, which means
|
||||
// move coordinate will not change on this dim
|
||||
template <typename IndexingType>
|
||||
struct indexing_adaptor_onshot_cached
|
||||
{
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr indexing_adaptor_onshot_cached() = default;
|
||||
CK_TILE_HOST_DEVICE constexpr indexing_adaptor_onshot_cached(const IndexingType& idx)
|
||||
: cached_idx_(idx)
|
||||
{
|
||||
}
|
||||
IndexingType cached_idx_;
|
||||
|
||||
template <typename LowIdx, typename UpIdx>
|
||||
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
|
||||
const UpIdx& /*idx_up*/) const
|
||||
{
|
||||
static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
|
||||
"wrong! inconsistent # of dimension");
|
||||
|
||||
idx_low(number<0>{}) = cached_idx_;
|
||||
}
|
||||
|
||||
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
|
||||
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
|
||||
const UpIdxDiff& idx_diff_up,
|
||||
LowIdx& /*idx_low*/,
|
||||
const UpIdx& /*idx_up*/) const
|
||||
{
|
||||
// TODO: nonthing changed here
|
||||
static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
|
||||
UpIdx::size() == 1,
|
||||
"wrong! inconsistent # of dimension");
|
||||
|
||||
idx_diff_low(number<0>{}) = idx_diff_up[number<0>{}];
|
||||
|
||||
// pass the diff to lower, but not changing the actually index
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
|
||||
{
|
||||
return ck_tile::is_known_at_compile_time<IndexingType>::value;
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user