Merge pull request #2919 from pbelevich/patch-1

Refactor binary_op functions to remove unused result parameter
This commit is contained in:
drazi
2026-02-11 11:48:58 +08:00
committed by GitHub

View File

@@ -167,7 +167,7 @@
"outputs": [],
"source": [
"@cute.jit\n",
"def binary_op_1(res: cute.Tensor, a: cute.Tensor, b: cute.Tensor):\n",
"def binary_op_1(a: cute.Tensor, b: cute.Tensor):\n",
" a_vec = a.load()\n",
" b_vec = b.load()\n",
"\n",
@@ -184,7 +184,7 @@
" cute.print_tensor(div_res) # prints [0.500000, 0.500000, 0.500000]\n",
"\n",
" floor_div_res = a_vec // b_vec\n",
" cute.print_tensor(res) # prints [0.000000, 0.000000, 0.000000]\n",
" cute.print_tensor(floor_div_res) # prints [0.000000, 0.000000, 0.000000]\n",
"\n",
" mod_res = a_vec % b_vec\n",
" cute.print_tensor(mod_res) # prints [1.000000, 1.000000, 1.000000]\n",
@@ -194,8 +194,7 @@
"a.fill(1.0)\n",
"b = np.empty((3,), dtype=np.float32)\n",
"b.fill(2.0)\n",
"res = np.empty((3,), dtype=np.float32)\n",
"binary_op_1(from_dlpack(res), from_dlpack(a), from_dlpack(b))"
"binary_op_1(from_dlpack(a), from_dlpack(b))"
]
},
{
@@ -205,7 +204,7 @@
"outputs": [],
"source": [
"@cute.jit\n",
"def binary_op_2(res: cute.Tensor, a: cute.Tensor, c: cutlass.Constexpr):\n",
"def binary_op_2(a: cute.Tensor, c: cutlass.Constexpr):\n",
" a_vec = a.load()\n",
"\n",
" add_res = a_vec + c\n",
@@ -230,8 +229,7 @@
"a = np.empty((3,), dtype=np.float32)\n",
"a.fill(1.0)\n",
"c = 2.0\n",
"res = np.empty((3,), dtype=np.float32)\n",
"binary_op_2(from_dlpack(res), from_dlpack(a), c)"
"binary_op_2(from_dlpack(a), c)"
]
},
{