mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
* Turning compare warnings on * Cleaning part I * Cleaning part II * Explicit static_cast to ck::type_convert * Resolving large tensor size issue. * format * revert change to tensor descriptor; promote lementSpaceSize to 64bit * use integer value for GEMM test * Review remarks * Review remarks + issues with (un)signed arithmetic * Format fix * Format * Clang-format. * fix 2gb limit issue Co-authored-by: Chao Liu <chao.liu2@amd.com> Co-authored-by: Adam Osewski <aosewski@amd.com>
This commit is contained in:
@@ -635,11 +635,12 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
d_grid_desc_m_{DeviceOp::MakeDGridDescriptor_M(MRaw)},
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
d_grid_desc_mblock_mperblock_{},
|
||||
compute_base_ptr_of_batch_{a_grid_desc_ak0_m_ak1_.GetElementSpaceSize(),
|
||||
b_grid_desc_bk0_n_bk1_.GetElementSpaceSize(),
|
||||
c_grid_desc_m_n_.GetElementSpaceSize(),
|
||||
d_grid_desc_m_.GetElementSpaceSize(),
|
||||
d_grid_desc_m_.GetElementSpaceSize()},
|
||||
compute_base_ptr_of_batch_{
|
||||
type_convert<index_t>(a_grid_desc_ak0_m_ak1_.GetElementSpaceSize()),
|
||||
type_convert<index_t>(b_grid_desc_bk0_n_bk1_.GetElementSpaceSize()),
|
||||
type_convert<index_t>(c_grid_desc_m_n_.GetElementSpaceSize()),
|
||||
type_convert<index_t>(d_grid_desc_m_.GetElementSpaceSize()),
|
||||
type_convert<index_t>(d_grid_desc_m_.GetElementSpaceSize())},
|
||||
block_2_ctile_map_{},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
|
||||
@@ -384,9 +384,10 @@ struct DeviceBatchedGemmXdl
|
||||
DeviceBatchedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB)},
|
||||
c_grid_desc_m_n_{DeviceBatchedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC)},
|
||||
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
|
||||
compute_ptr_offset_of_batch_{a_grid_desc_k0_m_k1_.GetElementSpaceSize(),
|
||||
b_grid_desc_k0_n_k1_.GetElementSpaceSize(),
|
||||
c_grid_desc_m_n_.GetElementSpaceSize()},
|
||||
compute_ptr_offset_of_batch_{
|
||||
type_convert<index_t>(a_grid_desc_k0_m_k1_.GetElementSpaceSize()),
|
||||
type_convert<index_t>(b_grid_desc_k0_n_k1_.GetElementSpaceSize()),
|
||||
type_convert<index_t>(c_grid_desc_m_n_.GetElementSpaceSize())},
|
||||
block_2_ctile_map_{},
|
||||
M01_{M01},
|
||||
N01_{N01},
|
||||
|
||||
@@ -697,7 +697,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
}
|
||||
|
||||
// Gridwise GEMM size
|
||||
for(int i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
|
||||
for(std::size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
|
||||
{
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
|
||||
arg.b_grid_desc_k0_n_k1_container_[i],
|
||||
|
||||
@@ -1412,7 +1412,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
|
||||
}
|
||||
|
||||
// Gridwise GEMM size
|
||||
for(int i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
|
||||
for(std::size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
|
||||
{
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
|
||||
arg.b_grid_desc_k0_n_k1_container_[i],
|
||||
|
||||
@@ -861,17 +861,11 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
// Input tensors can't be bigger than 2GB each.
|
||||
constexpr std::size_t GB2 = 2 * 1e9;
|
||||
constexpr ck::long_index_t GB2 = (ck::long_index_t{1} << 31);
|
||||
|
||||
if(arg.a_grid_desc_k0_m_k1_.GetElementSpaceSize() > GB2)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(arg.b_grid_desc_k0_n_k1_.GetElementSpaceSize() > GB2)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(arg.c_grid_desc_m_n_.GetElementSpaceSize() > GB2)
|
||||
if(arg.a_grid_desc_k0_m_k1_.GetElementSpaceSize() * sizeof(ADataType) > GB2 ||
|
||||
arg.b_grid_desc_k0_n_k1_.GetElementSpaceSize() * sizeof(BDataType) > GB2 ||
|
||||
arg.c_grid_desc_m_n_.GetElementSpaceSize() * sizeof(CDataType) > GB2)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -372,17 +372,18 @@ struct DeviceGroupedGemmXdl
|
||||
{
|
||||
grid_size_ = 0;
|
||||
|
||||
group_count_ = static_cast<int>(gemm_shapes.size());
|
||||
group_count_ = ck::type_convert<ck::index_t>(gemm_shapes.size());
|
||||
|
||||
if(!(group_count_ == p_a.size() && group_count_ == p_b.size() &&
|
||||
group_count_ == p_c.size()))
|
||||
if(!(group_count_ == ck::type_convert<ck::index_t>(p_a.size()) &&
|
||||
group_count_ == ck::type_convert<ck::index_t>(p_b.size()) &&
|
||||
group_count_ == ck::type_convert<ck::index_t>(p_c.size())))
|
||||
{
|
||||
throw std::runtime_error("wrong! group_count_ != P_a/b/c.size");
|
||||
}
|
||||
|
||||
gemm_desc_kernel_arg_.reserve(group_count_);
|
||||
|
||||
for(index_t i = 0; i < gemm_shapes.size(); i++)
|
||||
for(std::size_t i = 0; i < gemm_shapes.size(); i++)
|
||||
{
|
||||
const index_t M = gemm_shapes[i].M;
|
||||
const index_t N = gemm_shapes[i].N;
|
||||
@@ -563,7 +564,7 @@ struct DeviceGroupedGemmXdl
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(arg.gemm_desc_kernel_arg_.size() != arg.group_count_)
|
||||
if(ck::type_convert<ck::index_t>(arg.gemm_desc_kernel_arg_.size()) != arg.group_count_)
|
||||
return false;
|
||||
else
|
||||
return true;
|
||||
|
||||
Reference in New Issue
Block a user