mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Fixed handling of split-K autodeduce argument for grouped convolution (#3024)
* Fix handling of split-K autodeduce argument. * Fix clang formatting. * Test fix. * Fix clang formatting.
This commit is contained in:
@@ -689,6 +689,12 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
return false;
|
||||
}
|
||||
|
||||
// Split-K autodeduction is not supported
|
||||
if(arg.k_batch_ < 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Gridwise GEMM size
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
|
||||
@@ -1523,6 +1523,14 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Split-K autodeduction is not supported.
|
||||
if(arg.k_batch_ < 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
const index_t ConvG = arg.b_g_k_c_xs_lengths_[0];
|
||||
const index_t ConvK = arg.b_g_k_c_xs_lengths_[1];
|
||||
|
||||
@@ -688,6 +688,12 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
// Split-K autodeduction is not supported
|
||||
if(arg.k_batch_ < 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::type_convert<ck::index_t>(arg.gemm_desc_kernel_arg_.size()) != arg.group_count_)
|
||||
{
|
||||
return false;
|
||||
|
||||
Reference in New Issue
Block a user