diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp index 33e03a85e2..dae16612cc 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -501,29 +501,24 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle // for sanity check of vector memory access for(index_t i = 0; i < NumATensor; ++i) { - as_mz_consecutive_[i] = a_ms_ks_strides[i][NumDimM - 1] == 1; - as_kz_consecutive_[i] = a_ms_ks_strides[i][NumDimM + NumDimK - 1] == 1; - as_max_read_elems_[i] = + tie(as_continous_dim_[i], as_max_read_elems_[i]) = CalculateMaxRead(a_ms_ks_lengths[i], a_ms_ks_strides[i]); } for(index_t i = 0; i < NumBTensor; ++i) { - bs_nz_consecutive_[i] = b_ns_ks_strides[i][NumDimN - 1] == 1; - bs_kz_consecutive_[i] = b_ns_ks_strides[i][NumDimN + NumDimK - 1] == 1; - bs_max_read_elems_[i] = + tie(bs_continous_dim_[i], bs_max_read_elems_[i]) = CalculateMaxRead(b_ns_ks_lengths[i], b_ns_ks_strides[i]); } for(index_t i = 0; i < NumDTensor; ++i) { - ds_nz_consecutive_[i] = d_ms_ns_strides[i][NumDimM + NumDimN - 1] == 1; - ds_max_read_elems_[i] = + tie(ds_continous_dim_[i], ds_max_read_elems_[i]) = CalculateMaxRead(d_ms_ns_lengths[i], d_ms_ns_strides[i]); } - e_nz_consecutive_ = e_ms_ns_stride[NumDimM + NumDimN - 1] == 1; - e_max_write_elems_ = CalculateMaxRead(e_ms_ns_length, e_ms_ns_stride); + tie(e_continous_dim_, e_max_write_elems_) = + CalculateMaxRead(e_ms_ns_length, e_ms_ns_stride); } // pointers @@ -553,14 +548,11 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle BElementwiseOperation b_element_op_; CDEElementwiseOperation cde_element_op_; - // Describe whether the last part of a given dimension of A/B/D/E is consecutive - // in the memory or not. - std::array as_mz_consecutive_; - std::array as_kz_consecutive_; - std::array bs_nz_consecutive_; - std::array bs_kz_consecutive_; - std::array ds_nz_consecutive_; - bool e_nz_consecutive_; + // Describe whether the last part of a given dimension of A/B/D/E is continues dim. + std::array as_continous_dim_; + std::array bs_continous_dim_; + std::array ds_continous_dim_; + index_t e_continous_dim_; std::array as_max_read_elems_; std::array bs_max_read_elems_; @@ -659,9 +651,9 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle const bool valid_a_vector_size = arg.as_max_read_elems_[i] % ABlockTransferSrcScalarPerVector == 0; const bool valid_a_access_dim_m = - ABlockTransferSrcVectorDim == 1 && arg.as_mz_consecutive_[i]; + ABlockTransferSrcVectorDim == 1 && arg.as_continous_dim_[i] == 0; const bool valid_a_access_dim_k = - ABlockTransferSrcVectorDim == 2 && arg.as_kz_consecutive_[i]; + ABlockTransferSrcVectorDim == 2 && arg.as_continous_dim_[i] == 1; const bool valid_a_access_dim = valid_a_access_dim_m || valid_a_access_dim_k; if(!((valid_a_vector_size && valid_a_access_dim) || ABlockTransferSrcScalarPerVector == 1)) @@ -679,9 +671,9 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle const bool valid_b_vector_size = arg.bs_max_read_elems_[i] % BBlockTransferSrcScalarPerVector == 0; const bool valid_b_access_dim_n = - BBlockTransferSrcVectorDim == 1 && arg.bs_nz_consecutive_[i]; + BBlockTransferSrcVectorDim == 1 && arg.bs_continous_dim_[i] == 0; const bool valid_b_access_dim_k = - BBlockTransferSrcVectorDim == 2 && arg.bs_kz_consecutive_[i]; + BBlockTransferSrcVectorDim == 2 && arg.bs_continous_dim_[i] == 1; const bool valid_b_access_dim = valid_b_access_dim_n || valid_b_access_dim_k; if(!((valid_b_vector_size && valid_b_access_dim) || BBlockTransferSrcScalarPerVector == 1)) @@ -699,7 +691,7 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle const bool valid_d_vector_size = arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0; // Vector read of Ds is always on N dimension. - const bool valid_d_access_dim = arg.ds_nz_consecutive_[i]; + const bool valid_d_access_dim = arg.ds_continous_dim_[i] == 1; if(!((valid_d_vector_size && valid_d_access_dim) || CDEBlockTransferScalarPerVector_NPerBlock == 1)) { @@ -714,7 +706,7 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle const bool valid_e_vector_size = arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0; // Vector write of E is always on N dimension. - const bool valid_e_access_dim = arg.e_nz_consecutive_; + const bool valid_e_access_dim = arg.e_continous_dim_ == 1; if(!((valid_e_vector_size && valid_e_access_dim) || CDEBlockTransferScalarPerVector_NPerBlock == 1)) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp index 9d5b74be6c..f1bc6a2261 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp @@ -442,25 +442,19 @@ struct DeviceContractionMultipleD_Xdl_CShuffle } // for sanity check of vector memory access - a_mz_consecutive_ = a_ms_ks_strides[NumDimM - 1] == 1; - a_kz_consecutive_ = a_ms_ks_strides[NumDimM + NumDimK - 1] == 1; - a_max_read_elems_ = + tie(a_continous_dim_, a_max_read_elems_) = CalculateMaxRead(a_ms_ks_lengths, a_ms_ks_strides); - b_nz_consecutive_ = b_ns_ks_strides[NumDimN - 1] == 1; - b_kz_consecutive_ = b_ns_ks_strides[NumDimN + NumDimK - 1] == 1; - b_max_read_elems_ = + tie(b_continous_dim_, b_max_read_elems_) = CalculateMaxRead(b_ns_ks_lengths, b_ns_ks_strides); for(index_t i = 0; i < NumDTensor; ++i) { - ds_nz_consecutive_[i] = ds_ms_ns_strides[i][NumDimM + NumDimN - 1] == 1; - ds_max_read_elems_[i] = + tie(ds_continous_dim_[i], ds_max_read_elems_[i]) = CalculateMaxRead(ds_ms_ns_lengths[i], ds_ms_ns_strides[i]); } - e_nz_consecutive_ = e_ms_ns_strides[NumDimM + NumDimN - 1] == 1; - e_max_write_elems_ = + tie(e_continous_dim_, e_max_write_elems_) = CalculateMaxRead(e_ms_ns_lengths, e_ms_ns_strides); } @@ -501,14 +495,11 @@ struct DeviceContractionMultipleD_Xdl_CShuffle BElementwiseOperation b_element_op_; CDEElementwiseOperation cde_element_op_; - // Describe whether the last part of a given dimension of A/B/D/E is consecutive - // in the memory or not. - bool a_mz_consecutive_; - bool a_kz_consecutive_; - bool b_nz_consecutive_; - bool b_kz_consecutive_; - std::array ds_nz_consecutive_; - bool e_nz_consecutive_; + // Describe whether the last part of a given dimension of A/B/D/E is continues dim. + index_t a_continous_dim_; + index_t b_continous_dim_; + std::array ds_continous_dim_; + index_t e_continous_dim_; index_t a_max_read_elems_; index_t b_max_read_elems_; @@ -624,8 +615,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle const bool valid_a_vector_size = arg.a_max_read_elems_ % ABlockTransferSrcScalarPerVector == 0; - const bool valid_a_access_dim_m = ABlockTransferSrcVectorDim == 1 && arg.a_mz_consecutive_; - const bool valid_a_access_dim_k = ABlockTransferSrcVectorDim == 2 && arg.a_kz_consecutive_; + const bool valid_a_access_dim_m = + ABlockTransferSrcVectorDim == 1 && arg.a_continous_dim_ == 0; + const bool valid_a_access_dim_k = + ABlockTransferSrcVectorDim == 2 && arg.a_continous_dim_ == 1; const bool valid_a_access_dim = valid_a_access_dim_m || valid_a_access_dim_k || ABlockTransferSrcScalarPerVector == 1; if(!(valid_a_vector_size && valid_a_access_dim)) @@ -635,8 +628,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle const bool valid_b_vector_size = arg.b_max_read_elems_ % BBlockTransferSrcScalarPerVector == 0; - const bool valid_b_access_dim_n = BBlockTransferSrcVectorDim == 1 && arg.b_nz_consecutive_; - const bool valid_b_access_dim_k = BBlockTransferSrcVectorDim == 2 && arg.b_kz_consecutive_; + const bool valid_b_access_dim_n = + BBlockTransferSrcVectorDim == 1 && arg.b_continous_dim_ == 0; + const bool valid_b_access_dim_k = + BBlockTransferSrcVectorDim == 2 && arg.b_continous_dim_ == 1; const bool valid_b_access_dim = valid_b_access_dim_n || valid_b_access_dim_k || BBlockTransferSrcScalarPerVector == 1; if(!(valid_b_vector_size && valid_b_access_dim)) @@ -650,7 +645,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0; // Vector read of Ds is always on N dimension. const bool valid_d_access_dim = - arg.ds_nz_consecutive_[i] || CDEBlockTransferScalarPerVector_NPerBlock == 1; + arg.ds_continous_dim_[i] == 1 || CDEBlockTransferScalarPerVector_NPerBlock == 1; if(!(valid_d_vector_size && valid_d_access_dim)) { valid_ds_access = false; @@ -665,7 +660,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0; // Vector write of E is always on N dimension. const bool valid_e_access_dim = - arg.e_nz_consecutive_ || CDEBlockTransferScalarPerVector_NPerBlock == 1; + arg.e_continous_dim_ == 1 || CDEBlockTransferScalarPerVector_NPerBlock == 1; if(!(valid_e_vector_size && valid_e_access_dim)) { return false; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp index 838305f187..1b0db73fdd 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -50,25 +50,53 @@ auto CalculateMaxRead(const std::vector& lengths, const std::vector= begin_idx; --dim_idx) { if(strides[dim_idx] == consecutive_stride) @@ -81,7 +109,7 @@ auto CalculateMaxRead(const std::vector& lengths, const std::vectortemplate Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}}); this->template Run<2>({{16, 8}, {16, 8}, {16, 8}}); this->template Run<2>({{8, 16}, {16, 8}, {8, 16}}); + + // special cases + this->template Run<2>({{1, 1}, {16, 8}, {8, 16}}); + this->template Run<2>({{8, 16}, {16, 8}, {1, 1}}); + this->template Run<2>({{8, 16}, {1, 1}, {8, 16}}); + this->template Run<2>({{1, 1}, {1, 1}, {1, 1}}); }