From 92393b2676220149180e99601b9ccc3ce7291e6f Mon Sep 17 00:00:00 2001 From: Peter Han Date: Tue, 23 Mar 2021 21:11:42 +0800 Subject: [PATCH 1/2] Bugfix: memsetAsync uses wrong default stream Signed-off-by: Peter Han --- include/cutlass/conv/device/implicit_gemm_convolution.h | 2 +- include/cutlass/gemm/device/gemm.h | 6 +++--- include/cutlass/gemm/device/gemm_array.h | 6 +++--- include/cutlass/gemm/device/gemm_batched.h | 6 +++--- include/cutlass/gemm/device/gemm_complex.h | 6 +++--- include/cutlass/gemm/device/gemm_sparse.h | 2 +- include/cutlass/gemm/device/gemm_universal_base.h | 2 +- include/cutlass/reduction/device/reduce_split_k.h | 2 +- 8 files changed, 16 insertions(+), 16 deletions(-) diff --git a/include/cutlass/conv/device/implicit_gemm_convolution.h b/include/cutlass/conv/device/implicit_gemm_convolution.h index 2e5e3b0c8..dff737ffe 100644 --- a/include/cutlass/conv/device/implicit_gemm_convolution.h +++ b/include/cutlass/conv/device/implicit_gemm_convolution.h @@ -244,7 +244,7 @@ public: void *workspace = nullptr, cudaStream_t stream = nullptr) { - Status status = initialize(args, workspace); + Status status = initialize(args, workspace, stream); if (status == Status::kSuccess) { status = run(stream); diff --git a/include/cutlass/gemm/device/gemm.h b/include/cutlass/gemm/device/gemm.h index e1d0092cd..b398688f2 100644 --- a/include/cutlass/gemm/device/gemm.h +++ b/include/cutlass/gemm/device/gemm.h @@ -482,7 +482,7 @@ public: void *workspace = nullptr, cudaStream_t stream = nullptr) { - Status status = initialize(args, workspace); + Status status = initialize(args, workspace, stream); if (status == Status::kSuccess) { status = run(stream); @@ -673,7 +673,7 @@ public: /// Initializes GEMM state from arguments. Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { - return underlying_operator_.initialize(to_underlying_arguments(args), workspace); + return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); } /// Lightweight update given a subset of arguments @@ -699,7 +699,7 @@ public: void *workspace = nullptr, cudaStream_t stream = nullptr) { - Status status = initialize(args, workspace); + Status status = initialize(args, workspace, stream); if (status == Status::kSuccess) { status = run(stream); diff --git a/include/cutlass/gemm/device/gemm_array.h b/include/cutlass/gemm/device/gemm_array.h index 12bc300ff..be7be25de 100644 --- a/include/cutlass/gemm/device/gemm_array.h +++ b/include/cutlass/gemm/device/gemm_array.h @@ -473,7 +473,7 @@ public: void *workspace = nullptr, cudaStream_t stream = nullptr) { - Status status = initialize(args, workspace); + Status status = initialize(args, workspace, stream); if (status == Status::kSuccess) { status = run(stream); @@ -700,7 +700,7 @@ public: /// Initializes GEMM state from arguments. Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { - return underlying_operator_.initialize(to_underlying_arguments(args), workspace); + return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); } /// Lightweight update given a subset of arguments @@ -726,7 +726,7 @@ public: void *workspace = nullptr, cudaStream_t stream = nullptr) { - Status status = initialize(args, workspace); + Status status = initialize(args, workspace, stream); if (status == Status::kSuccess) { status = run(stream); diff --git a/include/cutlass/gemm/device/gemm_batched.h b/include/cutlass/gemm/device/gemm_batched.h index 8f09b4a77..e10932704 100644 --- a/include/cutlass/gemm/device/gemm_batched.h +++ b/include/cutlass/gemm/device/gemm_batched.h @@ -451,7 +451,7 @@ public: void *workspace = nullptr, cudaStream_t stream = nullptr) { - Status status = initialize(args, workspace); + Status status = initialize(args, workspace, stream); if (status == Status::kSuccess) { status = run(stream); @@ -666,7 +666,7 @@ public: /// Initializes GEMM state from arguments. Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { - return underlying_operator_.initialize(to_underlying_arguments(args), workspace); + return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); } /// Lightweight update given a subset of arguments @@ -692,7 +692,7 @@ public: void *workspace = nullptr, cudaStream_t stream = nullptr) { - Status status = initialize(args, workspace); + Status status = initialize(args, workspace, stream); if (status == Status::kSuccess) { status = run(stream); diff --git a/include/cutlass/gemm/device/gemm_complex.h b/include/cutlass/gemm/device/gemm_complex.h index 70e0b46a3..4b0fcaa98 100644 --- a/include/cutlass/gemm/device/gemm_complex.h +++ b/include/cutlass/gemm/device/gemm_complex.h @@ -465,7 +465,7 @@ public: void *workspace = nullptr, cudaStream_t stream = nullptr) { - Status status = initialize(args, workspace); + Status status = initialize(args, workspace, stream); if (status == Status::kSuccess) { status = run(stream); @@ -674,7 +674,7 @@ public: /// Initializes GEMM state from arguments. Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { - return underlying_operator_.initialize(to_underlying_arguments(args), workspace); + return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); } /// Lightweight update given a subset of arguments @@ -700,7 +700,7 @@ public: void *workspace = nullptr, cudaStream_t stream = nullptr) { - Status status = initialize(args, workspace); + Status status = initialize(args, workspace, stream); if (status == Status::kSuccess) { status = run(stream); diff --git a/include/cutlass/gemm/device/gemm_sparse.h b/include/cutlass/gemm/device/gemm_sparse.h index 04e2dd667..b37f585a9 100644 --- a/include/cutlass/gemm/device/gemm_sparse.h +++ b/include/cutlass/gemm/device/gemm_sparse.h @@ -498,7 +498,7 @@ public: void *workspace = nullptr, cudaStream_t stream = nullptr) { - Status status = initialize(args, workspace); + Status status = initialize(args, workspace, stream); if (status == Status::kSuccess) { status = run(stream); diff --git a/include/cutlass/gemm/device/gemm_universal_base.h b/include/cutlass/gemm/device/gemm_universal_base.h index 74c519a44..f15b3589a 100644 --- a/include/cutlass/gemm/device/gemm_universal_base.h +++ b/include/cutlass/gemm/device/gemm_universal_base.h @@ -347,7 +347,7 @@ public: return Status::kErrorWorkspaceNull; } - params_.update(args, workspace); + params_.update(args, workspace, stream); return Status::kSuccess; } diff --git a/include/cutlass/reduction/device/reduce_split_k.h b/include/cutlass/reduction/device/reduce_split_k.h index 4c044a4ca..f8558643a 100644 --- a/include/cutlass/reduction/device/reduce_split_k.h +++ b/include/cutlass/reduction/device/reduce_split_k.h @@ -197,7 +197,7 @@ public: void *workspace = nullptr, cudaStream_t stream = nullptr) { - Status status = initialize(args, workspace); + Status status = initialize(args, workspace, stream); if (status == Status::kSuccess) { status = run(stream); From 6a6b4028bd3a3f95721587dced4a97093eec4ab2 Mon Sep 17 00:00:00 2001 From: Peter Han Date: Tue, 23 Mar 2021 23:20:40 +0800 Subject: [PATCH 2/2] Revert wrong fix of params.update in GemmUniversalBase Signed-off-by: Peter Han --- include/cutlass/gemm/device/gemm_universal_base.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/cutlass/gemm/device/gemm_universal_base.h b/include/cutlass/gemm/device/gemm_universal_base.h index f15b3589a..74c519a44 100644 --- a/include/cutlass/gemm/device/gemm_universal_base.h +++ b/include/cutlass/gemm/device/gemm_universal_base.h @@ -347,7 +347,7 @@ public: return Status::kErrorWorkspaceNull; } - params_.update(args, workspace, stream); + params_.update(args, workspace); return Status::kSuccess; }