mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
Fix contraction IsSupported checks (#1257)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -627,7 +627,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
arg.a_max_read_elems_ % ABlockTransferSrcScalarPerVector == 0;
|
||||
const bool valid_a_access_dim_m = ABlockTransferSrcVectorDim == 1 && arg.a_mz_consecutive_;
|
||||
const bool valid_a_access_dim_k = ABlockTransferSrcVectorDim == 2 && arg.a_kz_consecutive_;
|
||||
const bool valid_a_access_dim = valid_a_access_dim_m || valid_a_access_dim_k;
|
||||
const bool valid_a_access_dim =
|
||||
valid_a_access_dim_m || valid_a_access_dim_k || ABlockTransferSrcScalarPerVector == 1;
|
||||
if(!(valid_a_vector_size && valid_a_access_dim))
|
||||
{
|
||||
return false;
|
||||
@@ -637,7 +638,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
arg.b_max_read_elems_ % BBlockTransferSrcScalarPerVector == 0;
|
||||
const bool valid_b_access_dim_n = BBlockTransferSrcVectorDim == 1 && arg.b_nz_consecutive_;
|
||||
const bool valid_b_access_dim_k = BBlockTransferSrcVectorDim == 2 && arg.b_kz_consecutive_;
|
||||
const bool valid_b_access_dim = valid_b_access_dim_n || valid_b_access_dim_k;
|
||||
const bool valid_b_access_dim =
|
||||
valid_b_access_dim_n || valid_b_access_dim_k || BBlockTransferSrcScalarPerVector == 1;
|
||||
if(!(valid_b_vector_size && valid_b_access_dim))
|
||||
{
|
||||
return false;
|
||||
@@ -648,7 +650,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
const bool valid_d_vector_size =
|
||||
arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0;
|
||||
// Vector read of Ds is always on N dimension.
|
||||
const bool valid_d_access_dim = arg.ds_nz_consecutive_[i];
|
||||
const bool valid_d_access_dim =
|
||||
arg.ds_nz_consecutive_[i] || CDEBlockTransferScalarPerVector_NPerBlock == 1;
|
||||
if(!(valid_d_vector_size && valid_d_access_dim))
|
||||
{
|
||||
valid_ds_access = false;
|
||||
@@ -662,7 +665,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
const bool valid_e_vector_size =
|
||||
arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0;
|
||||
// Vector write of E is always on N dimension.
|
||||
const bool valid_e_access_dim = arg.e_nz_consecutive_;
|
||||
const bool valid_e_access_dim =
|
||||
arg.e_nz_consecutive_ || CDEBlockTransferScalarPerVector_NPerBlock == 1;
|
||||
if(!(valid_e_vector_size && valid_e_access_dim))
|
||||
{
|
||||
return false;
|
||||
|
||||
Reference in New Issue
Block a user