replace hipMemcpy with hipMemcpyWithStream (#734)

This commit is contained in:
who who who
2023-06-02 05:23:41 +08:00
committed by GitHub
parent 9eae73df9b
commit e2ebc8e795
5 changed files with 26 additions and 21 deletions

View File

@@ -652,11 +652,12 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle
}
}
hipGetErrorString(hipMemcpy(arg.p_workspace_,
arg.contraction_multi_d_kernel_args_.data(),
arg.contraction_multi_d_kernel_args_.size() *
sizeof(ContractionMultiDKernelArg),
hipMemcpyHostToDevice));
hipGetErrorString(hipMemcpyWithStream(arg.p_workspace_,
arg.contraction_multi_d_kernel_args_.data(),
arg.contraction_multi_d_kernel_args_.size() *
sizeof(ContractionMultiDKernelArg),
hipMemcpyHostToDevice,
stream_config.stream_id_));
float ave_time = 0;

View File

@@ -597,10 +597,11 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
}
}
hipGetErrorString(hipMemcpy(arg.p_workspace_,
arg.gemm_desc_kernel_arg_.data(),
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmKernelArg),
hipMemcpyHostToDevice));
hipGetErrorString(hipMemcpyWithStream(arg.p_workspace_,
arg.gemm_desc_kernel_arg_.data(),
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmKernelArg),
hipMemcpyHostToDevice,
stream_config.stream_id_));
auto launch_kernel = [&](auto has_main_k_block_loop,
auto has_double_tail_k_block_loop) {

View File

@@ -549,10 +549,11 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
}
hipGetErrorString(
hipMemcpy(arg.p_workspace_,
arg.gemm_desc_kernel_arg_.data(),
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmBiasTransKernelArg),
hipMemcpyHostToDevice));
hipMemcpyWithStream(arg.p_workspace_,
arg.gemm_desc_kernel_arg_.data(),
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmBiasTransKernelArg),
hipMemcpyHostToDevice,
stream_config.stream_id_));
float ave_time = 0;

View File

@@ -406,10 +406,11 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
}
}
hip_check_error(hipMemcpy(arg.p_workspace_,
arg.gemm_kernel_args_.data(),
arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg),
hipMemcpyHostToDevice));
hip_check_error(hipMemcpyWithStream(arg.p_workspace_,
arg.gemm_kernel_args_.data(),
arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg),
hipMemcpyHostToDevice,
stream_config.stream_id_));
float ave_time = 0;