mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
Fix continous dim selection in contraction (#1336)
* Fix continous dim selection in contraction * Fixes
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -501,29 +501,24 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
|
||||
// for sanity check of vector memory access
|
||||
for(index_t i = 0; i < NumATensor; ++i)
|
||||
{
|
||||
as_mz_consecutive_[i] = a_ms_ks_strides[i][NumDimM - 1] == 1;
|
||||
as_kz_consecutive_[i] = a_ms_ks_strides[i][NumDimM + NumDimK - 1] == 1;
|
||||
as_max_read_elems_[i] =
|
||||
tie(as_continous_dim_[i], as_max_read_elems_[i]) =
|
||||
CalculateMaxRead<NumDimM, NumDimK>(a_ms_ks_lengths[i], a_ms_ks_strides[i]);
|
||||
}
|
||||
|
||||
for(index_t i = 0; i < NumBTensor; ++i)
|
||||
{
|
||||
bs_nz_consecutive_[i] = b_ns_ks_strides[i][NumDimN - 1] == 1;
|
||||
bs_kz_consecutive_[i] = b_ns_ks_strides[i][NumDimN + NumDimK - 1] == 1;
|
||||
bs_max_read_elems_[i] =
|
||||
tie(bs_continous_dim_[i], bs_max_read_elems_[i]) =
|
||||
CalculateMaxRead<NumDimN, NumDimK>(b_ns_ks_lengths[i], b_ns_ks_strides[i]);
|
||||
}
|
||||
|
||||
for(index_t i = 0; i < NumDTensor; ++i)
|
||||
{
|
||||
ds_nz_consecutive_[i] = d_ms_ns_strides[i][NumDimM + NumDimN - 1] == 1;
|
||||
ds_max_read_elems_[i] =
|
||||
tie(ds_continous_dim_[i], ds_max_read_elems_[i]) =
|
||||
CalculateMaxRead<NumDimM, NumDimN>(d_ms_ns_lengths[i], d_ms_ns_strides[i]);
|
||||
}
|
||||
|
||||
e_nz_consecutive_ = e_ms_ns_stride[NumDimM + NumDimN - 1] == 1;
|
||||
e_max_write_elems_ = CalculateMaxRead<NumDimM, NumDimN>(e_ms_ns_length, e_ms_ns_stride);
|
||||
tie(e_continous_dim_, e_max_write_elems_) =
|
||||
CalculateMaxRead<NumDimM, NumDimN>(e_ms_ns_length, e_ms_ns_stride);
|
||||
}
|
||||
|
||||
// pointers
|
||||
@@ -553,14 +548,11 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
|
||||
// Describe whether the last part of a given dimension of A/B/D/E is consecutive
|
||||
// in the memory or not.
|
||||
std::array<bool, NumATensor> as_mz_consecutive_;
|
||||
std::array<bool, NumATensor> as_kz_consecutive_;
|
||||
std::array<bool, NumBTensor> bs_nz_consecutive_;
|
||||
std::array<bool, NumBTensor> bs_kz_consecutive_;
|
||||
std::array<bool, NumDTensor> ds_nz_consecutive_;
|
||||
bool e_nz_consecutive_;
|
||||
// Describe whether the last part of a given dimension of A/B/D/E is continues dim.
|
||||
std::array<index_t, NumATensor> as_continous_dim_;
|
||||
std::array<index_t, NumATensor> bs_continous_dim_;
|
||||
std::array<index_t, NumBTensor> ds_continous_dim_;
|
||||
index_t e_continous_dim_;
|
||||
|
||||
std::array<index_t, NumATensor> as_max_read_elems_;
|
||||
std::array<index_t, NumBTensor> bs_max_read_elems_;
|
||||
@@ -659,9 +651,9 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
|
||||
const bool valid_a_vector_size =
|
||||
arg.as_max_read_elems_[i] % ABlockTransferSrcScalarPerVector == 0;
|
||||
const bool valid_a_access_dim_m =
|
||||
ABlockTransferSrcVectorDim == 1 && arg.as_mz_consecutive_[i];
|
||||
ABlockTransferSrcVectorDim == 1 && arg.as_continous_dim_[i] == 0;
|
||||
const bool valid_a_access_dim_k =
|
||||
ABlockTransferSrcVectorDim == 2 && arg.as_kz_consecutive_[i];
|
||||
ABlockTransferSrcVectorDim == 2 && arg.as_continous_dim_[i] == 1;
|
||||
const bool valid_a_access_dim = valid_a_access_dim_m || valid_a_access_dim_k;
|
||||
if(!((valid_a_vector_size && valid_a_access_dim) ||
|
||||
ABlockTransferSrcScalarPerVector == 1))
|
||||
@@ -679,9 +671,9 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
|
||||
const bool valid_b_vector_size =
|
||||
arg.bs_max_read_elems_[i] % BBlockTransferSrcScalarPerVector == 0;
|
||||
const bool valid_b_access_dim_n =
|
||||
BBlockTransferSrcVectorDim == 1 && arg.bs_nz_consecutive_[i];
|
||||
BBlockTransferSrcVectorDim == 1 && arg.bs_continous_dim_[i] == 0;
|
||||
const bool valid_b_access_dim_k =
|
||||
BBlockTransferSrcVectorDim == 2 && arg.bs_kz_consecutive_[i];
|
||||
BBlockTransferSrcVectorDim == 2 && arg.bs_continous_dim_[i] == 1;
|
||||
const bool valid_b_access_dim = valid_b_access_dim_n || valid_b_access_dim_k;
|
||||
if(!((valid_b_vector_size && valid_b_access_dim) ||
|
||||
BBlockTransferSrcScalarPerVector == 1))
|
||||
@@ -699,7 +691,7 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
|
||||
const bool valid_d_vector_size =
|
||||
arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0;
|
||||
// Vector read of Ds is always on N dimension.
|
||||
const bool valid_d_access_dim = arg.ds_nz_consecutive_[i];
|
||||
const bool valid_d_access_dim = arg.ds_continous_dim_[i] == 1;
|
||||
if(!((valid_d_vector_size && valid_d_access_dim) ||
|
||||
CDEBlockTransferScalarPerVector_NPerBlock == 1))
|
||||
{
|
||||
@@ -714,7 +706,7 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
|
||||
const bool valid_e_vector_size =
|
||||
arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0;
|
||||
// Vector write of E is always on N dimension.
|
||||
const bool valid_e_access_dim = arg.e_nz_consecutive_;
|
||||
const bool valid_e_access_dim = arg.e_continous_dim_ == 1;
|
||||
if(!((valid_e_vector_size && valid_e_access_dim) ||
|
||||
CDEBlockTransferScalarPerVector_NPerBlock == 1))
|
||||
{
|
||||
|
||||
@@ -442,25 +442,19 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
}
|
||||
|
||||
// for sanity check of vector memory access
|
||||
a_mz_consecutive_ = a_ms_ks_strides[NumDimM - 1] == 1;
|
||||
a_kz_consecutive_ = a_ms_ks_strides[NumDimM + NumDimK - 1] == 1;
|
||||
a_max_read_elems_ =
|
||||
tie(a_continous_dim_, a_max_read_elems_) =
|
||||
CalculateMaxRead<NumDimM, NumDimK>(a_ms_ks_lengths, a_ms_ks_strides);
|
||||
|
||||
b_nz_consecutive_ = b_ns_ks_strides[NumDimN - 1] == 1;
|
||||
b_kz_consecutive_ = b_ns_ks_strides[NumDimN + NumDimK - 1] == 1;
|
||||
b_max_read_elems_ =
|
||||
tie(b_continous_dim_, b_max_read_elems_) =
|
||||
CalculateMaxRead<NumDimN, NumDimK>(b_ns_ks_lengths, b_ns_ks_strides);
|
||||
|
||||
for(index_t i = 0; i < NumDTensor; ++i)
|
||||
{
|
||||
ds_nz_consecutive_[i] = ds_ms_ns_strides[i][NumDimM + NumDimN - 1] == 1;
|
||||
ds_max_read_elems_[i] =
|
||||
tie(ds_continous_dim_[i], ds_max_read_elems_[i]) =
|
||||
CalculateMaxRead<NumDimM, NumDimN>(ds_ms_ns_lengths[i], ds_ms_ns_strides[i]);
|
||||
}
|
||||
|
||||
e_nz_consecutive_ = e_ms_ns_strides[NumDimM + NumDimN - 1] == 1;
|
||||
e_max_write_elems_ =
|
||||
tie(e_continous_dim_, e_max_write_elems_) =
|
||||
CalculateMaxRead<NumDimM, NumDimN>(e_ms_ns_lengths, e_ms_ns_strides);
|
||||
}
|
||||
|
||||
@@ -501,14 +495,11 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
|
||||
// Describe whether the last part of a given dimension of A/B/D/E is consecutive
|
||||
// in the memory or not.
|
||||
bool a_mz_consecutive_;
|
||||
bool a_kz_consecutive_;
|
||||
bool b_nz_consecutive_;
|
||||
bool b_kz_consecutive_;
|
||||
std::array<bool, NumDTensor> ds_nz_consecutive_;
|
||||
bool e_nz_consecutive_;
|
||||
// Describe whether the last part of a given dimension of A/B/D/E is continues dim.
|
||||
index_t a_continous_dim_;
|
||||
index_t b_continous_dim_;
|
||||
std::array<index_t, NumDTensor> ds_continous_dim_;
|
||||
index_t e_continous_dim_;
|
||||
|
||||
index_t a_max_read_elems_;
|
||||
index_t b_max_read_elems_;
|
||||
@@ -624,8 +615,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
|
||||
const bool valid_a_vector_size =
|
||||
arg.a_max_read_elems_ % ABlockTransferSrcScalarPerVector == 0;
|
||||
const bool valid_a_access_dim_m = ABlockTransferSrcVectorDim == 1 && arg.a_mz_consecutive_;
|
||||
const bool valid_a_access_dim_k = ABlockTransferSrcVectorDim == 2 && arg.a_kz_consecutive_;
|
||||
const bool valid_a_access_dim_m =
|
||||
ABlockTransferSrcVectorDim == 1 && arg.a_continous_dim_ == 0;
|
||||
const bool valid_a_access_dim_k =
|
||||
ABlockTransferSrcVectorDim == 2 && arg.a_continous_dim_ == 1;
|
||||
const bool valid_a_access_dim =
|
||||
valid_a_access_dim_m || valid_a_access_dim_k || ABlockTransferSrcScalarPerVector == 1;
|
||||
if(!(valid_a_vector_size && valid_a_access_dim))
|
||||
@@ -635,8 +628,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
|
||||
const bool valid_b_vector_size =
|
||||
arg.b_max_read_elems_ % BBlockTransferSrcScalarPerVector == 0;
|
||||
const bool valid_b_access_dim_n = BBlockTransferSrcVectorDim == 1 && arg.b_nz_consecutive_;
|
||||
const bool valid_b_access_dim_k = BBlockTransferSrcVectorDim == 2 && arg.b_kz_consecutive_;
|
||||
const bool valid_b_access_dim_n =
|
||||
BBlockTransferSrcVectorDim == 1 && arg.b_continous_dim_ == 0;
|
||||
const bool valid_b_access_dim_k =
|
||||
BBlockTransferSrcVectorDim == 2 && arg.b_continous_dim_ == 1;
|
||||
const bool valid_b_access_dim =
|
||||
valid_b_access_dim_n || valid_b_access_dim_k || BBlockTransferSrcScalarPerVector == 1;
|
||||
if(!(valid_b_vector_size && valid_b_access_dim))
|
||||
@@ -650,7 +645,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0;
|
||||
// Vector read of Ds is always on N dimension.
|
||||
const bool valid_d_access_dim =
|
||||
arg.ds_nz_consecutive_[i] || CDEBlockTransferScalarPerVector_NPerBlock == 1;
|
||||
arg.ds_continous_dim_[i] == 1 || CDEBlockTransferScalarPerVector_NPerBlock == 1;
|
||||
if(!(valid_d_vector_size && valid_d_access_dim))
|
||||
{
|
||||
valid_ds_access = false;
|
||||
@@ -665,7 +660,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0;
|
||||
// Vector write of E is always on N dimension.
|
||||
const bool valid_e_access_dim =
|
||||
arg.e_nz_consecutive_ || CDEBlockTransferScalarPerVector_NPerBlock == 1;
|
||||
arg.e_continous_dim_ == 1 || CDEBlockTransferScalarPerVector_NPerBlock == 1;
|
||||
if(!(valid_e_vector_size && valid_e_access_dim))
|
||||
{
|
||||
return false;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -50,25 +50,53 @@ auto CalculateMaxRead(const std::vector<index_t>& lengths, const std::vector<ind
|
||||
}
|
||||
|
||||
// Determine the beginning and end idx of the group representing the FCD.
|
||||
index_t begin_idx, end_idx;
|
||||
if(strides[NumDim1 - 1] == 1)
|
||||
index_t begin_idx, end_idx, continous_dim, consecutive_stride = 1;
|
||||
if(strides[NumDim1 - 1] == 1 && strides[NumDim1 + NumDim2 - 1] == 1)
|
||||
{
|
||||
begin_idx = 0;
|
||||
end_idx = NumDim1 - 1;
|
||||
// MZ or KZ are ones
|
||||
bool dims1_are_ones = true;
|
||||
for(index_t dim_idx = 0; dim_idx < NumDim1; dim_idx++)
|
||||
{
|
||||
if(lengths[dim_idx] != 1)
|
||||
{
|
||||
dims1_are_ones = false;
|
||||
}
|
||||
}
|
||||
|
||||
if(dims1_are_ones)
|
||||
{
|
||||
begin_idx = NumDim1;
|
||||
end_idx = NumDim1 + NumDim2 - 1;
|
||||
continous_dim = 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
begin_idx = 0;
|
||||
end_idx = NumDim1 - 1;
|
||||
continous_dim = 0;
|
||||
}
|
||||
}
|
||||
else if(strides[NumDim1 - 1] == 1)
|
||||
{
|
||||
begin_idx = 0;
|
||||
end_idx = NumDim1 - 1;
|
||||
continous_dim = 0;
|
||||
}
|
||||
else if(strides[NumDim1 + NumDim2 - 1] == 1)
|
||||
{
|
||||
begin_idx = NumDim1;
|
||||
end_idx = NumDim1 + NumDim2 - 1;
|
||||
begin_idx = NumDim1;
|
||||
end_idx = NumDim1 + NumDim2 - 1;
|
||||
continous_dim = 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
// The dimension consecutive in memory is not the last dimension of any group, so only
|
||||
// one element can be read/written at once.
|
||||
return 1;
|
||||
consecutive_stride = 1;
|
||||
continous_dim = 0;
|
||||
return make_tuple(continous_dim, consecutive_stride);
|
||||
}
|
||||
|
||||
index_t consecutive_stride = 1;
|
||||
for(index_t dim_idx = end_idx; dim_idx >= begin_idx; --dim_idx)
|
||||
{
|
||||
if(strides[dim_idx] == consecutive_stride)
|
||||
@@ -81,7 +109,7 @@ auto CalculateMaxRead(const std::vector<index_t>& lengths, const std::vector<ind
|
||||
}
|
||||
}
|
||||
const index_t max_subsequent_elems = consecutive_stride;
|
||||
return max_subsequent_elems;
|
||||
return make_tuple(continous_dim, max_subsequent_elems);
|
||||
}
|
||||
|
||||
} // namespace device
|
||||
|
||||
Reference in New Issue
Block a user