mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
Fix
This commit is contained in:
@@ -40,12 +40,13 @@ template <typename BufferView_,
|
||||
struct tensor_view
|
||||
{
|
||||
using buffer_view = remove_reference_t<BufferView_>;
|
||||
using DataType = remove_cvref_t<typename buffer_view::type>;
|
||||
using DataType = typename buffer_view::type;
|
||||
using DataType_ = remove_cvref_t<DataType>;
|
||||
using TensorDesc = remove_cvref_t<TensorDesc_>;
|
||||
using TensorIndex = array<index_t, TensorDesc::get_num_of_top_dimension()>;
|
||||
using TensorCoord = decltype(make_tensor_coordinate(TensorDesc{}, TensorIndex{}));
|
||||
static constexpr auto DstInMemOp = DstInMemOp_;
|
||||
static constexpr index_t PackedSize = ck_tile::numeric_traits<DataType>::PackedSize;
|
||||
static constexpr index_t PackedSize = ck_tile::numeric_traits<DataType_>::PackedSize;
|
||||
|
||||
template <typename T>
|
||||
using vector_scalar_t = typename vector_traits<remove_cvref_t<T>>::scalar_type;
|
||||
@@ -77,7 +78,7 @@ struct tensor_view
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<DataType>::scalar_type>,
|
||||
typename vector_traits<DataType_>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<X>
|
||||
get_vectorized_elements(const TensorCoord& coord,
|
||||
@@ -95,7 +96,7 @@ struct tensor_view
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<DataType>::scalar_type>,
|
||||
typename vector_traits<DataType_>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<X>
|
||||
get_vectorized_elements(const TensorCoord& coord,
|
||||
@@ -116,7 +117,7 @@ struct tensor_view
|
||||
bool pre_nop = false,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<DataType>::scalar_type>,
|
||||
typename vector_traits<DataType_>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE void get_vectorized_elements_raw(remove_cvref_t<X>& dst,
|
||||
const TensorCoord& coord,
|
||||
@@ -137,7 +138,7 @@ struct tensor_view
|
||||
bool pre_nop = false,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<DataType>::scalar_type>,
|
||||
typename vector_traits<DataType_>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE void get_vectorized_elements_raw(remove_cvref_t<X>& dst,
|
||||
const TensorCoord& coord,
|
||||
@@ -158,9 +159,9 @@ struct tensor_view
|
||||
bool oob_conditional_check = true,
|
||||
index_t IMM = 0,
|
||||
typename = std::enable_if_t<
|
||||
std::is_same_v<vector_scalar_t<remove_cvref_t<X>>, vector_scalar_t<DataType>>>>
|
||||
std::is_same_v<vector_scalar_t<remove_cvref_t<X>>, vector_scalar_t<DataType_>>>>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
async_get_vectorized_elements(CK_TILE_LDS_ADDR DataType* smem,
|
||||
async_get_vectorized_elements(CK_TILE_LDS_ADDR DataType_* smem,
|
||||
index_t offset,
|
||||
index_t wave_offset,
|
||||
number<IMM> = {},
|
||||
@@ -177,9 +178,9 @@ struct tensor_view
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename = std::enable_if_t<
|
||||
std::is_same_v<vector_scalar_t<remove_cvref_t<X>>, vector_scalar_t<DataType>>>>
|
||||
std::is_same_v<vector_scalar_t<remove_cvref_t<X>>, vector_scalar_t<DataType_>>>>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
async_get_vectorized_elements(CK_TILE_LDS_ADDR DataType* smem,
|
||||
async_get_vectorized_elements(CK_TILE_LDS_ADDR DataType_* smem,
|
||||
const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
@@ -197,9 +198,9 @@ struct tensor_view
|
||||
bool oob_conditional_check = true,
|
||||
typename = std::enable_if_t<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<DataType>::scalar_type>>>
|
||||
typename vector_traits<DataType_>::scalar_type>>>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
async_get_vectorized_elements(CK_TILE_LDS_ADDR DataType* smem,
|
||||
async_get_vectorized_elements(CK_TILE_LDS_ADDR DataType_* smem,
|
||||
const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element,
|
||||
@@ -217,10 +218,10 @@ struct tensor_view
|
||||
bool pre_nop = false,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<DataType>::scalar_type>,
|
||||
typename vector_traits<DataType_>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
async_get_vectorized_elements_raw(DataType* smem,
|
||||
async_get_vectorized_elements_raw(DataType_* smem,
|
||||
const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool_constant<pre_nop> = {}) const
|
||||
@@ -237,10 +238,10 @@ struct tensor_view
|
||||
bool pre_nop = false,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<DataType>::scalar_type>,
|
||||
typename vector_traits<DataType_>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
async_get_vectorized_elements_raw(DataType* smem,
|
||||
async_get_vectorized_elements_raw(DataType_* smem,
|
||||
const TensorCoord& coord,
|
||||
index_t coord_extra_offset,
|
||||
index_t linear_offset,
|
||||
@@ -258,10 +259,10 @@ struct tensor_view
|
||||
bool pre_nop = false,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<DataType>::scalar_type>,
|
||||
typename vector_traits<DataType_>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
async_get_vectorized_elements_raw(DataType* smem,
|
||||
async_get_vectorized_elements_raw(DataType_* smem,
|
||||
const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element,
|
||||
@@ -277,7 +278,7 @@ struct tensor_view
|
||||
template <typename X,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<DataType>::scalar_type>,
|
||||
typename vector_traits<DataType_>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<X>
|
||||
get_transpose_vectorized_elements(const TensorCoord& coord, index_t linear_offset) const
|
||||
@@ -291,7 +292,7 @@ struct tensor_view
|
||||
template <typename X,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<DataType>::scalar_type>,
|
||||
typename vector_traits<DataType_>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<X>
|
||||
get_transpose_vectorized_elements(const TensorCoord& coord,
|
||||
@@ -307,7 +308,7 @@ struct tensor_view
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<DataType>::scalar_type>,
|
||||
typename vector_traits<DataType_>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
set_vectorized_elements(const TensorCoord& coord,
|
||||
@@ -326,7 +327,7 @@ struct tensor_view
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<DataType>::scalar_type>,
|
||||
typename vector_traits<DataType_>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
set_vectorized_elements(const TensorCoord& coord,
|
||||
@@ -343,7 +344,7 @@ struct tensor_view
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<DataType>::scalar_type>,
|
||||
typename vector_traits<DataType_>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
set_vectorized_elements_raw(const TensorCoord& coord,
|
||||
@@ -362,7 +363,7 @@ struct tensor_view
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<DataType>::scalar_type>,
|
||||
typename vector_traits<DataType_>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
set_vectorized_elements_raw(const TensorCoord& coord,
|
||||
@@ -381,7 +382,7 @@ struct tensor_view
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<DataType>::scalar_type>,
|
||||
typename vector_traits<DataType_>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
update_vectorized_elements(const TensorCoord& coord,
|
||||
@@ -400,7 +401,7 @@ struct tensor_view
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<DataType>::scalar_type>,
|
||||
typename vector_traits<DataType_>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
update_vectorized_elements(const TensorCoord& coord,
|
||||
@@ -420,7 +421,7 @@ struct tensor_view
|
||||
bool pre_nop = false,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<DataType>::scalar_type>,
|
||||
typename vector_traits<DataType_>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
update_vectorized_elements_raw(const TensorCoord& coord,
|
||||
@@ -441,7 +442,7 @@ struct tensor_view
|
||||
bool pre_nop = false,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<DataType>::scalar_type>,
|
||||
typename vector_traits<DataType_>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
update_vectorized_elements_raw(const TensorCoord& coord,
|
||||
|
||||
Reference in New Issue
Block a user