mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Add blockwise gemm to ck wrapper (#1139)
* Add blockwise gemm to ck wrapper * Add blockwise gemm traits * Disable test_gemm for non xdl devices * Fixes * Add c layout descritpions
This commit is contained in:
@@ -10,8 +10,8 @@
|
||||
namespace ck {
|
||||
namespace wrapper {
|
||||
|
||||
namespace detail {
|
||||
namespace {
|
||||
namespace detail {
|
||||
/**
|
||||
* \brief Check if Tuple contains Slice object
|
||||
*
|
||||
@@ -187,8 +187,8 @@ __host__ __device__ constexpr auto GenerateSlicedDescriptor(const Tuple<Ts...>&
|
||||
const auto upper_dims = decltype(GenerateUpperDims<0>(TransformsTupleType{})){};
|
||||
return transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims);
|
||||
}
|
||||
} // namespace
|
||||
} // namespace detail
|
||||
} // namespace
|
||||
|
||||
/**
|
||||
* \brief Tensor wrapper that performs static and dynamic buffer logic.
|
||||
@@ -209,7 +209,10 @@ struct Tensor
|
||||
public:
|
||||
using ElementSpaceSize = decltype(Layout<Shape, UnrolledDescriptorType>{
|
||||
Shape{}, UnrolledDescriptorType{}}.GetElementSpaceSize()); // SpaceSize type for buffer
|
||||
using TensorElementType = ElementType; // DataType
|
||||
using TensorElementType = std::conditional_t<
|
||||
is_scalar_type<ElementType>::value,
|
||||
ElementType,
|
||||
typename scalar_type<std::remove_const_t<ElementType>>::type>; // DataType
|
||||
|
||||
static constexpr MemoryTypeEnum TensorBufferAddressSpace = BufferAddressSpace;
|
||||
static constexpr bool IsDynamicBuffer = !(BufferAddressSpace == MemoryTypeEnum ::Sgpr ||
|
||||
@@ -280,7 +283,7 @@ struct Tensor
|
||||
* \return Requested value.
|
||||
*/
|
||||
template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
|
||||
__host__ __device__ const ElementType& operator[](const Tuple<Ts...>& idx) const
|
||||
__host__ __device__ const TensorElementType& operator[](const Tuple<Ts...>& idx) const
|
||||
{
|
||||
if constexpr(IsDynamicBuffer)
|
||||
{
|
||||
@@ -301,13 +304,13 @@ struct Tensor
|
||||
}
|
||||
|
||||
template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
|
||||
__host__ __device__ const ElementType& operator()(const Tuple<Ts...>& idx) const
|
||||
__host__ __device__ const TensorElementType& operator()(const Tuple<Ts...>& idx) const
|
||||
{
|
||||
return this->operator[](idx);
|
||||
}
|
||||
|
||||
template <typename... Idxs, enable_if_t<!detail::HasSlice(Tuple<Idxs...>{}), bool> = false>
|
||||
__host__ __device__ const ElementType& operator()(Idxs... idxs) const
|
||||
__host__ __device__ const TensorElementType& operator()(Idxs... idxs) const
|
||||
{
|
||||
return this->operator[](make_tuple(idxs...));
|
||||
}
|
||||
@@ -319,7 +322,7 @@ struct Tensor
|
||||
* \return Requested value.
|
||||
*/
|
||||
template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
|
||||
__host__ __device__ ElementType& operator[](const Tuple<Ts...>& idx)
|
||||
__host__ __device__ TensorElementType& operator[](const Tuple<Ts...>& idx)
|
||||
{
|
||||
if constexpr(IsDynamicBuffer)
|
||||
{
|
||||
@@ -340,13 +343,13 @@ struct Tensor
|
||||
}
|
||||
|
||||
template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
|
||||
__host__ __device__ ElementType& operator()(const Tuple<Ts...>& idx)
|
||||
__host__ __device__ TensorElementType& operator()(const Tuple<Ts...>& idx)
|
||||
{
|
||||
return this->operator[](idx);
|
||||
}
|
||||
|
||||
template <typename... Idxs, enable_if_t<!detail::HasSlice(Tuple<Idxs...>{}), bool> = false>
|
||||
__host__ __device__ ElementType& operator()(Idxs... idxs)
|
||||
__host__ __device__ TensorElementType& operator()(Idxs... idxs)
|
||||
{
|
||||
return this->operator[](make_tuple(idxs...));
|
||||
}
|
||||
@@ -366,7 +369,7 @@ struct Tensor
|
||||
*
|
||||
* \return Pointer.
|
||||
*/
|
||||
__host__ __device__ ElementType* GetPointer() const { return buffer_.p_data_; }
|
||||
__host__ __device__ TensorElementType* GetPointer() const { return buffer_.p_data_; }
|
||||
|
||||
__host__ __device__ constexpr auto& GetBuffer() { return buffer_; }
|
||||
__host__ __device__ constexpr auto& GetBuffer() const { return buffer_; }
|
||||
@@ -395,10 +398,18 @@ struct Tensor
|
||||
ElementType,
|
||||
ElementSpaceSize,
|
||||
true /*InvalidElementUseNumericalZeroValue*/>;
|
||||
using StaticBufferType = StaticBuffer<BufferAddressSpace,
|
||||
ElementType,
|
||||
size(Shape{}),
|
||||
true /*InvalidElementUseNumericalZeroValue*/>;
|
||||
using StaticBufferType = std::conditional_t<
|
||||
is_scalar_type<ElementType>::value,
|
||||
StaticBuffer<BufferAddressSpace,
|
||||
ElementType,
|
||||
size(Shape{}),
|
||||
true /*InvalidElementUseNumericalZeroValue*/>,
|
||||
StaticBufferTupleOfVector<BufferAddressSpace,
|
||||
TensorElementType,
|
||||
size(Shape{}) /
|
||||
scalar_type<std::remove_const_t<ElementType>>::vector_size,
|
||||
scalar_type<std::remove_const_t<ElementType>>::vector_size,
|
||||
true /*InvalidElementUseNumericalZeroValue*/>>;
|
||||
// If register use static buffer, else use dynamic buffer
|
||||
using Buffer = std::conditional_t<IsDynamicBuffer, DynamicBufferType, StaticBufferType>;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user