mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
Initial implementation of magic number division and "Merge" transformation that use it (#28)
* initial implementation for magic number division and DynamicMerge_v2_magic_division that uses it * turn off DynamicMerge_v2_magic_division that use magic number division by default
This commit is contained in:
@@ -467,8 +467,10 @@ struct DynamicEmbed
|
||||
}
|
||||
};
|
||||
|
||||
// Implementation of "Merge" transformation primitive that uses regular to do lowering of
|
||||
// multi-index and use carry-and-borrow check to do lowering of multi-index delta
|
||||
template <typename LowLengths>
|
||||
struct DynamicMerge
|
||||
struct DynamicMerge_v1_carry_check
|
||||
{
|
||||
static constexpr index_t NDimLow = LowLengths::Size();
|
||||
|
||||
@@ -485,9 +487,9 @@ struct DynamicMerge
|
||||
LowLengthsScan low_lengths_scan_;
|
||||
UpLengths up_lengths_;
|
||||
|
||||
__host__ __device__ constexpr DynamicMerge() = default;
|
||||
__host__ __device__ constexpr DynamicMerge_v1_carry_check() = default;
|
||||
|
||||
__host__ __device__ constexpr DynamicMerge(const LowLengths& low_lengths)
|
||||
__host__ __device__ constexpr DynamicMerge_v1_carry_check(const LowLengths& low_lengths)
|
||||
: low_lengths_{low_lengths},
|
||||
low_lengths_scan_{
|
||||
container_reverse_exclusive_scan(low_lengths, math::multiplies_v2{}, Number<1>{})},
|
||||
@@ -511,7 +513,8 @@ struct DynamicMerge
|
||||
|
||||
index_t tmp = idx_up[Number<0>{}];
|
||||
|
||||
static_for<0, NDimLow - 1, 1>{}([&idx_low, &tmp, this](auto i) {
|
||||
// normal division
|
||||
static_for<0, NDimLow - 1, 1>{}([&](auto i) {
|
||||
idx_low(i) = tmp / this->low_lengths_scan_[i];
|
||||
tmp -= idx_low[i] * this->low_lengths_scan_[i];
|
||||
});
|
||||
@@ -978,7 +981,7 @@ struct DynamicMerge
|
||||
__host__ __device__ void Print() const
|
||||
{
|
||||
printf("{");
|
||||
printf("DynamicMerge, ");
|
||||
printf("DynamicMerge_v1_carry_check, ");
|
||||
printf("low_lengths_ ");
|
||||
print_multi_index(low_lengths_);
|
||||
printf("low_lengths_scan_ ");
|
||||
@@ -989,6 +992,178 @@ struct DynamicMerge
|
||||
}
|
||||
};
|
||||
|
||||
template <typename LowLengths>
|
||||
struct lambda_merge_generate_MagicDivision_calculate_magic_multiplier
|
||||
{
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto operator()(Number<I> i) const
|
||||
{
|
||||
return MagicDivision::CalculateMagicMultiplier(LowLengths{}[i]);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename LowLengths>
|
||||
struct lambda_merge_generate_MagicDivision_calculate_magic_shift
|
||||
{
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto operator()(Number<I> i) const
|
||||
{
|
||||
return MagicDivision::CalculateMagicShift(LowLengths{}[i]);
|
||||
}
|
||||
};
|
||||
|
||||
// Implementation of "Merge" transformation primitive that uses magic-number-division to do lowering
|
||||
// of both multi-index and delta of multi-index
|
||||
// Caution:
|
||||
// 1. The magic number division implementation being used would produce correct result if the
|
||||
// dividended is uint32_t and its value is with in 31-bit value range of uint32_t.
|
||||
// 2. The magic number division for int32_t dividened has not been implemented, the int32_t
|
||||
// dividend would be bit-wise interpreted as uint32_t and magic number division implementation for
|
||||
// uint32_t is then used.
|
||||
// 3. For Merge primitive, upper-index is the dividend.
|
||||
// 4. When upper-index is uint32_t, its value need to be within 31-bit range.
|
||||
// 5. When upper-index is int32_t type (when index_t is int32_t), its value need to be
|
||||
// non-negative.
|
||||
template <typename LowLengths>
|
||||
struct DynamicMerge_v2_magic_division
|
||||
{
|
||||
static constexpr index_t NDimLow = LowLengths::Size();
|
||||
|
||||
using LowerIndex = MultiIndex<NDimLow>;
|
||||
using UpperIndex = MultiIndex<1>;
|
||||
|
||||
using UpLengths =
|
||||
decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies_v2{}, Number<1>{})));
|
||||
|
||||
using LowLengthsMagicDivisorMultipiler = decltype(
|
||||
generate_tuple(lambda_merge_generate_MagicDivision_calculate_magic_multiplier<LowLengths>{},
|
||||
Number<NDimLow>{}));
|
||||
|
||||
using LowLengthsMagicDivisorShift = decltype(
|
||||
generate_tuple(lambda_merge_generate_MagicDivision_calculate_magic_shift<LowLengths>{},
|
||||
Number<NDimLow>{}));
|
||||
|
||||
LowLengths low_lengths_;
|
||||
LowLengthsMagicDivisorMultipiler low_lengths_magic_divisor_multiplier_;
|
||||
LowLengthsMagicDivisorShift low_lengths_magic_divisor_shift_;
|
||||
UpLengths up_lengths_;
|
||||
|
||||
__host__ __device__ constexpr DynamicMerge_v2_magic_division() = default;
|
||||
|
||||
__host__ __device__ constexpr DynamicMerge_v2_magic_division(const LowLengths& low_lengths)
|
||||
: low_lengths_{low_lengths},
|
||||
low_lengths_magic_divisor_multiplier_{generate_tuple(
|
||||
[&](auto i) { return MagicDivision::CalculateMagicMultiplier(low_lengths[i]); },
|
||||
Number<NDimLow>{})},
|
||||
low_lengths_magic_divisor_shift_{generate_tuple(
|
||||
[&](auto i) { return MagicDivision::CalculateMagicShift(low_lengths[i]); },
|
||||
Number<NDimLow>{})},
|
||||
up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies_v2{}, Number<1>{}))}
|
||||
{
|
||||
static_assert(LowerIndex::Size() == NDimLow, "wrong!");
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; }
|
||||
|
||||
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
|
||||
|
||||
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
|
||||
|
||||
template <typename LowIdx, typename UpIdx>
|
||||
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
|
||||
const UpIdx& idx_up) const
|
||||
{
|
||||
static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
|
||||
"wrong! inconsistent # of dimension");
|
||||
|
||||
index_t tmp = idx_up[Number<0>{}];
|
||||
|
||||
static_for<NDimLow - 1, 0, -1>{}([&, this](auto i) {
|
||||
index_t tmp2 =
|
||||
MagicDivision::DoMagicDivision(tmp,
|
||||
this->low_lengths_magic_divisor_multiplier_[i],
|
||||
this->low_lengths_magic_divisor_shift_[i]);
|
||||
idx_low(i) = tmp - tmp2 * this->low_lengths_[i];
|
||||
tmp = tmp2;
|
||||
});
|
||||
|
||||
idx_low(Number<0>{}) = tmp;
|
||||
}
|
||||
|
||||
template <typename LowIdxDiff,
|
||||
typename UpIdxDiff,
|
||||
typename LowIdx,
|
||||
typename UpIdx,
|
||||
index_t Hack>
|
||||
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
|
||||
const UpIdxDiff& idx_diff_up,
|
||||
LowIdx& idx_low,
|
||||
const UpIdx& idx_up_new,
|
||||
Number<Hack>) const
|
||||
{
|
||||
static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 &&
|
||||
LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
|
||||
"wrong! inconsistent # of dimension");
|
||||
|
||||
index_t tmp = idx_up_new[Number<0>{}];
|
||||
|
||||
static_for<NDimLow - 1, 0, -1>{}([&, this](auto i) {
|
||||
index_t tmp2 =
|
||||
MagicDivision::DoMagicDivision(tmp,
|
||||
this->low_lengths_magic_divisor_multiplier_[i],
|
||||
this->low_lengths_magic_divisor_shift_[i]);
|
||||
|
||||
index_t idx_low_old = idx_low[i];
|
||||
|
||||
idx_low(i) = tmp - tmp2 * this->low_lengths_[i];
|
||||
tmp = tmp2;
|
||||
|
||||
idx_diff_low(i) = idx_low[i] - idx_low_old;
|
||||
});
|
||||
|
||||
idx_diff_low(Number<0>{}) = tmp - idx_low(Number<0>{});
|
||||
|
||||
idx_low(Number<0>{}) = tmp;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return false; }
|
||||
|
||||
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
|
||||
{
|
||||
return is_known_at_compile_time<LowLengths>::value &&
|
||||
is_known_at_compile_time<LowLengthsMagicDivisorMultipiler>::value &&
|
||||
is_known_at_compile_time<LowLengthsMagicDivisorShift>::value &&
|
||||
is_known_at_compile_time<UpLengths>::value;
|
||||
}
|
||||
|
||||
template <typename UpIdx>
|
||||
__host__ __device__ static constexpr bool
|
||||
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
__host__ __device__ void Print() const
|
||||
{
|
||||
printf("{");
|
||||
printf("DynamicMerge_v2_magic_division, ");
|
||||
printf("low_lengths_ ");
|
||||
print_multi_index(low_lengths_);
|
||||
printf("low_lengths_magic_divisor_multiplier_ ");
|
||||
print_multi_index(low_lengths_magic_divisor_multiplier_);
|
||||
printf("low_lengths_magic_divisor_shift_ ");
|
||||
print_multi_index(low_lengths_magic_divisor_shift_);
|
||||
printf("up_lengths_ ");
|
||||
print_multi_index(up_lengths_);
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename UpLengths, bool Use24BitIntegerCalculation>
|
||||
struct DynamicUnMerge
|
||||
{
|
||||
|
||||
@@ -53,7 +53,11 @@ __host__ __device__ constexpr auto make_embed_transform(const UpLengths& up_leng
|
||||
template <typename LowLengths>
|
||||
__host__ __device__ constexpr auto make_merge_transform(const LowLengths& low_lengths)
|
||||
{
|
||||
return DynamicMerge<LowLengths>{low_lengths};
|
||||
#if !CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION
|
||||
return DynamicMerge_v1_carry_check<LowLengths>{low_lengths};
|
||||
#else
|
||||
return DynamicMerge_v2_magic_division<LowLengths>{low_lengths};
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename UpLengths, bool Use24BitIntegerCalculation = false>
|
||||
|
||||
@@ -22,6 +22,7 @@
|
||||
#include "tuple_helper.hpp"
|
||||
#include "type.hpp"
|
||||
#include "utility.hpp"
|
||||
#include "magic_division.hpp"
|
||||
|
||||
#if CK_USE_AMD_INLINE_ASM
|
||||
#include "amd_inline_asm.hpp"
|
||||
|
||||
@@ -115,6 +115,9 @@
|
||||
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE 1
|
||||
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 0
|
||||
|
||||
// merge transformation use magic number division
|
||||
#define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 0
|
||||
|
||||
// hack: have underlying assumption that need to be satsified, otherwise it's a bug
|
||||
// hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be
|
||||
// thread-invariant, otherwise it's a bug
|
||||
|
||||
136
composable_kernel/include/utility/magic_division.hpp
Normal file
136
composable_kernel/include/utility/magic_division.hpp
Normal file
@@ -0,0 +1,136 @@
|
||||
#ifndef CK_MAGIC_DIVISION_HPP
|
||||
#define CK_MAGIC_DIVISION_HPP
|
||||
|
||||
#include "config.hpp"
|
||||
#include "integral_constant.hpp"
|
||||
#include "number.hpp"
|
||||
#include "type.hpp"
|
||||
#include "tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// magic number division
|
||||
// Caution:
|
||||
// 1. For uint32_t as dividend: magic number division implementation being used would produce
|
||||
// correct result if the dividend is uint32_t and its value is within 31-bit value range.
|
||||
// 2. For int32_t as dividendd: magic number division for int32_t dividened has not been
|
||||
// implemented, the int32_t dividend would be bit-wise interpreted as uint32_t and magic number
|
||||
// division implementation for uint32_t is then used. Therefore, dividend value need to be
|
||||
// non-negative.
|
||||
// TODO:
|
||||
// 1. Implement magic number divison for int32_t
|
||||
// 2. Implement magic number divison for unit32_t with 32-bit value range
|
||||
struct MagicDivision
|
||||
{
|
||||
// uint32_t
|
||||
__host__ __device__ static constexpr auto CalculateMagicNumbers(uint32_t divisor)
|
||||
{
|
||||
// assert(divisior >= 1 && divisior <= INT32_MAX);
|
||||
uint32_t shift = 0;
|
||||
for(shift = 0; shift < 32; ++shift)
|
||||
{
|
||||
if((1U << shift) >= divisor)
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
uint64_t one = 1;
|
||||
uint64_t multiplier = ((one << 32) * ((one << shift) - divisor)) / divisor + 1;
|
||||
// assert(multiplier <= 0xffffffffUL);
|
||||
|
||||
return make_tuple(uint32_t(multiplier), shift);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr uint32_t CalculateMagicMultiplier(uint32_t divisor)
|
||||
{
|
||||
auto tmp = CalculateMagicNumbers(divisor);
|
||||
|
||||
return tmp[Number<0>{}];
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr uint32_t CalculateMagicShift(uint32_t divisor)
|
||||
{
|
||||
auto tmp = CalculateMagicNumbers(divisor);
|
||||
|
||||
return tmp[Number<1>{}];
|
||||
}
|
||||
|
||||
// integral_constant<uint32_t, .>
|
||||
template <uint32_t Divisor>
|
||||
__host__ __device__ static constexpr auto
|
||||
CalculateMagicNumbers(integral_constant<uint32_t, Divisor>)
|
||||
{
|
||||
constexpr auto tmp = CalculateMagicNumbers(uint32_t{Divisor});
|
||||
|
||||
constexpr uint32_t multiplier = tmp[Number<0>{}];
|
||||
constexpr uint32_t shift = tmp[Number<1>{}];
|
||||
|
||||
return make_tuple(integral_constant<uint32_t, multiplier>{},
|
||||
integral_constant<uint32_t, shift>{});
|
||||
}
|
||||
|
||||
template <uint32_t Divisor>
|
||||
__host__ __device__ static constexpr auto
|
||||
CalculateMagicMultiplier(integral_constant<uint32_t, Divisor>)
|
||||
{
|
||||
constexpr uint32_t multiplier = CalculateMagicMultiplier(uint32_t{Divisor});
|
||||
|
||||
return integral_constant<uint32_t, multiplier>{};
|
||||
}
|
||||
|
||||
template <uint32_t Divisor>
|
||||
__host__ __device__ static constexpr auto
|
||||
CalculateMagicShift(integral_constant<uint32_t, Divisor>)
|
||||
{
|
||||
constexpr uint32_t shift = CalculateMagicShift(uint32_t{Divisor});
|
||||
|
||||
return integral_constant<uint32_t, shift>{};
|
||||
}
|
||||
|
||||
// integral_constant<int32_t, .>
|
||||
template <int32_t Divisor>
|
||||
__host__ __device__ static constexpr auto
|
||||
CalculateMagicNumbers(integral_constant<int32_t, Divisor>)
|
||||
{
|
||||
return CalculateMagicNumbers(integral_constant<uint32_t, Divisor>{});
|
||||
}
|
||||
|
||||
template <int32_t Divisor>
|
||||
__host__ __device__ static constexpr auto
|
||||
CalculateMagicMultiplier(integral_constant<int32_t, Divisor>)
|
||||
{
|
||||
return CalculateMagicMultiplier(integral_constant<uint32_t, Divisor>{});
|
||||
}
|
||||
|
||||
template <int32_t Divisor>
|
||||
__host__ __device__ static constexpr auto
|
||||
CalculateMagicShift(integral_constant<int32_t, Divisor>)
|
||||
{
|
||||
return CalculateMagicShift(integral_constant<uint32_t, Divisor>{});
|
||||
}
|
||||
|
||||
// magic division for uint32_t
|
||||
__host__ __device__ static constexpr uint32_t
|
||||
DoMagicDivision(uint32_t dividend, uint32_t multiplier, uint32_t shift)
|
||||
{
|
||||
uint32_t tmp = (uint64_t(dividend) * uint64_t(multiplier)) >> 32;
|
||||
return (tmp + dividend) >> shift;
|
||||
}
|
||||
|
||||
// HACK: magic division for int32_t
|
||||
// HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be
|
||||
// non-negative for result to be correct
|
||||
// TODO: figure out how to do magic number divison for int32_t as dividended
|
||||
__host__ __device__ static constexpr int32_t
|
||||
DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
|
||||
{
|
||||
uint32_t dividend_u32 = as_type<uint32_t>(dividend_i32);
|
||||
uint32_t tmp = ((uint64_t)dividend_u32 * (uint64_t)multiplier) >> 32;
|
||||
return (tmp + dividend_i32) >> shift;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -42,5 +42,19 @@ struct is_known_at_compile_time<integral_constant<T, X>>
|
||||
static constexpr bool value = true;
|
||||
};
|
||||
|
||||
template <typename Y,
|
||||
typename X,
|
||||
typename std::enable_if<sizeof(X) == sizeof(Y), bool>::type = false>
|
||||
__host__ __device__ constexpr Y as_type(X x)
|
||||
{
|
||||
union AsType
|
||||
{
|
||||
X x;
|
||||
Y y;
|
||||
};
|
||||
|
||||
return AsType{x}.y;
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user