mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
Fix IsSupported check in the contraction op (#1066)
Current implementation of IsSupported method in contraction ops does not cover a lot of possible cases in which ScalarPerVector cannot really be used to read A, B or D, or write E. This PR extends both the regular and multiABD contraction ops with improved checks and also adds new instances with smaller values of ScalarPerVector to support instances that are not supported by other instances.
This commit is contained in:
committed by
GitHub
parent
f199035b74
commit
89ee47460b
@@ -14,6 +14,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_abd.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
@@ -500,22 +501,29 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
|
||||
// for sanity check of vector memory access
|
||||
for(index_t i = 0; i < NumATensor; ++i)
|
||||
{
|
||||
a_mz_stride_[i] = a_ms_ks_strides[i][NumDimM - 1];
|
||||
a_kz_stride_[i] = a_ms_ks_strides[i][NumDimM + NumDimK - 1];
|
||||
as_mz_consecutive_[i] = a_ms_ks_strides[i][NumDimM - 1] == 1;
|
||||
as_kz_consecutive_[i] = a_ms_ks_strides[i][NumDimM + NumDimK - 1] == 1;
|
||||
as_max_read_elems_[i] =
|
||||
CalculateMaxRead<NumDimM, NumDimK>(a_ms_ks_lengths[i], a_ms_ks_strides[i]);
|
||||
}
|
||||
|
||||
for(index_t i = 0; i < NumBTensor; ++i)
|
||||
{
|
||||
b_nz_stride_[i] = b_ns_ks_strides[i][NumDimN - 1];
|
||||
b_kz_stride_[i] = b_ns_ks_strides[i][NumDimN + NumDimK - 1];
|
||||
bs_nz_consecutive_[i] = b_ns_ks_strides[i][NumDimN - 1] == 1;
|
||||
bs_kz_consecutive_[i] = b_ns_ks_strides[i][NumDimN + NumDimK - 1] == 1;
|
||||
bs_max_read_elems_[i] =
|
||||
CalculateMaxRead<NumDimN, NumDimK>(b_ns_ks_lengths[i], b_ns_ks_strides[i]);
|
||||
}
|
||||
|
||||
for(index_t i = 0; i < NumDTensor; ++i)
|
||||
{
|
||||
ds_nz_stride_[i] = d_ms_ns_strides[i][NumDimM + NumDimN - 1];
|
||||
ds_nz_consecutive_[i] = d_ms_ns_strides[i][NumDimM + NumDimN - 1] == 1;
|
||||
ds_max_read_elems_[i] =
|
||||
CalculateMaxRead<NumDimM, NumDimN>(d_ms_ns_lengths[i], d_ms_ns_strides[i]);
|
||||
}
|
||||
|
||||
e_nz_stride_ = e_ms_ns_stride[NumDimM + NumDimN - 1];
|
||||
e_nz_consecutive_ = e_ms_ns_stride[NumDimM + NumDimN - 1] == 1;
|
||||
e_max_write_elems_ = CalculateMaxRead<NumDimM, NumDimN>(e_ms_ns_length, e_ms_ns_stride);
|
||||
}
|
||||
|
||||
// pointers
|
||||
@@ -545,16 +553,19 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
|
||||
// Strides for the last M/N/K dimensions of A/B/Ds/E
|
||||
// for sanity check of vector load/store
|
||||
std::array<index_t, NumATensor> a_mz_stride_;
|
||||
std::array<index_t, NumATensor> a_kz_stride_;
|
||||
// Describe whether the last part of a given dimension of A/B/D/E is consecutive
|
||||
// in the memory or not.
|
||||
std::array<bool, NumATensor> as_mz_consecutive_;
|
||||
std::array<bool, NumATensor> as_kz_consecutive_;
|
||||
std::array<bool, NumBTensor> bs_nz_consecutive_;
|
||||
std::array<bool, NumBTensor> bs_kz_consecutive_;
|
||||
std::array<bool, NumDTensor> ds_nz_consecutive_;
|
||||
bool e_nz_consecutive_;
|
||||
|
||||
std::array<index_t, NumBTensor> b_nz_stride_;
|
||||
std::array<index_t, NumBTensor> b_kz_stride_;
|
||||
|
||||
std::array<index_t, NumDTensor> ds_nz_stride_;
|
||||
index_t e_nz_stride_;
|
||||
std::array<index_t, NumATensor> as_max_read_elems_;
|
||||
std::array<index_t, NumBTensor> bs_max_read_elems_;
|
||||
std::array<index_t, NumDTensor> ds_max_read_elems_;
|
||||
index_t e_max_write_elems_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -643,73 +654,65 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
|
||||
|
||||
// check vector load/store
|
||||
{
|
||||
bool all_valid = true;
|
||||
|
||||
bool valid_as_access = true;
|
||||
static_for<0, NumATensor, 1>{}([&](auto i) {
|
||||
// vector memory access of A: could be on M or AK1 dimension
|
||||
if constexpr(ABlockTransferSrcVectorDim == 1)
|
||||
const bool valid_a_vector_size =
|
||||
arg.as_max_read_elems_[i] % ABlockTransferSrcScalarPerVector == 0;
|
||||
const bool valid_a_access_dim_m =
|
||||
ABlockTransferSrcVectorDim == 1 && arg.as_mz_consecutive_[i];
|
||||
const bool valid_a_access_dim_k =
|
||||
ABlockTransferSrcVectorDim == 2 && arg.as_kz_consecutive_[i];
|
||||
const bool valid_a_access_dim = valid_a_access_dim_m || valid_a_access_dim_k;
|
||||
if(!(valid_a_vector_size && valid_a_access_dim))
|
||||
{
|
||||
if(!(arg.a_mz_stride_[i] == 1 && arg.as_grid_desc_ak0_m_ak1_[i].GetLength(I1) %
|
||||
ABlockTransferSrcScalarPerVector ==
|
||||
0))
|
||||
{
|
||||
all_valid = false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(!(arg.a_kz_stride_[i] == 1 && arg.as_grid_desc_ak0_m_ak1_[i].GetLength(I2) %
|
||||
ABlockTransferSrcScalarPerVector ==
|
||||
0))
|
||||
{
|
||||
all_valid = false;
|
||||
}
|
||||
valid_as_access = false;
|
||||
}
|
||||
});
|
||||
|
||||
// vector memory access of B: could be on N or BK1 dimension
|
||||
static_for<0, NumBTensor, 1>{}([&](auto i) {
|
||||
if constexpr(BBlockTransferSrcVectorDim == 1)
|
||||
{
|
||||
if(!(arg.b_nz_stride_[i] == 1 && arg.bs_grid_desc_bk0_n_bk1_[i].GetLength(I1) %
|
||||
BBlockTransferSrcScalarPerVector ==
|
||||
0))
|
||||
{
|
||||
all_valid = false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(!(arg.b_kz_stride_[i] == 1 && arg.bs_grid_desc_bk0_n_bk1_[i].GetLength(I2) %
|
||||
BBlockTransferSrcScalarPerVector ==
|
||||
0))
|
||||
{
|
||||
all_valid = false;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// check vector load of Ds
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
if(!(arg.ds_nz_stride_[i] == 1 &&
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[i].GetLength(I3) %
|
||||
CDEBlockTransferScalarPerVector_NPerBlock ==
|
||||
0))
|
||||
{
|
||||
all_valid = false;
|
||||
}
|
||||
});
|
||||
|
||||
// vector memory access of E: always on NPerBlock dimension
|
||||
if(!(arg.e_nz_stride_ == 1 &&
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetLength(I3) %
|
||||
CDEBlockTransferScalarPerVector_NPerBlock ==
|
||||
0))
|
||||
if(!valid_as_access)
|
||||
{
|
||||
all_valid = false;
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!all_valid)
|
||||
bool valid_bs_access = true;
|
||||
static_for<0, NumBTensor, 1>{}([&](auto i) {
|
||||
const bool valid_b_vector_size =
|
||||
arg.bs_max_read_elems_[i] % BBlockTransferSrcScalarPerVector == 0;
|
||||
const bool valid_b_access_dim_n =
|
||||
BBlockTransferSrcVectorDim == 1 && arg.bs_nz_consecutive_[i];
|
||||
const bool valid_b_access_dim_k =
|
||||
BBlockTransferSrcVectorDim == 2 && arg.bs_kz_consecutive_[i];
|
||||
const bool valid_b_access_dim = valid_b_access_dim_n || valid_b_access_dim_k;
|
||||
if(!(valid_b_vector_size && valid_b_access_dim))
|
||||
{
|
||||
valid_bs_access = false;
|
||||
}
|
||||
});
|
||||
if(!valid_bs_access)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
bool valid_ds_access = true;
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
const bool valid_d_vector_size =
|
||||
arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0;
|
||||
// Vector read of Ds is always on N dimension.
|
||||
const bool valid_d_access_dim = arg.ds_nz_consecutive_[i];
|
||||
if(!(valid_d_vector_size && valid_d_access_dim))
|
||||
{
|
||||
valid_ds_access = false;
|
||||
}
|
||||
});
|
||||
if(!valid_ds_access)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
const bool valid_e_vector_size =
|
||||
arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0;
|
||||
// Vector write of E is always on N dimension.
|
||||
const bool valid_e_access_dim = arg.e_nz_consecutive_;
|
||||
if(!(valid_e_vector_size && valid_e_access_dim))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
@@ -183,7 +184,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
return generate_tuple([&](auto i) { return vec[i]; }, num);
|
||||
};
|
||||
|
||||
const auto a_ms_ns_lengths = to_tuple(a_ms_ks_lengths_vec, Number<NumDimM + NumDimK>{});
|
||||
const auto a_ms_ks_lengths = to_tuple(a_ms_ks_lengths_vec, Number<NumDimM + NumDimK>{});
|
||||
const auto a_ms_ks_strides = to_tuple(a_ms_ks_strides_vec, Number<NumDimM + NumDimK>{});
|
||||
|
||||
// dimension Ids for M0, M1, ...
|
||||
@@ -194,14 +195,14 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
typename arithmetic_sequence_gen<NumDimM, NumDimM + NumDimK, 1>::type{};
|
||||
|
||||
// lengths for M0, M1, ...
|
||||
const auto mLengths = get_container_subset(a_ms_ns_lengths, mDimIds);
|
||||
const auto mLengths = get_container_subset(a_ms_ks_lengths, mDimIds);
|
||||
|
||||
// lengths for K0, K1, ...
|
||||
const auto kLengths = get_container_subset(a_ms_ns_lengths, kDimIds);
|
||||
const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds);
|
||||
|
||||
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...]
|
||||
const auto a_grid_desc_ms_ks =
|
||||
make_naive_tensor_descriptor(a_ms_ns_lengths, a_ms_ks_strides);
|
||||
make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides);
|
||||
|
||||
// transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...]
|
||||
const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor(
|
||||
@@ -383,7 +384,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
const void* p_b_grid,
|
||||
std::array<const void*, NumDTensor> p_ds_grid,
|
||||
void* p_e_grid,
|
||||
const std::vector<index_t>& a_ms_ns_lengths,
|
||||
const std::vector<index_t>& a_ms_ks_lengths,
|
||||
const std::vector<index_t>& a_ms_ks_strides,
|
||||
const std::vector<index_t>& b_ns_ks_lengths,
|
||||
const std::vector<index_t>& b_ns_ks_strides,
|
||||
@@ -398,7 +399,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
|
||||
p_ds_grid_{},
|
||||
p_e_grid_{static_cast<EDataType*>(p_e_grid)},
|
||||
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(a_ms_ns_lengths, a_ms_ks_strides)},
|
||||
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(a_ms_ks_lengths, a_ms_ks_strides)},
|
||||
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(b_ns_ks_lengths, b_ns_ks_strides)},
|
||||
ds_grid_desc_m_n_{},
|
||||
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_ms_ns_lengths, e_ms_ns_strides)},
|
||||
@@ -411,13 +412,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op},
|
||||
a_mz_stride_{},
|
||||
a_kz_stride_{},
|
||||
b_nz_stride_{},
|
||||
b_kz_stride_{},
|
||||
ds_nz_stride_{},
|
||||
e_nz_stride_{}
|
||||
cde_element_op_{cde_element_op}
|
||||
{
|
||||
// populate pointer, batch stride, desc for Ds
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
@@ -448,18 +443,26 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
}
|
||||
|
||||
// for sanity check of vector memory access
|
||||
a_mz_stride_ = a_ms_ks_strides[NumDimM - 1];
|
||||
a_kz_stride_ = a_ms_ks_strides[NumDimM + NumDimK - 1];
|
||||
a_mz_consecutive_ = a_ms_ks_strides[NumDimM - 1] == 1;
|
||||
a_kz_consecutive_ = a_ms_ks_strides[NumDimM + NumDimK - 1] == 1;
|
||||
a_max_read_elems_ =
|
||||
CalculateMaxRead<NumDimM, NumDimK>(a_ms_ks_lengths, a_ms_ks_strides);
|
||||
|
||||
b_nz_stride_ = b_ns_ks_strides[NumDimN - 1];
|
||||
b_kz_stride_ = b_ns_ks_strides[NumDimN + NumDimK - 1];
|
||||
b_nz_consecutive_ = b_ns_ks_strides[NumDimN - 1] == 1;
|
||||
b_kz_consecutive_ = b_ns_ks_strides[NumDimN + NumDimK - 1] == 1;
|
||||
b_max_read_elems_ =
|
||||
CalculateMaxRead<NumDimN, NumDimK>(b_ns_ks_lengths, b_ns_ks_strides);
|
||||
|
||||
for(index_t i = 0; i < NumDTensor; ++i)
|
||||
{
|
||||
ds_nz_stride_[i] = ds_ms_ns_strides[i][NumDimM + NumDimN - 1];
|
||||
ds_nz_consecutive_[i] = ds_ms_ns_strides[i][NumDimM + NumDimN - 1] == 1;
|
||||
ds_max_read_elems_[i] =
|
||||
CalculateMaxRead<NumDimM, NumDimN>(ds_ms_ns_lengths[i], ds_ms_ns_strides[i]);
|
||||
}
|
||||
|
||||
e_nz_stride_ = e_ms_ns_strides[NumDimM + NumDimN - 1];
|
||||
e_nz_consecutive_ = e_ms_ns_strides[NumDimM + NumDimN - 1] == 1;
|
||||
e_max_write_elems_ =
|
||||
CalculateMaxRead<NumDimM, NumDimN>(e_ms_ns_lengths, e_ms_ns_strides);
|
||||
}
|
||||
|
||||
void Print() const
|
||||
@@ -499,15 +502,19 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
|
||||
// Strides for the last M/N/K dimensions of A/B/Ds/E
|
||||
// for sanity check of vector load/store
|
||||
index_t a_mz_stride_;
|
||||
index_t a_kz_stride_;
|
||||
index_t b_nz_stride_;
|
||||
index_t b_kz_stride_;
|
||||
std::array<index_t, NumDTensor> ds_nz_stride_;
|
||||
index_t e_mz_stride_;
|
||||
index_t e_nz_stride_;
|
||||
// Describe whether the last part of a given dimension of A/B/D/E is consecutive
|
||||
// in the memory or not.
|
||||
bool a_mz_consecutive_;
|
||||
bool a_kz_consecutive_;
|
||||
bool b_nz_consecutive_;
|
||||
bool b_kz_consecutive_;
|
||||
std::array<bool, NumDTensor> ds_nz_consecutive_;
|
||||
bool e_nz_consecutive_;
|
||||
|
||||
index_t a_max_read_elems_;
|
||||
index_t b_max_read_elems_;
|
||||
std::array<index_t, NumDTensor> ds_max_read_elems_;
|
||||
index_t e_max_write_elems_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -616,65 +623,47 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
(BBlockTransferSrcVectorDim == 1 || BBlockTransferSrcVectorDim == 2),
|
||||
"wrong!");
|
||||
|
||||
// vector memory access of A: could be on M or AK1 dimension
|
||||
if constexpr(ABlockTransferSrcVectorDim == 1)
|
||||
{
|
||||
if(!(arg.a_mz_stride_ == 1 &&
|
||||
arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) % ABlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(!(arg.a_kz_stride_ == 1 &&
|
||||
arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// vector memory access of B: could be on N or BK1 dimension
|
||||
if constexpr(BBlockTransferSrcVectorDim == 1)
|
||||
{
|
||||
if(!(arg.b_nz_stride_ == 1 &&
|
||||
arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) % BBlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(!(arg.b_kz_stride_ == 1 &&
|
||||
arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) % BBlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// vector memory access of Ds: always on NPerBlock dimension
|
||||
bool valid_d_access = true;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
if(!(arg.ds_nz_stride_[i] == 1 &&
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[i].GetLength(I3) %
|
||||
CDEBlockTransferScalarPerVector_NPerBlock ==
|
||||
0))
|
||||
{
|
||||
valid_d_access = false;
|
||||
}
|
||||
});
|
||||
|
||||
if(valid_d_access == false)
|
||||
const bool valid_a_vector_size =
|
||||
arg.a_max_read_elems_ % ABlockTransferSrcScalarPerVector == 0;
|
||||
const bool valid_a_access_dim_m = ABlockTransferSrcVectorDim == 1 && arg.a_mz_consecutive_;
|
||||
const bool valid_a_access_dim_k = ABlockTransferSrcVectorDim == 2 && arg.a_kz_consecutive_;
|
||||
const bool valid_a_access_dim = valid_a_access_dim_m || valid_a_access_dim_k;
|
||||
if(!(valid_a_vector_size && valid_a_access_dim))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// vector memory access of E: always on NPerBlock dimension
|
||||
if(!(arg.e_nz_stride_ == 1 &&
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetLength(I3) %
|
||||
CDEBlockTransferScalarPerVector_NPerBlock ==
|
||||
0))
|
||||
const bool valid_b_vector_size =
|
||||
arg.b_max_read_elems_ % BBlockTransferSrcScalarPerVector == 0;
|
||||
const bool valid_b_access_dim_n = BBlockTransferSrcVectorDim == 1 && arg.b_nz_consecutive_;
|
||||
const bool valid_b_access_dim_k = BBlockTransferSrcVectorDim == 2 && arg.b_kz_consecutive_;
|
||||
const bool valid_b_access_dim = valid_b_access_dim_n || valid_b_access_dim_k;
|
||||
if(!(valid_b_vector_size && valid_b_access_dim))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
bool valid_ds_access = true;
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
const bool valid_d_vector_size =
|
||||
arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0;
|
||||
// Vector read of Ds is always on N dimension.
|
||||
const bool valid_d_access_dim = arg.ds_nz_consecutive_[i];
|
||||
if(!(valid_d_vector_size && valid_d_access_dim))
|
||||
{
|
||||
valid_ds_access = false;
|
||||
}
|
||||
});
|
||||
if(!valid_ds_access)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
const bool valid_e_vector_size =
|
||||
arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0;
|
||||
// Vector write of E is always on N dimension.
|
||||
const bool valid_e_access_dim = arg.e_nz_consecutive_;
|
||||
if(!(valid_e_vector_size && valid_e_access_dim))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -692,7 +681,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
const std::vector<index_t>& a_ms_ns_lengths,
|
||||
const std::vector<index_t>& a_ms_ks_lengths,
|
||||
const std::vector<index_t>& a_ms_ks_strides,
|
||||
const std::vector<index_t>& b_ns_ks_lengths,
|
||||
const std::vector<index_t>& b_ns_ks_strides,
|
||||
@@ -708,7 +697,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
p_b,
|
||||
p_ds,
|
||||
p_e,
|
||||
a_ms_ns_lengths,
|
||||
a_ms_ks_lengths,
|
||||
a_ms_ks_strides,
|
||||
b_ns_ks_lengths,
|
||||
b_ns_ks_strides,
|
||||
@@ -729,7 +718,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
const std::vector<index_t>& a_ms_ns_lengths,
|
||||
const std::vector<index_t>& a_ms_ks_lengths,
|
||||
const std::vector<index_t>& a_ms_ks_strides,
|
||||
const std::vector<index_t>& b_ns_ks_lengths,
|
||||
const std::vector<index_t>& b_ns_ks_strides,
|
||||
@@ -745,7 +734,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
p_b,
|
||||
p_ds,
|
||||
p_e,
|
||||
a_ms_ns_lengths,
|
||||
a_ms_ks_lengths,
|
||||
a_ms_ks_strides,
|
||||
b_ns_ks_lengths,
|
||||
b_ns_ks_strides,
|
||||
|
||||
@@ -0,0 +1,87 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cassert>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
/**
|
||||
* Calculates the maximum number of subsequent elements of the fast changing dimension
|
||||
* that are consecutive in memory.
|
||||
*
|
||||
* Example:
|
||||
* NumDimM = 2, NumDimK = 3
|
||||
* A shape = [ 2, 3, 4, 5, 6]
|
||||
* A strides = [360, 120, 30, 6, 1]
|
||||
* | M | | K |
|
||||
* It follows from strides that K is FCD and all the subsequent elements of K are consecutive
|
||||
* in memory.
|
||||
* But if strides were [360, 120, 6, 24, 1], then only 6 subsequent elements of K would be
|
||||
* consecutive in memory.
|
||||
*
|
||||
* Assumes that the dimensions are split into two groups of `NumDim1` and `NumDim2` dimensions.
|
||||
*/
|
||||
template <index_t NumDim1, index_t NumDim2>
|
||||
auto CalculateMaxRead(const std::vector<index_t>& lengths, const std::vector<index_t>& strides)
|
||||
{
|
||||
if(lengths.size() != NumDim1 + NumDim2)
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "Incorrect number of lengths in " << __FILE__ << ":" << __LINE__
|
||||
<< ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
if(strides.size() != NumDim1 + NumDim2)
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "Incorrect number of strides in " << __FILE__ << ":" << __LINE__
|
||||
<< ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
|
||||
// Determine the beginning and end idx of the group representing the FCD.
|
||||
index_t begin_idx, end_idx;
|
||||
if(strides[NumDim1 - 1] == 1)
|
||||
{
|
||||
begin_idx = 0;
|
||||
end_idx = NumDim1 - 1;
|
||||
}
|
||||
else if(strides[NumDim1 + NumDim2 - 1] == 1)
|
||||
{
|
||||
begin_idx = NumDim1;
|
||||
end_idx = NumDim1 + NumDim2 - 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
// The dimension consecutive in memory is not the last dimension of any group, so only
|
||||
// one element can be read/written at once.
|
||||
return 1;
|
||||
}
|
||||
|
||||
index_t consecutive_stride = 1;
|
||||
for(index_t dim_idx = end_idx; dim_idx >= begin_idx; --dim_idx)
|
||||
{
|
||||
if(strides[dim_idx] == consecutive_stride)
|
||||
{
|
||||
consecutive_stride *= lengths[dim_idx];
|
||||
}
|
||||
else
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
const index_t max_subsequent_elems = consecutive_stride;
|
||||
return max_subsequent_elems;
|
||||
}
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user