mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
Add grouped conv bwd weight wmma (#985)
* Add grouped conv bwd weight wmma * Update README, changelog, profiler * Minor fixes * Fix grouped conv bwd wei dl kernel * Minor fixes * Minor stylistic fixes
This commit is contained in:
@@ -36,7 +36,7 @@ __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_grouped_conv_fwd_multiple_d_wmma_cshuffle(
|
||||
kernel_grouped_conv_multiple_d_wmma_cshuffle(
|
||||
const ADataType* __restrict__ p_a_grid,
|
||||
const BDataType* __restrict__ p_b_grid,
|
||||
DsPointer p_ds_grid,
|
||||
@@ -452,11 +452,11 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
|
||||
}
|
||||
|
||||
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
|
||||
// CheckValidity for kernels without multi D
|
||||
template <typename Block2CTileMap>
|
||||
__host__ __device__ static constexpr bool
|
||||
CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
|
||||
const DsGridDesc_M_N& ds_grid_desc_m_n,
|
||||
const EGridDesc_M_N& e_grid_desc_m_n,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
{
|
||||
@@ -471,18 +471,6 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
|
||||
const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
|
||||
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
|
||||
|
||||
bool valid = true;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) &&
|
||||
N == ds_grid_desc_m_n[i].GetLength(I1));
|
||||
});
|
||||
|
||||
if(!valid)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) &&
|
||||
K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
|
||||
K1 == b_grid_desc_k0_n_k1.GetLength(I2)))
|
||||
@@ -517,6 +505,31 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename Block2CTileMap>
|
||||
__host__ __device__ static constexpr bool
|
||||
CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
|
||||
const DsGridDesc_M_N& ds_grid_desc_m_n,
|
||||
const EGridDesc_M_N& e_grid_desc_m_n,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
{
|
||||
const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
|
||||
const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
|
||||
bool valid = true;
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) &&
|
||||
N == ds_grid_desc_m_n[i].GetLength(I1));
|
||||
});
|
||||
|
||||
if(!valid)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return CheckValidity(
|
||||
a_grid_desc_k0_m_k1, b_grid_desc_k0_n_k1, e_grid_desc_m_n, block_2_ctile_map);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
|
||||
{
|
||||
const index_t num_loop = K / (K0PerBlock * K1);
|
||||
|
||||
Reference in New Issue
Block a user