CUTLASS 3.2 (#1024)

* CUTLASS 3.2
This commit is contained in:
ANIKET SHIVAM
2023-08-07 14:50:32 -10:00
committed by GitHub
parent a0d787b746
commit 4575443d44
392 changed files with 47559 additions and 7940 deletions

View File

@@ -1,6 +1,7 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"id": "1ef96b3f",
"metadata": {},
@@ -12,6 +13,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "962324fd",
"metadata": {},
@@ -31,8 +33,8 @@
"\n",
"import cutlass\n",
"\n",
"# This controls whether ther C++ GEMM declaration will be printed at each step. Set to `false` to\n",
"# omit this information.\n",
"# This controls whether the C++ GEMM declaration will be printed at each step. \n",
"# Set to `False` to omit this information.\n",
"print_module = True\n",
"\n",
"m = 128\n",
@@ -60,6 +62,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "f2c7bf48",
"metadata": {},
@@ -87,6 +90,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "4a5856de",
"metadata": {},
@@ -95,6 +99,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "945478ef",
"metadata": {},
@@ -114,6 +119,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "ee5cbbbe",
"metadata": {},
@@ -122,6 +128,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "b6c86493",
"metadata": {},
@@ -143,6 +150,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "6d27c575",
"metadata": {},
@@ -167,6 +175,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "639dcb59",
"metadata": {},
@@ -185,6 +194,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "0cce1eae",
"metadata": {},
@@ -219,6 +229,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "52a4e318",
"metadata": {},
@@ -245,6 +256,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "dc3ad875",
"metadata": {},
@@ -267,6 +279,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "c5a8b534",
"metadata": {},
@@ -281,13 +294,23 @@
"metadata": {},
"outputs": [],
"source": [
"# Stream K is only supported pre-SM90 (at least when this example was written)\n",
"if plan.cc != 90:\n",
"# Stream K is exposed through the threadblock swizzle method for pre-SM90 kernels,\n",
"# and via the tile_scheduler attribute of the TileDescription for post-SM90 kernels\n",
"if plan.cc < 90:\n",
" plan.swizzling_functor = cutlass.swizzle.ThreadblockSwizzleStreamK\n",
" plan.run(tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, print_module=print_module)\n",
"else:\n",
" # Stream-K is currently only supported for warp-specialized cooperative kernels\n",
" td.kernel_schedule = cutlass.KernelScheduleType.TmaWarpSpecializedCooperative\n",
" td.epilogue_schedule = cutlass.EpilogueScheduleType.TmaWarpSpecializedCooperative\n",
" td.tile_scheduler = cutlass.TileSchedulerType.StreamK\n",
"\n",
" plan.compile(td)\n",
" plan.run(tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, print_module=print_module)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "5a8ba2ba",
"metadata": {},
@@ -327,7 +350,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
"version": "3.6.9"
},
"vscode": {
"interpreter": {