From ac08f8a3a1e098dbd65aff860afcfee413c5ea0d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Tue, 23 Apr 2024 22:59:39 +0200 Subject: [PATCH] Fix contraction IsSupported checks (#1257) [ROCm/composable_kernel commit: b1f8ae379bbed73d0b92482083c1e54abebd14a0] --- .../device_contraction_multiple_d_xdl_cshuffle.hpp | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp index 1f65afed3d..4cc60f2836 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -627,7 +627,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle arg.a_max_read_elems_ % ABlockTransferSrcScalarPerVector == 0; const bool valid_a_access_dim_m = ABlockTransferSrcVectorDim == 1 && arg.a_mz_consecutive_; const bool valid_a_access_dim_k = ABlockTransferSrcVectorDim == 2 && arg.a_kz_consecutive_; - const bool valid_a_access_dim = valid_a_access_dim_m || valid_a_access_dim_k; + const bool valid_a_access_dim = + valid_a_access_dim_m || valid_a_access_dim_k || ABlockTransferSrcScalarPerVector == 1; if(!(valid_a_vector_size && valid_a_access_dim)) { return false; @@ -637,7 +638,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle arg.b_max_read_elems_ % BBlockTransferSrcScalarPerVector == 0; const bool valid_b_access_dim_n = BBlockTransferSrcVectorDim == 1 && arg.b_nz_consecutive_; const bool valid_b_access_dim_k = BBlockTransferSrcVectorDim == 2 && arg.b_kz_consecutive_; - const bool valid_b_access_dim = valid_b_access_dim_n || valid_b_access_dim_k; + const bool valid_b_access_dim = + valid_b_access_dim_n || valid_b_access_dim_k || BBlockTransferSrcScalarPerVector == 1; if(!(valid_b_vector_size && valid_b_access_dim)) { return false; @@ -648,7 +650,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle const bool valid_d_vector_size = arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0; // Vector read of Ds is always on N dimension. - const bool valid_d_access_dim = arg.ds_nz_consecutive_[i]; + const bool valid_d_access_dim = + arg.ds_nz_consecutive_[i] || CDEBlockTransferScalarPerVector_NPerBlock == 1; if(!(valid_d_vector_size && valid_d_access_dim)) { valid_ds_access = false; @@ -662,7 +665,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle const bool valid_e_vector_size = arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0; // Vector write of E is always on N dimension. - const bool valid_e_access_dim = arg.e_nz_consecutive_; + const bool valid_e_access_dim = + arg.e_nz_consecutive_ || CDEBlockTransferScalarPerVector_NPerBlock == 1; if(!(valid_e_vector_size && valid_e_access_dim)) { return false;