mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
fix scratch in fp8 kernel
This commit is contained in:
@@ -59,6 +59,64 @@ struct thread_buffer {
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& at(number<I>) { return get(I); }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at(number<I>) const { return get(I); }
|
||||
|
||||
template <typename X_,
|
||||
typename std::enable_if<has_same_scalar_type<value_type, X_>::value, bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto _get_as() const
|
||||
{
|
||||
using X = remove_cvref_t<X_>;
|
||||
|
||||
constexpr index_t kSPerX = vector_traits<X>::vector_size;
|
||||
static_assert(N % kSPerX == 0);
|
||||
|
||||
union {
|
||||
thread_buffer<X_, N / kSPerX> data {};
|
||||
// tuple_array<value_type, kSPerX> sub_data;
|
||||
value_type sub_data[N];
|
||||
} vx;
|
||||
static_for<0, N, 1>{}(
|
||||
[&](auto j) { vx.sub_data[j] = data[j]; });
|
||||
return vx.data;
|
||||
}
|
||||
|
||||
template <typename X_,
|
||||
index_t Is,
|
||||
typename std::enable_if<has_same_scalar_type<value_type, X_>::value, bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE const constexpr remove_reference_t<X_> _get_as(number<Is> is) const
|
||||
{
|
||||
using X = remove_cvref_t<X_>;
|
||||
|
||||
constexpr index_t kSPerX = vector_traits<X>::vector_size;
|
||||
|
||||
union {
|
||||
X_ data {};
|
||||
tuple_array<value_type, kSPerX> sub_data;
|
||||
} vx;
|
||||
static_for<0, kSPerX, 1>{}(
|
||||
[&](auto j) { vx.sub_data(j) = operator[]((is * number<sizeof(X_)/sizeof(value_type)>{}) + j); });
|
||||
return vx.data;
|
||||
}
|
||||
|
||||
#if 0
|
||||
template <typename X_,
|
||||
index_t Is,
|
||||
typename std::enable_if<has_same_scalar_type<value_type, X_>::value, bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void _set_as(number<Is> is, X_ x)
|
||||
{
|
||||
using X = remove_cvref_t<X_>;
|
||||
|
||||
constexpr index_t kSPerX = vector_traits<X>::vector_size;
|
||||
|
||||
union {
|
||||
X_ data;
|
||||
tuple_array<value_type, kSPerX> sub_data;
|
||||
} vx {x};
|
||||
|
||||
static_for<0, kSPerX, 1>{}(
|
||||
[&](auto j) { operator()((is * number<sizeof(X_)/sizeof(value_type)>{}) + j) = vx.sub_data[j]; });
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
#define TB_COMMON_AS() \
|
||||
static_assert(sizeof(value_type) * N % sizeof(Tx) == 0); \
|
||||
constexpr int vx = sizeof(value_type) * N / sizeof(Tx)
|
||||
@@ -67,19 +125,26 @@ struct thread_buffer {
|
||||
CK_TILE_HOST_DEVICE auto & get_as() {TB_COMMON_AS();
|
||||
return reinterpret_cast<thread_buffer<Tx, vx>&>(data);}
|
||||
template<typename Tx>
|
||||
CK_TILE_HOST_DEVICE const auto & get_as() const {TB_COMMON_AS();
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_as() const {TB_COMMON_AS();
|
||||
if constexpr(sizeof(value_type) <= 1 )
|
||||
return _get_as<Tx>(); // TODO: current compiler for 8bit data need use union to get data back, should fix in the future
|
||||
else
|
||||
return reinterpret_cast<const thread_buffer<Tx, vx>&>(data);}
|
||||
template<typename Tx>
|
||||
CK_TILE_HOST_DEVICE auto & get_as(index_t i) {TB_COMMON_AS();
|
||||
return reinterpret_cast<thread_buffer<Tx, vx>&>(data).get(i);}
|
||||
template<typename Tx>
|
||||
CK_TILE_HOST_DEVICE const auto & get_as(index_t i) const {TB_COMMON_AS();
|
||||
return reinterpret_cast<const thread_buffer<Tx, vx>&>(data).get(i);}
|
||||
template<typename Tx, index_t I>
|
||||
CK_TILE_HOST_DEVICE auto & get_as(number<I>) {TB_COMMON_AS();
|
||||
return reinterpret_cast<thread_buffer<Tx, vx>&>(data).get(number<I>{});}
|
||||
template<typename Tx, index_t I>
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_as(number<I>) const {TB_COMMON_AS();
|
||||
if constexpr(sizeof(value_type) <= 1 )
|
||||
return _get_as<Tx>(number<I>{}); // TODO: current compiler for 8bit data need use union to get data back, should fix in the future
|
||||
else
|
||||
return reinterpret_cast<const thread_buffer<Tx, vx>&>(data).get(number<I>{});}
|
||||
|
||||
template <typename Tx> CK_TILE_HOST_DEVICE constexpr void set_as(index_t i, const Tx & x)
|
||||
{ TB_COMMON_AS(); reinterpret_cast<array<Tx, vx>&>(data).at(i) = x; }
|
||||
{ TB_COMMON_AS(); reinterpret_cast<thread_buffer<Tx, vx>&>(data).at(i) = x; }
|
||||
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr void set_as(number<I>, const Tx & x)
|
||||
{ TB_COMMON_AS(); reinterpret_cast<array<Tx, vx>&>(data).at(number<I>{}) = x; }
|
||||
{ TB_COMMON_AS(); reinterpret_cast<thread_buffer<Tx, vx>&>(data).at(number<I>{}) = x; }
|
||||
|
||||
#undef TB_COMMON_AS
|
||||
};
|
||||
// clang-format on
|
||||
|
||||
@@ -38,6 +38,16 @@ struct ext_vector
|
||||
static_assert(!std::is_class_v<value_type>);
|
||||
using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
|
||||
};
|
||||
|
||||
template <typename V_, index_t Vs_, index_t N_>
|
||||
struct ext_vector<V_ __attribute__((ext_vector_type(Vs_))), N_>
|
||||
{
|
||||
static constexpr index_t N = Vs_ * N_;
|
||||
using value_type = typename native_t<remove_cvref_t<V_>>::type;
|
||||
static_assert(!std::is_class_v<value_type>);
|
||||
using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
|
||||
};
|
||||
|
||||
} // namespace impl
|
||||
|
||||
template <typename T, index_t N>
|
||||
@@ -60,6 +70,10 @@ struct vector_traits<T __attribute__((ext_vector_type(N)))>
|
||||
static constexpr index_t vector_size = N;
|
||||
};
|
||||
|
||||
template <typename X, typename Y>
|
||||
using has_same_scalar_type = std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<Y>>::scalar_type>;
|
||||
|
||||
// below are some pre-defines of ext_vector_type
|
||||
// attention! 2 vector type could be just the same type
|
||||
// fp64
|
||||
|
||||
@@ -119,8 +119,8 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT
|
||||
static_assert(in_offset % vec_length_in == 0);
|
||||
|
||||
in_vectors(i).template get_as<InVec>()(I0) =
|
||||
in_tensor.get_thread_buffer().template get_as<InVec>(
|
||||
number<in_offset / vec_length_in>{});
|
||||
in_tensor.get_thread_buffer()
|
||||
.template get_as<InVec>()[number<in_offset / vec_length_in>{}];
|
||||
});
|
||||
|
||||
// transpose
|
||||
|
||||
@@ -39,9 +39,9 @@ struct WarpGemmImpl
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
const auto a_vec = a.get_thread_buffer().template get_as<AVec>(I0);
|
||||
const auto b_vec = b.get_thread_buffer().template get_as<BVec>(I0);
|
||||
auto c_vec = c.get_thread_buffer().template get_as<CVec>(I0);
|
||||
const auto a_vec = a.get_thread_buffer().template get_as<AVec>()[I0];
|
||||
const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
|
||||
auto c_vec = c.get_thread_buffer().template get_as<CVec>()[I0];
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
WarpGemmAttribute{}(c_vec, a_vec, b_vec);
|
||||
@@ -59,8 +59,8 @@ struct WarpGemmImpl
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
const auto a_vec = a.get_thread_buffer().template get_as<AVec>(I0);
|
||||
const auto b_vec = b.get_thread_buffer().template get_as<BVec>(I0);
|
||||
const auto a_vec = a.get_thread_buffer().template get_as<AVec>()[I0];
|
||||
const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
auto c_vec = WarpGemmAttribute{}(a_vec, b_vec);
|
||||
|
||||
Reference in New Issue
Block a user