From 6db9cf9f68f28aed0bde287c17280ac3dc8dd8c9 Mon Sep 17 00:00:00 2001 From: "Ding, Yi" Date: Mon, 26 Jan 2026 17:12:11 +0000 Subject: [PATCH] Fix --- include/ck_tile/core/tensor/tensor_view.hpp | 57 +++++++++++---------- 1 file changed, 29 insertions(+), 28 deletions(-) diff --git a/include/ck_tile/core/tensor/tensor_view.hpp b/include/ck_tile/core/tensor/tensor_view.hpp index 053e1c8ea7..24cb397643 100644 --- a/include/ck_tile/core/tensor/tensor_view.hpp +++ b/include/ck_tile/core/tensor/tensor_view.hpp @@ -40,12 +40,13 @@ template ; - using DataType = remove_cvref_t; + using DataType = typename buffer_view::type; + using DataType_ = remove_cvref_t; using TensorDesc = remove_cvref_t; using TensorIndex = array; using TensorCoord = decltype(make_tensor_coordinate(TensorDesc{}, TensorIndex{})); static constexpr auto DstInMemOp = DstInMemOp_; - static constexpr index_t PackedSize = ck_tile::numeric_traits::PackedSize; + static constexpr index_t PackedSize = ck_tile::numeric_traits::PackedSize; template using vector_scalar_t = typename vector_traits>::scalar_type; @@ -77,7 +78,7 @@ struct tensor_view bool oob_conditional_check = true, typename std::enable_if< std::is_same_v>::scalar_type, - typename vector_traits::scalar_type>, + typename vector_traits::scalar_type>, bool>::type = false> CK_TILE_HOST_DEVICE constexpr remove_cvref_t get_vectorized_elements(const TensorCoord& coord, @@ -95,7 +96,7 @@ struct tensor_view bool oob_conditional_check = true, typename std::enable_if< std::is_same_v>::scalar_type, - typename vector_traits::scalar_type>, + typename vector_traits::scalar_type>, bool>::type = false> CK_TILE_HOST_DEVICE constexpr remove_cvref_t get_vectorized_elements(const TensorCoord& coord, @@ -116,7 +117,7 @@ struct tensor_view bool pre_nop = false, typename std::enable_if< std::is_same_v>::scalar_type, - typename vector_traits::scalar_type>, + typename vector_traits::scalar_type>, bool>::type = false> CK_TILE_HOST_DEVICE void get_vectorized_elements_raw(remove_cvref_t& dst, const TensorCoord& coord, @@ -137,7 +138,7 @@ struct tensor_view bool pre_nop = false, typename std::enable_if< std::is_same_v>::scalar_type, - typename vector_traits::scalar_type>, + typename vector_traits::scalar_type>, bool>::type = false> CK_TILE_HOST_DEVICE void get_vectorized_elements_raw(remove_cvref_t& dst, const TensorCoord& coord, @@ -158,9 +159,9 @@ struct tensor_view bool oob_conditional_check = true, index_t IMM = 0, typename = std::enable_if_t< - std::is_same_v>, vector_scalar_t>>> + std::is_same_v>, vector_scalar_t>>> CK_TILE_HOST_DEVICE constexpr void - async_get_vectorized_elements(CK_TILE_LDS_ADDR DataType* smem, + async_get_vectorized_elements(CK_TILE_LDS_ADDR DataType_* smem, index_t offset, index_t wave_offset, number = {}, @@ -177,9 +178,9 @@ struct tensor_view template >, vector_scalar_t>>> + std::is_same_v>, vector_scalar_t>>> CK_TILE_HOST_DEVICE constexpr void - async_get_vectorized_elements(CK_TILE_LDS_ADDR DataType* smem, + async_get_vectorized_elements(CK_TILE_LDS_ADDR DataType_* smem, const TensorCoord& coord, index_t linear_offset, bool_constant = {}) const @@ -197,9 +198,9 @@ struct tensor_view bool oob_conditional_check = true, typename = std::enable_if_t< std::is_same_v>::scalar_type, - typename vector_traits::scalar_type>>> + typename vector_traits::scalar_type>>> CK_TILE_HOST_DEVICE constexpr void - async_get_vectorized_elements(CK_TILE_LDS_ADDR DataType* smem, + async_get_vectorized_elements(CK_TILE_LDS_ADDR DataType_* smem, const TensorCoord& coord, index_t linear_offset, bool is_valid_element, @@ -217,10 +218,10 @@ struct tensor_view bool pre_nop = false, typename std::enable_if< std::is_same_v>::scalar_type, - typename vector_traits::scalar_type>, + typename vector_traits::scalar_type>, bool>::type = false> CK_TILE_HOST_DEVICE constexpr void - async_get_vectorized_elements_raw(DataType* smem, + async_get_vectorized_elements_raw(DataType_* smem, const TensorCoord& coord, index_t linear_offset, bool_constant = {}) const @@ -237,10 +238,10 @@ struct tensor_view bool pre_nop = false, typename std::enable_if< std::is_same_v>::scalar_type, - typename vector_traits::scalar_type>, + typename vector_traits::scalar_type>, bool>::type = false> CK_TILE_HOST_DEVICE constexpr void - async_get_vectorized_elements_raw(DataType* smem, + async_get_vectorized_elements_raw(DataType_* smem, const TensorCoord& coord, index_t coord_extra_offset, index_t linear_offset, @@ -258,10 +259,10 @@ struct tensor_view bool pre_nop = false, typename std::enable_if< std::is_same_v>::scalar_type, - typename vector_traits::scalar_type>, + typename vector_traits::scalar_type>, bool>::type = false> CK_TILE_HOST_DEVICE constexpr void - async_get_vectorized_elements_raw(DataType* smem, + async_get_vectorized_elements_raw(DataType_* smem, const TensorCoord& coord, index_t linear_offset, bool is_valid_element, @@ -277,7 +278,7 @@ struct tensor_view template >::scalar_type, - typename vector_traits::scalar_type>, + typename vector_traits::scalar_type>, bool>::type = false> CK_TILE_HOST_DEVICE constexpr remove_cvref_t get_transpose_vectorized_elements(const TensorCoord& coord, index_t linear_offset) const @@ -291,7 +292,7 @@ struct tensor_view template >::scalar_type, - typename vector_traits::scalar_type>, + typename vector_traits::scalar_type>, bool>::type = false> CK_TILE_HOST_DEVICE constexpr remove_cvref_t get_transpose_vectorized_elements(const TensorCoord& coord, @@ -307,7 +308,7 @@ struct tensor_view bool oob_conditional_check = true, typename std::enable_if< std::is_same_v>::scalar_type, - typename vector_traits::scalar_type>, + typename vector_traits::scalar_type>, bool>::type = false> CK_TILE_HOST_DEVICE constexpr void set_vectorized_elements(const TensorCoord& coord, @@ -326,7 +327,7 @@ struct tensor_view bool oob_conditional_check = true, typename std::enable_if< std::is_same_v>::scalar_type, - typename vector_traits::scalar_type>, + typename vector_traits::scalar_type>, bool>::type = false> CK_TILE_HOST_DEVICE constexpr void set_vectorized_elements(const TensorCoord& coord, @@ -343,7 +344,7 @@ struct tensor_view bool oob_conditional_check = true, typename std::enable_if< std::is_same_v>::scalar_type, - typename vector_traits::scalar_type>, + typename vector_traits::scalar_type>, bool>::type = false> CK_TILE_HOST_DEVICE constexpr void set_vectorized_elements_raw(const TensorCoord& coord, @@ -362,7 +363,7 @@ struct tensor_view bool oob_conditional_check = true, typename std::enable_if< std::is_same_v>::scalar_type, - typename vector_traits::scalar_type>, + typename vector_traits::scalar_type>, bool>::type = false> CK_TILE_HOST_DEVICE constexpr void set_vectorized_elements_raw(const TensorCoord& coord, @@ -381,7 +382,7 @@ struct tensor_view bool oob_conditional_check = true, typename std::enable_if< std::is_same_v>::scalar_type, - typename vector_traits::scalar_type>, + typename vector_traits::scalar_type>, bool>::type = false> CK_TILE_HOST_DEVICE constexpr void update_vectorized_elements(const TensorCoord& coord, @@ -400,7 +401,7 @@ struct tensor_view bool oob_conditional_check = true, typename std::enable_if< std::is_same_v>::scalar_type, - typename vector_traits::scalar_type>, + typename vector_traits::scalar_type>, bool>::type = false> CK_TILE_HOST_DEVICE constexpr void update_vectorized_elements(const TensorCoord& coord, @@ -420,7 +421,7 @@ struct tensor_view bool pre_nop = false, typename std::enable_if< std::is_same_v>::scalar_type, - typename vector_traits::scalar_type>, + typename vector_traits::scalar_type>, bool>::type = false> CK_TILE_HOST_DEVICE constexpr void update_vectorized_elements_raw(const TensorCoord& coord, @@ -441,7 +442,7 @@ struct tensor_view bool pre_nop = false, typename std::enable_if< std::is_same_v>::scalar_type, - typename vector_traits::scalar_type>, + typename vector_traits::scalar_type>, bool>::type = false> CK_TILE_HOST_DEVICE constexpr void update_vectorized_elements_raw(const TensorCoord& coord,