mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-21 05:19:20 +00:00
add GetWorkSpaceSize to base arg (#253)
* add GetWorkSpaceSize to base arg and make an example on convnd_bwd_weight
* remove redundant compute
* use datatype and split k to check whether a workspace is used
* remove unused computation for work space size
[ROCm/composable_kernel commit: 0d08cf1893]
This commit is contained in:
@@ -257,11 +257,11 @@ int main(int argc, char* argv[])
|
||||
case 0: break;
|
||||
case 1:
|
||||
out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-2, 2});
|
||||
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-2, 2});
|
||||
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2<InDataType>{-2, 2});
|
||||
break;
|
||||
default:
|
||||
out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
|
||||
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1});
|
||||
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{1});
|
||||
}
|
||||
|
||||
DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
||||
@@ -296,15 +296,53 @@ int main(int argc, char* argv[])
|
||||
OutElementOp{},
|
||||
split_k);
|
||||
|
||||
if(!conv->IsSupportedArgument(argument.get()))
|
||||
// alloc work space
|
||||
size_t bwd_weight_workspace_size = conv->GetWorkSpaceSize(argument.get());
|
||||
float ave_time = 0.f;
|
||||
if(std::is_same<InDataType, ck::bhalf_t>::value && split_k > 1)
|
||||
{
|
||||
std::cout << "wrong! device_conv with the specified compilation parameters does "
|
||||
"not support this Conv problem"
|
||||
<< std::endl;
|
||||
return 1;
|
||||
}
|
||||
DeviceMem wei_work_space_device_buf(bwd_weight_workspace_size);
|
||||
wei_work_space_device_buf.SetZero();
|
||||
argument = conv->MakeArgumentPointer(
|
||||
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
|
||||
static_cast<AccDataType*>(wei_work_space_device_buf.GetDeviceBuffer()),
|
||||
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
|
||||
params.N_,
|
||||
params.K_,
|
||||
params.C_,
|
||||
params.input_spatial_lengths_,
|
||||
params.filter_spatial_lengths_,
|
||||
output_spatial_lengths,
|
||||
params.conv_filter_strides_,
|
||||
params.conv_filter_dilations_,
|
||||
params.input_left_pads_,
|
||||
params.input_right_pads_,
|
||||
InElementOp{},
|
||||
WeiElementOp{},
|
||||
OutElementOp{},
|
||||
split_k);
|
||||
|
||||
float ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel});
|
||||
if(!conv->IsSupportedArgument(argument.get()))
|
||||
{
|
||||
std::cout << "wrong! device_conv with the specified compilation parameters does "
|
||||
"not support this Conv problem"
|
||||
<< std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel});
|
||||
}
|
||||
else
|
||||
{
|
||||
if(!conv->IsSupportedArgument(argument.get()))
|
||||
{
|
||||
std::cout << "wrong! device_conv with the specified compilation parameters does "
|
||||
"not support this Conv problem"
|
||||
<< std::endl;
|
||||
return 1;
|
||||
}
|
||||
ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel});
|
||||
}
|
||||
|
||||
std::size_t flop = ck::utils::conv::get_flops(
|
||||
params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths);
|
||||
|
||||
Reference in New Issue
Block a user