mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
Add a version of Merge transform that use integerdivision and mod (#25)
* add Merg_v3_division_mod * refactor
This commit is contained in:
@@ -1327,6 +1327,129 @@ struct Merge_v2r2_magic_division
|
||||
}
|
||||
};
|
||||
|
||||
// Implementation of "Merge" transformation primitive that uses division and mod. It is supposed to
|
||||
// be used for low_lengths that are known at compile time and are power of 2, otherwise performance
|
||||
// will be very bad
|
||||
template <typename LowLengths>
|
||||
struct Merge_v3_division_mod
|
||||
{
|
||||
static constexpr index_t NDimLow = LowLengths::Size();
|
||||
|
||||
using LowerIndex = MultiIndex<NDimLow>;
|
||||
using UpperIndex = MultiIndex<1>;
|
||||
|
||||
using LowLengthsScan =
|
||||
decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number<1>{}));
|
||||
|
||||
using UpLengths =
|
||||
decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{})));
|
||||
|
||||
LowLengths low_lengths_;
|
||||
LowLengthsScan low_lengths_scan_;
|
||||
UpLengths up_lengths_;
|
||||
|
||||
__host__ __device__ constexpr Merge_v3_division_mod() = default;
|
||||
|
||||
__host__ __device__ constexpr Merge_v3_division_mod(const LowLengths& low_lengths)
|
||||
: low_lengths_{low_lengths},
|
||||
low_lengths_scan_{
|
||||
container_reverse_exclusive_scan(low_lengths, math::multiplies{}, Number<1>{})},
|
||||
up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))}
|
||||
{
|
||||
static_assert(LowerIndex::Size() == NDimLow, "wrong!");
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; }
|
||||
|
||||
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
|
||||
|
||||
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
|
||||
|
||||
template <typename LowIdx, typename UpIdx>
|
||||
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
|
||||
const UpIdx& idx_up) const
|
||||
{
|
||||
static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
|
||||
"wrong! inconsistent # of dimension");
|
||||
|
||||
index_t tmp = idx_up[Number<0>{}];
|
||||
|
||||
// division and mod
|
||||
static_for<0, NDimLow - 1, 1>{}([&](auto i) {
|
||||
idx_low(i) = tmp / this->low_lengths_scan_[i];
|
||||
tmp %= this->low_lengths_scan_[i];
|
||||
});
|
||||
|
||||
idx_low(Number<NDimLow - 1>{}) = tmp;
|
||||
}
|
||||
|
||||
template <typename LowIdxDiff,
|
||||
typename UpIdxDiff,
|
||||
typename LowIdx,
|
||||
typename UpIdx,
|
||||
index_t Hack>
|
||||
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
|
||||
const UpIdxDiff&,
|
||||
LowIdx& idx_low,
|
||||
const UpIdx& idx_up_new,
|
||||
Number<Hack>) const
|
||||
{
|
||||
static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 &&
|
||||
LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
|
||||
"wrong! inconsistent # of dimension");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto INm1 = Number<NDimLow - 1>{};
|
||||
|
||||
index_t tmp = idx_up_new[I0];
|
||||
|
||||
static_for<0, NDimLow - 1, 1>{}([&](auto i) {
|
||||
const index_t tmp2 = idx_low[i];
|
||||
idx_low(i) = tmp / this->low_lengths_scan_[i];
|
||||
idx_diff_low(i) = idx_low[i] - tmp2;
|
||||
tmp %= this->low_lengths_scan_[i];
|
||||
});
|
||||
|
||||
const index_t tmp2 = idx_low[INm1];
|
||||
idx_low(INm1) = tmp;
|
||||
idx_diff_low(INm1) = idx_low[INm1] - tmp2;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return false; }
|
||||
|
||||
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
|
||||
{
|
||||
return is_known_at_compile_time<LowLengths>::value &&
|
||||
is_known_at_compile_time<LowLengthsScan>::value &&
|
||||
is_known_at_compile_time<UpLengths>::value;
|
||||
}
|
||||
|
||||
template <typename UpIdx>
|
||||
__host__ __device__ static constexpr bool
|
||||
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
__host__ __device__ void Print() const
|
||||
{
|
||||
printf("{");
|
||||
printf("Merge_v3_direct_division_mod, ");
|
||||
printf("low_lengths_ ");
|
||||
print_multi_index(low_lengths_);
|
||||
printf("low_lengths_scan_ ");
|
||||
print_multi_index(low_lengths_scan_);
|
||||
printf("up_lengths_ ");
|
||||
print_multi_index(up_lengths_);
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename UpLengths, bool Use24BitIntegerCalculation>
|
||||
struct UnMerge
|
||||
{
|
||||
|
||||
@@ -52,22 +52,36 @@ __host__ __device__ constexpr auto make_embed_transform(const UpLengths& up_leng
|
||||
template <typename LowLengths>
|
||||
__host__ __device__ constexpr auto make_merge_transform(const LowLengths& low_lengths)
|
||||
{
|
||||
#if !CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION
|
||||
#if CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION
|
||||
return make_merge_transform_v2_magic_division(low_lengths);
|
||||
#else
|
||||
return make_merge_transform_v1_carry_check(low_lengths);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename LowLengths>
|
||||
__host__ __device__ constexpr auto
|
||||
make_merge_transform_v1_carry_check(const LowLengths& low_lengths)
|
||||
{
|
||||
return Merge_v1_carry_check<LowLengths>{low_lengths};
|
||||
#else
|
||||
#if 1
|
||||
return Merge_v2_magic_division<LowLengths>{low_lengths};
|
||||
#else
|
||||
return Merge_v2r2_magic_division<LowLengths>{low_lengths};
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename LowLengths>
|
||||
__host__ __device__ constexpr auto
|
||||
make_merge_transform_v2_magic_division(const LowLengths& low_lengths)
|
||||
{
|
||||
#if 1
|
||||
return Merge_v2_magic_division<LowLengths>{low_lengths};
|
||||
#else
|
||||
return Merge_v2r2_magic_division<LowLengths>{low_lengths};
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename LowLengths>
|
||||
__host__ __device__ constexpr auto
|
||||
make_merge_transform_v3_division_mod(const LowLengths& low_lengths)
|
||||
{
|
||||
return Merge_v3_division_mod<LowLengths>{low_lengths};
|
||||
}
|
||||
|
||||
template <typename UpLengths, bool Use24BitIntegerCalculation = false>
|
||||
|
||||
Reference in New Issue
Block a user