diff --git a/include/cutlass/conv/device/implicit_gemm_convolution.h b/include/cutlass/conv/device/implicit_gemm_convolution.h index 2e5e3b0c8..dff737ffe 100644 --- a/include/cutlass/conv/device/implicit_gemm_convolution.h +++ b/include/cutlass/conv/device/implicit_gemm_convolution.h @@ -244,7 +244,7 @@ public: void *workspace = nullptr, cudaStream_t stream = nullptr) { - Status status = initialize(args, workspace); + Status status = initialize(args, workspace, stream); if (status == Status::kSuccess) { status = run(stream); diff --git a/include/cutlass/gemm/device/gemm.h b/include/cutlass/gemm/device/gemm.h index e1d0092cd..b398688f2 100644 --- a/include/cutlass/gemm/device/gemm.h +++ b/include/cutlass/gemm/device/gemm.h @@ -482,7 +482,7 @@ public: void *workspace = nullptr, cudaStream_t stream = nullptr) { - Status status = initialize(args, workspace); + Status status = initialize(args, workspace, stream); if (status == Status::kSuccess) { status = run(stream); @@ -673,7 +673,7 @@ public: /// Initializes GEMM state from arguments. Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { - return underlying_operator_.initialize(to_underlying_arguments(args), workspace); + return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); } /// Lightweight update given a subset of arguments @@ -699,7 +699,7 @@ public: void *workspace = nullptr, cudaStream_t stream = nullptr) { - Status status = initialize(args, workspace); + Status status = initialize(args, workspace, stream); if (status == Status::kSuccess) { status = run(stream); diff --git a/include/cutlass/gemm/device/gemm_array.h b/include/cutlass/gemm/device/gemm_array.h index 12bc300ff..be7be25de 100644 --- a/include/cutlass/gemm/device/gemm_array.h +++ b/include/cutlass/gemm/device/gemm_array.h @@ -473,7 +473,7 @@ public: void *workspace = nullptr, cudaStream_t stream = nullptr) { - Status status = initialize(args, workspace); + Status status = initialize(args, workspace, stream); if (status == Status::kSuccess) { status = run(stream); @@ -700,7 +700,7 @@ public: /// Initializes GEMM state from arguments. Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { - return underlying_operator_.initialize(to_underlying_arguments(args), workspace); + return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); } /// Lightweight update given a subset of arguments @@ -726,7 +726,7 @@ public: void *workspace = nullptr, cudaStream_t stream = nullptr) { - Status status = initialize(args, workspace); + Status status = initialize(args, workspace, stream); if (status == Status::kSuccess) { status = run(stream); diff --git a/include/cutlass/gemm/device/gemm_batched.h b/include/cutlass/gemm/device/gemm_batched.h index 8f09b4a77..e10932704 100644 --- a/include/cutlass/gemm/device/gemm_batched.h +++ b/include/cutlass/gemm/device/gemm_batched.h @@ -451,7 +451,7 @@ public: void *workspace = nullptr, cudaStream_t stream = nullptr) { - Status status = initialize(args, workspace); + Status status = initialize(args, workspace, stream); if (status == Status::kSuccess) { status = run(stream); @@ -666,7 +666,7 @@ public: /// Initializes GEMM state from arguments. Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { - return underlying_operator_.initialize(to_underlying_arguments(args), workspace); + return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); } /// Lightweight update given a subset of arguments @@ -692,7 +692,7 @@ public: void *workspace = nullptr, cudaStream_t stream = nullptr) { - Status status = initialize(args, workspace); + Status status = initialize(args, workspace, stream); if (status == Status::kSuccess) { status = run(stream); diff --git a/include/cutlass/gemm/device/gemm_complex.h b/include/cutlass/gemm/device/gemm_complex.h index 70e0b46a3..4b0fcaa98 100644 --- a/include/cutlass/gemm/device/gemm_complex.h +++ b/include/cutlass/gemm/device/gemm_complex.h @@ -465,7 +465,7 @@ public: void *workspace = nullptr, cudaStream_t stream = nullptr) { - Status status = initialize(args, workspace); + Status status = initialize(args, workspace, stream); if (status == Status::kSuccess) { status = run(stream); @@ -674,7 +674,7 @@ public: /// Initializes GEMM state from arguments. Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { - return underlying_operator_.initialize(to_underlying_arguments(args), workspace); + return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); } /// Lightweight update given a subset of arguments @@ -700,7 +700,7 @@ public: void *workspace = nullptr, cudaStream_t stream = nullptr) { - Status status = initialize(args, workspace); + Status status = initialize(args, workspace, stream); if (status == Status::kSuccess) { status = run(stream); diff --git a/include/cutlass/gemm/device/gemm_sparse.h b/include/cutlass/gemm/device/gemm_sparse.h index 04e2dd667..b37f585a9 100644 --- a/include/cutlass/gemm/device/gemm_sparse.h +++ b/include/cutlass/gemm/device/gemm_sparse.h @@ -498,7 +498,7 @@ public: void *workspace = nullptr, cudaStream_t stream = nullptr) { - Status status = initialize(args, workspace); + Status status = initialize(args, workspace, stream); if (status == Status::kSuccess) { status = run(stream); diff --git a/include/cutlass/reduction/device/reduce_split_k.h b/include/cutlass/reduction/device/reduce_split_k.h index 4c044a4ca..f8558643a 100644 --- a/include/cutlass/reduction/device/reduce_split_k.h +++ b/include/cutlass/reduction/device/reduce_split_k.h @@ -197,7 +197,7 @@ public: void *workspace = nullptr, cudaStream_t stream = nullptr) { - Status status = initialize(args, workspace); + Status status = initialize(args, workspace, stream); if (status == Status::kSuccess) { status = run(stream);