mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-19 22:38:56 +00:00
Merge pull request #2919 from pbelevich/patch-1
Refactor binary_op functions to remove unused result parameter
This commit is contained in:
@@ -167,7 +167,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def binary_op_1(res: cute.Tensor, a: cute.Tensor, b: cute.Tensor):\n",
|
||||
"def binary_op_1(a: cute.Tensor, b: cute.Tensor):\n",
|
||||
" a_vec = a.load()\n",
|
||||
" b_vec = b.load()\n",
|
||||
"\n",
|
||||
@@ -184,7 +184,7 @@
|
||||
" cute.print_tensor(div_res) # prints [0.500000, 0.500000, 0.500000]\n",
|
||||
"\n",
|
||||
" floor_div_res = a_vec // b_vec\n",
|
||||
" cute.print_tensor(res) # prints [0.000000, 0.000000, 0.000000]\n",
|
||||
" cute.print_tensor(floor_div_res) # prints [0.000000, 0.000000, 0.000000]\n",
|
||||
"\n",
|
||||
" mod_res = a_vec % b_vec\n",
|
||||
" cute.print_tensor(mod_res) # prints [1.000000, 1.000000, 1.000000]\n",
|
||||
@@ -194,8 +194,7 @@
|
||||
"a.fill(1.0)\n",
|
||||
"b = np.empty((3,), dtype=np.float32)\n",
|
||||
"b.fill(2.0)\n",
|
||||
"res = np.empty((3,), dtype=np.float32)\n",
|
||||
"binary_op_1(from_dlpack(res), from_dlpack(a), from_dlpack(b))"
|
||||
"binary_op_1(from_dlpack(a), from_dlpack(b))"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -205,7 +204,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def binary_op_2(res: cute.Tensor, a: cute.Tensor, c: cutlass.Constexpr):\n",
|
||||
"def binary_op_2(a: cute.Tensor, c: cutlass.Constexpr):\n",
|
||||
" a_vec = a.load()\n",
|
||||
"\n",
|
||||
" add_res = a_vec + c\n",
|
||||
@@ -230,8 +229,7 @@
|
||||
"a = np.empty((3,), dtype=np.float32)\n",
|
||||
"a.fill(1.0)\n",
|
||||
"c = 2.0\n",
|
||||
"res = np.empty((3,), dtype=np.float32)\n",
|
||||
"binary_op_2(from_dlpack(res), from_dlpack(a), c)"
|
||||
"binary_op_2(from_dlpack(a), c)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user