mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 04:19:36 +00:00
Integrate universal gemm with conv bwd data and add SplitK (#1315)
* Integrate universal gemm with conv bwd data
* Fix multi d kernel
* Add splitK support
* instances refactor
* instances refactor
* refactor
* fixeS
* fixes
* 16x16 instnaces
* Fixes
* Fix
* Fix
* Fix
* Fix
* Fix
* Fixes
* fix
* fix
[ROCm/composable_kernel commit: 4094ad158a]
This commit is contained in:
@@ -68,8 +68,8 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
|
||||
const bool time_kernel = std::stoi(argv[7]);
|
||||
const int num_dim_spatial = std::stoi(argv[8]);
|
||||
|
||||
// 8 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial
|
||||
if(argc != 8 + 1 + 4 + 6 * num_dim_spatial)
|
||||
// 8 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial, 1 for split-K
|
||||
if(argc != 8 + 1 + 4 + 6 * num_dim_spatial + 1)
|
||||
{
|
||||
print_helper_msg();
|
||||
return 1;
|
||||
@@ -77,6 +77,8 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
|
||||
|
||||
const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 9, argv);
|
||||
|
||||
ck::index_t split_k = std::stoi(argv[8 + 1 + 4 + 6 * num_dim_spatial]);
|
||||
|
||||
using F32 = float;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
@@ -110,7 +112,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
|
||||
OutDataType,
|
||||
WeiDataType,
|
||||
InDataType>(
|
||||
do_verification, init_method, do_log, time_kernel, params);
|
||||
do_verification, init_method, do_log, time_kernel, params, split_k);
|
||||
|
||||
return pass ? 0 : 1;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user