mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
Add Grouped Conv Fwd Large Tensor kernel (#1432)
* Support 64 bit indexing * Add new grouped conv fwd kernel for large tensors * Add instances large tensor * Fixes for transform conv to gemm * Fixes * fixes * Remove not needed instances * examples fixes * Remove not need ds arrays * Fix tests * Add 2GB check in gridwise dl * Fixes
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -111,6 +111,15 @@ struct GridwiseGemmDlMultipleD_km_kn_mn
|
||||
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
|
||||
const CGridDesc_M_N& c_grid_desc_m_n)
|
||||
{
|
||||
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
|
||||
|
||||
if(!(a_grid_desc_k0_m_k1.GetElementSpaceSize() * sizeof(FloatAB) <= TwoGB &&
|
||||
b_grid_desc_k0_n_k1.GetElementSpaceSize() * sizeof(FloatAB) <= TwoGB &&
|
||||
c_grid_desc_m_n.GetElementSpaceSize() * sizeof(FloatC) <= TwoGB))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
|
||||
const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
|
||||
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -649,6 +649,15 @@ struct GridwiseGemmDl_bkm_bkn_mn_v1r3
|
||||
const BGridDesc_B_K0_N_K1& b_grid_desc_b_k0_n_k1,
|
||||
const CGridDesc_M_N& c_grid_desc_m_n)
|
||||
{
|
||||
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
|
||||
|
||||
if(!(a_grid_desc_b_k0_m_k1.GetElementSpaceSize() * sizeof(FloatAB) <= TwoGB &&
|
||||
b_grid_desc_b_k0_n_k1.GetElementSpaceSize() * sizeof(FloatAB) <= TwoGB &&
|
||||
c_grid_desc_m_n.GetElementSpaceSize() * sizeof(FloatC) <= TwoGB))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto M = a_grid_desc_b_k0_m_k1.GetLength(I2);
|
||||
const auto N = b_grid_desc_b_k0_n_k1.GetLength(I2);
|
||||
const auto K0 = a_grid_desc_b_k0_m_k1.GetLength(I1);
|
||||
|
||||
Reference in New Issue
Block a user