Add blockwise gemm to ck wrapper (#1139)

* Add blockwise gemm to ck wrapper

* Add blockwise gemm traits

* Disable test_gemm for non xdl devices

* Fixes

* Add c layout descritpions
This commit is contained in:
Bartłomiej Kocot
2024-01-31 21:24:40 +01:00
committed by GitHub
parent 6651a124cc
commit f3b6c23ac5
12 changed files with 1064 additions and 116 deletions

View File

@@ -10,8 +10,8 @@
namespace ck {
namespace wrapper {
namespace detail {
namespace {
namespace detail {
/**
* \brief Check if Tuple contains Slice object
*
@@ -187,8 +187,8 @@ __host__ __device__ constexpr auto GenerateSlicedDescriptor(const Tuple<Ts...>&
const auto upper_dims = decltype(GenerateUpperDims<0>(TransformsTupleType{})){};
return transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims);
}
} // namespace
} // namespace detail
} // namespace
/**
* \brief Tensor wrapper that performs static and dynamic buffer logic.
@@ -209,7 +209,10 @@ struct Tensor
public:
using ElementSpaceSize = decltype(Layout<Shape, UnrolledDescriptorType>{
Shape{}, UnrolledDescriptorType{}}.GetElementSpaceSize()); // SpaceSize type for buffer
using TensorElementType = ElementType; // DataType
using TensorElementType = std::conditional_t<
is_scalar_type<ElementType>::value,
ElementType,
typename scalar_type<std::remove_const_t<ElementType>>::type>; // DataType
static constexpr MemoryTypeEnum TensorBufferAddressSpace = BufferAddressSpace;
static constexpr bool IsDynamicBuffer = !(BufferAddressSpace == MemoryTypeEnum ::Sgpr ||
@@ -280,7 +283,7 @@ struct Tensor
* \return Requested value.
*/
template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
__host__ __device__ const ElementType& operator[](const Tuple<Ts...>& idx) const
__host__ __device__ const TensorElementType& operator[](const Tuple<Ts...>& idx) const
{
if constexpr(IsDynamicBuffer)
{
@@ -301,13 +304,13 @@ struct Tensor
}
template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
__host__ __device__ const ElementType& operator()(const Tuple<Ts...>& idx) const
__host__ __device__ const TensorElementType& operator()(const Tuple<Ts...>& idx) const
{
return this->operator[](idx);
}
template <typename... Idxs, enable_if_t<!detail::HasSlice(Tuple<Idxs...>{}), bool> = false>
__host__ __device__ const ElementType& operator()(Idxs... idxs) const
__host__ __device__ const TensorElementType& operator()(Idxs... idxs) const
{
return this->operator[](make_tuple(idxs...));
}
@@ -319,7 +322,7 @@ struct Tensor
* \return Requested value.
*/
template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
__host__ __device__ ElementType& operator[](const Tuple<Ts...>& idx)
__host__ __device__ TensorElementType& operator[](const Tuple<Ts...>& idx)
{
if constexpr(IsDynamicBuffer)
{
@@ -340,13 +343,13 @@ struct Tensor
}
template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
__host__ __device__ ElementType& operator()(const Tuple<Ts...>& idx)
__host__ __device__ TensorElementType& operator()(const Tuple<Ts...>& idx)
{
return this->operator[](idx);
}
template <typename... Idxs, enable_if_t<!detail::HasSlice(Tuple<Idxs...>{}), bool> = false>
__host__ __device__ ElementType& operator()(Idxs... idxs)
__host__ __device__ TensorElementType& operator()(Idxs... idxs)
{
return this->operator[](make_tuple(idxs...));
}
@@ -366,7 +369,7 @@ struct Tensor
*
* \return Pointer.
*/
__host__ __device__ ElementType* GetPointer() const { return buffer_.p_data_; }
__host__ __device__ TensorElementType* GetPointer() const { return buffer_.p_data_; }
__host__ __device__ constexpr auto& GetBuffer() { return buffer_; }
__host__ __device__ constexpr auto& GetBuffer() const { return buffer_; }
@@ -395,10 +398,18 @@ struct Tensor
ElementType,
ElementSpaceSize,
true /*InvalidElementUseNumericalZeroValue*/>;
using StaticBufferType = StaticBuffer<BufferAddressSpace,
ElementType,
size(Shape{}),
true /*InvalidElementUseNumericalZeroValue*/>;
using StaticBufferType = std::conditional_t<
is_scalar_type<ElementType>::value,
StaticBuffer<BufferAddressSpace,
ElementType,
size(Shape{}),
true /*InvalidElementUseNumericalZeroValue*/>,
StaticBufferTupleOfVector<BufferAddressSpace,
TensorElementType,
size(Shape{}) /
scalar_type<std::remove_const_t<ElementType>>::vector_size,
scalar_type<std::remove_const_t<ElementType>>::vector_size,
true /*InvalidElementUseNumericalZeroValue*/>>;
// If register use static buffer, else use dynamic buffer
using Buffer = std::conditional_t<IsDynamicBuffer, DynamicBufferType, StaticBufferType>;