mirror of
https://github.com/NVIDIA/nvbench.git
synced 2026-04-19 22:38:52 +00:00
Improvements to readability of examples per PR review
This commit is contained in:
@@ -61,10 +61,6 @@ def segmented_reduce(state: nvbench.State):
|
||||
dev_id = state.get_device()
|
||||
cp_stream = as_cp_ExternalStream(state.get_stream(), dev_id)
|
||||
|
||||
with cp_stream:
|
||||
rng = cp.random.default_rng()
|
||||
mat = rng.integers(low=-31, high=32, dtype=np.int32, size=(n_rows, n_cols))
|
||||
|
||||
def add_op(a, b):
|
||||
return a + b
|
||||
|
||||
@@ -84,6 +80,8 @@ def segmented_reduce(state: nvbench.State):
|
||||
|
||||
h_init = np.zeros(tuple(), dtype=np.int32)
|
||||
with cp_stream:
|
||||
rng = cp.random.default_rng()
|
||||
mat = rng.integers(low=-31, high=32, dtype=np.int32, size=(n_rows, n_cols))
|
||||
d_input = mat
|
||||
d_output = cp.empty(n_rows, dtype=d_input.dtype)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user