fix scratch in fp8 kernel

This commit is contained in:
carlushuang
2024-03-25 19:45:07 +00:00
parent 1cacb713c5
commit 200d2b22d4
4 changed files with 95 additions and 16 deletions

View File

@@ -59,6 +59,64 @@ struct thread_buffer {
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& at(number<I>) { return get(I); }
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at(number<I>) const { return get(I); }
template <typename X_,
typename std::enable_if<has_same_scalar_type<value_type, X_>::value, bool>::type = false>
CK_TILE_HOST_DEVICE constexpr auto _get_as() const
{
using X = remove_cvref_t<X_>;
constexpr index_t kSPerX = vector_traits<X>::vector_size;
static_assert(N % kSPerX == 0);
union {
thread_buffer<X_, N / kSPerX> data {};
// tuple_array<value_type, kSPerX> sub_data;
value_type sub_data[N];
} vx;
static_for<0, N, 1>{}(
[&](auto j) { vx.sub_data[j] = data[j]; });
return vx.data;
}
template <typename X_,
index_t Is,
typename std::enable_if<has_same_scalar_type<value_type, X_>::value, bool>::type = false>
CK_TILE_HOST_DEVICE const constexpr remove_reference_t<X_> _get_as(number<Is> is) const
{
using X = remove_cvref_t<X_>;
constexpr index_t kSPerX = vector_traits<X>::vector_size;
union {
X_ data {};
tuple_array<value_type, kSPerX> sub_data;
} vx;
static_for<0, kSPerX, 1>{}(
[&](auto j) { vx.sub_data(j) = operator[]((is * number<sizeof(X_)/sizeof(value_type)>{}) + j); });
return vx.data;
}
#if 0
template <typename X_,
index_t Is,
typename std::enable_if<has_same_scalar_type<value_type, X_>::value, bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void _set_as(number<Is> is, X_ x)
{
using X = remove_cvref_t<X_>;
constexpr index_t kSPerX = vector_traits<X>::vector_size;
union {
X_ data;
tuple_array<value_type, kSPerX> sub_data;
} vx {x};
static_for<0, kSPerX, 1>{}(
[&](auto j) { operator()((is * number<sizeof(X_)/sizeof(value_type)>{}) + j) = vx.sub_data[j]; });
}
#endif
#define TB_COMMON_AS() \
static_assert(sizeof(value_type) * N % sizeof(Tx) == 0); \
constexpr int vx = sizeof(value_type) * N / sizeof(Tx)
@@ -67,19 +125,26 @@ struct thread_buffer {
CK_TILE_HOST_DEVICE auto & get_as() {TB_COMMON_AS();
return reinterpret_cast<thread_buffer<Tx, vx>&>(data);}
template<typename Tx>
CK_TILE_HOST_DEVICE const auto & get_as() const {TB_COMMON_AS();
CK_TILE_HOST_DEVICE constexpr auto get_as() const {TB_COMMON_AS();
if constexpr(sizeof(value_type) <= 1 )
return _get_as<Tx>(); // TODO: current compiler for 8bit data need use union to get data back, should fix in the future
else
return reinterpret_cast<const thread_buffer<Tx, vx>&>(data);}
template<typename Tx>
CK_TILE_HOST_DEVICE auto & get_as(index_t i) {TB_COMMON_AS();
return reinterpret_cast<thread_buffer<Tx, vx>&>(data).get(i);}
template<typename Tx>
CK_TILE_HOST_DEVICE const auto & get_as(index_t i) const {TB_COMMON_AS();
return reinterpret_cast<const thread_buffer<Tx, vx>&>(data).get(i);}
template<typename Tx, index_t I>
CK_TILE_HOST_DEVICE auto & get_as(number<I>) {TB_COMMON_AS();
return reinterpret_cast<thread_buffer<Tx, vx>&>(data).get(number<I>{});}
template<typename Tx, index_t I>
CK_TILE_HOST_DEVICE constexpr auto get_as(number<I>) const {TB_COMMON_AS();
if constexpr(sizeof(value_type) <= 1 )
return _get_as<Tx>(number<I>{}); // TODO: current compiler for 8bit data need use union to get data back, should fix in the future
else
return reinterpret_cast<const thread_buffer<Tx, vx>&>(data).get(number<I>{});}
template <typename Tx> CK_TILE_HOST_DEVICE constexpr void set_as(index_t i, const Tx & x)
{ TB_COMMON_AS(); reinterpret_cast<array<Tx, vx>&>(data).at(i) = x; }
{ TB_COMMON_AS(); reinterpret_cast<thread_buffer<Tx, vx>&>(data).at(i) = x; }
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr void set_as(number<I>, const Tx & x)
{ TB_COMMON_AS(); reinterpret_cast<array<Tx, vx>&>(data).at(number<I>{}) = x; }
{ TB_COMMON_AS(); reinterpret_cast<thread_buffer<Tx, vx>&>(data).at(number<I>{}) = x; }
#undef TB_COMMON_AS
};
// clang-format on

View File

@@ -38,6 +38,16 @@ struct ext_vector
static_assert(!std::is_class_v<value_type>);
using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
};
template <typename V_, index_t Vs_, index_t N_>
struct ext_vector<V_ __attribute__((ext_vector_type(Vs_))), N_>
{
static constexpr index_t N = Vs_ * N_;
using value_type = typename native_t<remove_cvref_t<V_>>::type;
static_assert(!std::is_class_v<value_type>);
using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
};
} // namespace impl
template <typename T, index_t N>
@@ -60,6 +70,10 @@ struct vector_traits<T __attribute__((ext_vector_type(N)))>
static constexpr index_t vector_size = N;
};
template <typename X, typename Y>
using has_same_scalar_type = std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<Y>>::scalar_type>;
// below are some pre-defines of ext_vector_type
// attention! 2 vector type could be just the same type
// fp64

View File

@@ -119,8 +119,8 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT
static_assert(in_offset % vec_length_in == 0);
in_vectors(i).template get_as<InVec>()(I0) =
in_tensor.get_thread_buffer().template get_as<InVec>(
number<in_offset / vec_length_in>{});
in_tensor.get_thread_buffer()
.template get_as<InVec>()[number<in_offset / vec_length_in>{}];
});
// transpose

View File

@@ -39,9 +39,9 @@ struct WarpGemmImpl
constexpr auto I0 = number<0>{};
const auto a_vec = a.get_thread_buffer().template get_as<AVec>(I0);
const auto b_vec = b.get_thread_buffer().template get_as<BVec>(I0);
auto c_vec = c.get_thread_buffer().template get_as<CVec>(I0);
const auto a_vec = a.get_thread_buffer().template get_as<AVec>()[I0];
const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
auto c_vec = c.get_thread_buffer().template get_as<CVec>()[I0];
// c_vec += a_vec * b_vec
WarpGemmAttribute{}(c_vec, a_vec, b_vec);
@@ -59,8 +59,8 @@ struct WarpGemmImpl
constexpr auto I0 = number<0>{};
const auto a_vec = a.get_thread_buffer().template get_as<AVec>(I0);
const auto b_vec = b.get_thread_buffer().template get_as<BVec>(I0);
const auto a_vec = a.get_thread_buffer().template get_as<AVec>()[I0];
const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
// c_vec = a_vec * b_vec
auto c_vec = WarpGemmAttribute{}(a_vec, b_vec);