mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-14 02:02:25 +00:00
923 lines
72 KiB
HTML
923 lines
72 KiB
HTML
<!doctype html>
|
|
<html class="no-js" lang="en">
|
|
<head><meta charset="utf-8"/>
|
|
<meta name="viewport" content="width=device-width,initial-scale=1"/>
|
|
<meta name="color-scheme" content="light dark"><link rel="index" title="Index" href="../../../genindex.html" /><link rel="search" title="Search" href="../../../search.html" />
|
|
<link rel="canonical" href="docs/_modules/cutlass/emit/pytorch.html" />
|
|
|
|
<!-- Generated with Sphinx 6.1.3 and Furo 2023.03.27 -->
|
|
<title>cutlass.emit.pytorch - CUTLASS Python</title>
|
|
<link rel="stylesheet" type="text/css" href="../../../_static/pygments.css" />
|
|
<link rel="stylesheet" type="text/css" href="../../../_static/styles/furo.css?digest=fad236701ea90a88636c2a8c73b44ae642ed2a53" />
|
|
<link rel="stylesheet" type="text/css" href="../../../_static/copybutton.css" />
|
|
<link rel="stylesheet" type="text/css" href="../../../_static/tabs.css" />
|
|
<link rel="stylesheet" type="text/css" href="../../../_static/styles/furo-extensions.css?digest=30d1aed668e5c3a91c3e3bf6a60b675221979f0e" />
|
|
|
|
|
|
|
|
|
|
<style>
|
|
body {
|
|
--color-code-background: #eeffcc;
|
|
--color-code-foreground: black;
|
|
--color-brand-primary: #76B900;
|
|
--color-brand-content: #76B900;
|
|
|
|
}
|
|
@media not print {
|
|
body[data-theme="dark"] {
|
|
--color-code-background: #272822;
|
|
--color-code-foreground: #f8f8f2;
|
|
--color-brand-primary: #76B900;
|
|
--color-brand-content: #76B900;
|
|
|
|
}
|
|
@media (prefers-color-scheme: dark) {
|
|
body:not([data-theme="light"]) {
|
|
--color-code-background: #272822;
|
|
--color-code-foreground: #f8f8f2;
|
|
--color-brand-primary: #76B900;
|
|
--color-brand-content: #76B900;
|
|
|
|
}
|
|
}
|
|
}
|
|
</style></head>
|
|
<body>
|
|
|
|
<script>
|
|
document.body.dataset.theme = localStorage.getItem("theme") || "auto";
|
|
</script>
|
|
|
|
|
|
<svg xmlns="http://www.w3.org/2000/svg" style="display: none;">
|
|
<symbol id="svg-toc" viewBox="0 0 24 24">
|
|
<title>Contents</title>
|
|
<svg stroke="currentColor" fill="currentColor" stroke-width="0" viewBox="0 0 1024 1024">
|
|
<path d="M408 442h480c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8H408c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8zm-8 204c0 4.4 3.6 8 8 8h480c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8H408c-4.4 0-8 3.6-8 8v56zm504-486H120c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8h784c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8zm0 632H120c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8h784c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8zM115.4 518.9L271.7 642c5.8 4.6 14.4.5 14.4-6.9V388.9c0-7.4-8.5-11.5-14.4-6.9L115.4 505.1a8.74 8.74 0 0 0 0 13.8z"/>
|
|
</svg>
|
|
</symbol>
|
|
<symbol id="svg-menu" viewBox="0 0 24 24">
|
|
<title>Menu</title>
|
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
|
stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="feather-menu">
|
|
<line x1="3" y1="12" x2="21" y2="12"></line>
|
|
<line x1="3" y1="6" x2="21" y2="6"></line>
|
|
<line x1="3" y1="18" x2="21" y2="18"></line>
|
|
</svg>
|
|
</symbol>
|
|
<symbol id="svg-arrow-right" viewBox="0 0 24 24">
|
|
<title>Expand</title>
|
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
|
stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="feather-chevron-right">
|
|
<polyline points="9 18 15 12 9 6"></polyline>
|
|
</svg>
|
|
</symbol>
|
|
<symbol id="svg-sun" viewBox="0 0 24 24">
|
|
<title>Light mode</title>
|
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
|
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="feather-sun">
|
|
<circle cx="12" cy="12" r="5"></circle>
|
|
<line x1="12" y1="1" x2="12" y2="3"></line>
|
|
<line x1="12" y1="21" x2="12" y2="23"></line>
|
|
<line x1="4.22" y1="4.22" x2="5.64" y2="5.64"></line>
|
|
<line x1="18.36" y1="18.36" x2="19.78" y2="19.78"></line>
|
|
<line x1="1" y1="12" x2="3" y2="12"></line>
|
|
<line x1="21" y1="12" x2="23" y2="12"></line>
|
|
<line x1="4.22" y1="19.78" x2="5.64" y2="18.36"></line>
|
|
<line x1="18.36" y1="5.64" x2="19.78" y2="4.22"></line>
|
|
</svg>
|
|
</symbol>
|
|
<symbol id="svg-moon" viewBox="0 0 24 24">
|
|
<title>Dark mode</title>
|
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
|
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="icon-tabler-moon">
|
|
<path stroke="none" d="M0 0h24v24H0z" fill="none" />
|
|
<path d="M12 3c.132 0 .263 0 .393 0a7.5 7.5 0 0 0 7.92 12.446a9 9 0 1 1 -8.313 -12.454z" />
|
|
</svg>
|
|
</symbol>
|
|
<symbol id="svg-sun-half" viewBox="0 0 24 24">
|
|
<title>Auto light/dark mode</title>
|
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
|
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="icon-tabler-shadow">
|
|
<path stroke="none" d="M0 0h24v24H0z" fill="none"/>
|
|
<circle cx="12" cy="12" r="9" />
|
|
<path d="M13 12h5" />
|
|
<path d="M13 15h4" />
|
|
<path d="M13 18h1" />
|
|
<path d="M13 9h4" />
|
|
<path d="M13 6h1" />
|
|
</svg>
|
|
</symbol>
|
|
</svg>
|
|
|
|
<input type="checkbox" class="sidebar-toggle" name="__navigation" id="__navigation">
|
|
<input type="checkbox" class="sidebar-toggle" name="__toc" id="__toc">
|
|
<label class="overlay sidebar-overlay" for="__navigation">
|
|
<div class="visually-hidden">Hide navigation sidebar</div>
|
|
</label>
|
|
<label class="overlay toc-overlay" for="__toc">
|
|
<div class="visually-hidden">Hide table of contents sidebar</div>
|
|
</label>
|
|
|
|
|
|
|
|
<div class="page">
|
|
<header class="mobile-header">
|
|
<div class="header-left">
|
|
<label class="nav-overlay-icon" for="__navigation">
|
|
<div class="visually-hidden">Toggle site navigation sidebar</div>
|
|
<i class="icon"><svg><use href="#svg-menu"></use></svg></i>
|
|
</label>
|
|
</div>
|
|
<div class="header-center">
|
|
<a href="../../../index.html"><div class="brand">CUTLASS Python</div></a>
|
|
</div>
|
|
<div class="header-right">
|
|
<div class="theme-toggle-container theme-toggle-header">
|
|
<button class="theme-toggle">
|
|
<div class="visually-hidden">Toggle Light / Dark / Auto color theme</div>
|
|
<svg class="theme-icon-when-auto"><use href="#svg-sun-half"></use></svg>
|
|
<svg class="theme-icon-when-dark"><use href="#svg-moon"></use></svg>
|
|
<svg class="theme-icon-when-light"><use href="#svg-sun"></use></svg>
|
|
</button>
|
|
</div>
|
|
<label class="toc-overlay-icon toc-header-icon no-toc" for="__toc">
|
|
<div class="visually-hidden">Toggle table of contents sidebar</div>
|
|
<i class="icon"><svg><use href="#svg-toc"></use></svg></i>
|
|
</label>
|
|
</div>
|
|
</header>
|
|
<aside class="sidebar-drawer">
|
|
<div class="sidebar-container">
|
|
|
|
<div class="sidebar-sticky"><a class="sidebar-brand" href="../../../index.html">
|
|
|
|
<div class="sidebar-logo-container">
|
|
<img class="sidebar-logo only-light" src="../../../_static/cutlass-logo-small.png" alt="Light Logo"/>
|
|
<img class="sidebar-logo only-dark" src="../../../_static/cutlass-logo-small.png" alt="Dark Logo"/>
|
|
</div>
|
|
|
|
<span class="sidebar-brand-text">CUTLASS Python</span>
|
|
|
|
</a><form class="sidebar-search-container" method="get" action="../../../search.html" role="search">
|
|
<input class="sidebar-search" placeholder="Search" name="q" aria-label="Search">
|
|
<input type="hidden" name="check_keywords" value="yes">
|
|
<input type="hidden" name="area" value="default">
|
|
</form>
|
|
<div id="searchbox"></div><div class="sidebar-scroll"><div class="sidebar-tree">
|
|
<ul>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../index.html">Home</a></li>
|
|
</ul>
|
|
<p class="caption" role="heading"><span class="caption-text">Getting Started:</span></p>
|
|
<ul>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../install.html">Installation</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../externals/00_basic_gemm.html">Getting Started</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../contribute.html">Contributing</a></li>
|
|
</ul>
|
|
<p class="caption" role="heading"><span class="caption-text">Python Documentation:</span></p>
|
|
<ul>
|
|
<li class="toctree-l1 has-children"><a class="reference internal" href="../../../modules.html">CUTLASS Python API</a><input class="toctree-checkbox" id="toctree-checkbox-1" name="toctree-checkbox-1" role="switch" type="checkbox"/><label for="toctree-checkbox-1"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul>
|
|
<li class="toctree-l2 has-children"><a class="reference internal" href="../../../cutlass.html">CUTLASS</a><input class="toctree-checkbox" id="toctree-checkbox-2" name="toctree-checkbox-2" role="switch" type="checkbox"/><label for="toctree-checkbox-2"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul>
|
|
<li class="toctree-l3"><a class="reference internal" href="../../../cutlass.emit.html">Emitters</a></li>
|
|
<li class="toctree-l3"><a class="reference internal" href="../../../cutlass.op.html">Operations</a></li>
|
|
<li class="toctree-l3"><a class="reference internal" href="../../../cutlass.utils.html">Utilities</a></li>
|
|
</ul>
|
|
</li>
|
|
</ul>
|
|
</li>
|
|
</ul>
|
|
<p class="caption" role="heading"><span class="caption-text">Examples and Tutorials:</span></p>
|
|
<ul>
|
|
<li class="toctree-l1 has-children"><a class="reference internal" href="../../../examples.html">Examples</a><input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" role="switch" type="checkbox"/><label for="toctree-checkbox-3"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../externals/00_basic_gemm.html">Basic GEMM</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../externals/01_epilogue.html">Epilogue</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../externals/02_pytorch_extension_grouped_gemm.html">PyTorch Extension</a></li>
|
|
</ul>
|
|
</li>
|
|
</ul>
|
|
<p class="caption" role="heading"><span class="caption-text">Reference:</span></p>
|
|
<ul>
|
|
<li class="toctree-l1"><a class="reference external" href="https://github.com/NVIDIA/cutlass">Github</a></li>
|
|
</ul>
|
|
|
|
</div>
|
|
</div>
|
|
|
|
</div>
|
|
|
|
</div>
|
|
</aside>
|
|
<div class="main">
|
|
<div class="content">
|
|
<div class="article-container">
|
|
<a href="#" class="back-to-top muted-link">
|
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24">
|
|
<path d="M13 20h-2V8l-5.5 5.5-1.42-1.42L12 4.16l7.92 7.92-1.42 1.42L13 8v12z"></path>
|
|
</svg>
|
|
<span>Back to top</span>
|
|
</a>
|
|
<div class="content-icon-container">
|
|
<div class="theme-toggle-container theme-toggle-content">
|
|
<button class="theme-toggle">
|
|
<div class="visually-hidden">Toggle Light / Dark / Auto color theme</div>
|
|
<svg class="theme-icon-when-auto"><use href="#svg-sun-half"></use></svg>
|
|
<svg class="theme-icon-when-dark"><use href="#svg-moon"></use></svg>
|
|
<svg class="theme-icon-when-light"><use href="#svg-sun"></use></svg>
|
|
</button>
|
|
</div>
|
|
<label class="toc-overlay-icon toc-content-icon no-toc" for="__toc">
|
|
<div class="visually-hidden">Toggle table of contents sidebar</div>
|
|
<i class="icon"><svg><use href="#svg-toc"></use></svg></i>
|
|
</label>
|
|
</div>
|
|
<article role="main">
|
|
<h1>Source code for cutlass.emit.pytorch</h1><div class="highlight"><pre>
|
|
<span></span><span class="c1">#################################################################################################</span>
|
|
<span class="c1">#</span>
|
|
<span class="c1"># Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.</span>
|
|
<span class="c1"># SPDX-License-Identifier: BSD-3-Clause</span>
|
|
<span class="c1">#</span>
|
|
<span class="c1"># Redistribution and use in source and binary forms, with or without</span>
|
|
<span class="c1"># modification, are permitted provided that the following conditions are met:</span>
|
|
<span class="c1">#</span>
|
|
<span class="c1"># 1. Redistributions of source code must retain the above copyright notice, this</span>
|
|
<span class="c1"># list of conditions and the following disclaimer.</span>
|
|
<span class="c1">#</span>
|
|
<span class="c1"># 2. Redistributions in binary form must reproduce the above copyright notice,</span>
|
|
<span class="c1"># this list of conditions and the following disclaimer in the documentation</span>
|
|
<span class="c1"># and/or other materials provided with the distribution.</span>
|
|
<span class="c1">#</span>
|
|
<span class="c1"># 3. Neither the name of the copyright holder nor the names of its</span>
|
|
<span class="c1"># contributors may be used to endorse or promote products derived from</span>
|
|
<span class="c1"># this software without specific prior written permission.</span>
|
|
<span class="c1">#</span>
|
|
<span class="c1"># THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"</span>
|
|
<span class="c1"># AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE</span>
|
|
<span class="c1"># IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE</span>
|
|
<span class="c1"># DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE</span>
|
|
<span class="c1"># FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL</span>
|
|
<span class="c1"># DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR</span>
|
|
<span class="c1"># SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER</span>
|
|
<span class="c1"># CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,</span>
|
|
<span class="c1"># OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE</span>
|
|
<span class="c1"># OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.</span>
|
|
<span class="c1">#</span>
|
|
<span class="c1">#################################################################################################</span>
|
|
|
|
<span class="sd">"""</span>
|
|
<span class="sd">Utilities for generating source for building a PyTorch CUDA extension that using a CUTLASS kernel.</span>
|
|
<span class="sd">If specified, the extension can be JIT compiled via PyTorch's ``cpp_extension.load`` method.</span>
|
|
|
|
<span class="sd">Example usage with JIT compilation:</span>
|
|
|
|
<span class="sd">.. highlight:: python</span>
|
|
<span class="sd">.. code-block:: python</span>
|
|
|
|
<span class="sd"> plan = cutlass.op.Gemm(element=torch.float32, layout=cutlass.LayoutType.RowMajor)</span>
|
|
<span class="sd"> op = plan.construct()</span>
|
|
<span class="sd"> mod = cutlass.emit.pytorch(op, 'cutlass_gemm', 80, jit=True)</span>
|
|
|
|
<span class="sd"> # Generate inputs for the GEMM</span>
|
|
<span class="sd"> A, B, C = [torch.ones((512, 512)).to('cuda') for _ in range(3)]</span>
|
|
|
|
<span class="sd"> # Run the module</span>
|
|
<span class="sd"> D = mod.run(A, B, C)</span>
|
|
|
|
|
|
<span class="sd">Example usage without JIT compilation:</span>
|
|
|
|
<span class="sd">.. highlight:: python</span>
|
|
<span class="sd">.. code-block:: python</span>
|
|
|
|
<span class="sd"> plan = cutlass.op.Gemm(element=torch.float32, layout=cutlass.LayoutType.RowMajor)</span>
|
|
<span class="sd"> op = plan.construct()</span>
|
|
<span class="sd"> cutlass.emit.pytorch(op, 'cutlass_gemm', 80, jit=False, sourcedir='output')</span>
|
|
|
|
<span class="sd">After this call, the directory ``output`` contains ``setup.py``,</span>
|
|
<span class="sd">``cutlass_gemm.cpp``, and ``cutlass_gemm_kernel.cu``. The module can be built from</span>
|
|
<span class="sd">within ``output`` by running: ``TORCH_CUDA_ARCH_LIST="8.0" python setup.py develop --user``.</span>
|
|
|
|
<span class="sd">The module can later be used in Python via:</span>
|
|
|
|
<span class="sd">.. highlight:: python</span>
|
|
<span class="sd">.. code-block:: python</span>
|
|
|
|
<span class="sd"> import torch</span>
|
|
<span class="sd"> import cutlass_gemm</span>
|
|
|
|
<span class="sd"> # Generate inputs for the GEMM</span>
|
|
<span class="sd"> A, B, C = [torch.ones((512, 512)).to('cuda') for _ in range(3)]</span>
|
|
|
|
<span class="sd"> # Run the module</span>
|
|
<span class="sd"> D = cutlass_gemm.run(A, B, C)</span>
|
|
<span class="sd">"""</span>
|
|
|
|
<span class="kn">import</span> <span class="nn">logging</span>
|
|
<span class="kn">import</span> <span class="nn">os</span>
|
|
|
|
<span class="kn">import</span> <span class="nn">cutlass_bindings</span>
|
|
|
|
<span class="kn">from</span> <span class="nn">cutlass</span> <span class="kn">import</span> <span class="n">CUTLASS_PATH</span><span class="p">,</span> <span class="n">logger</span><span class="p">,</span> <span class="n">swizzle</span>
|
|
<span class="kn">from</span> <span class="nn">cutlass.backend.gemm_operation</span> <span class="kn">import</span> <span class="n">GemmOperationGrouped</span><span class="p">,</span> <span class="n">GemmOperationUniversal</span>
|
|
<span class="kn">from</span> <span class="nn">cutlass.backend.library</span> <span class="kn">import</span> <span class="n">ApiVersion</span>
|
|
<span class="kn">from</span> <span class="nn">cutlass.backend.utils.software</span> <span class="kn">import</span> <span class="n">CheckPackages</span><span class="p">,</span> <span class="n">SubstituteTemplate</span>
|
|
<span class="kn">from</span> <span class="nn">cutlass.emit</span> <span class="kn">import</span> <span class="n">common</span>
|
|
|
|
<span class="n">torch_available</span> <span class="o">=</span> <span class="n">CheckPackages</span><span class="p">()</span><span class="o">.</span><span class="n">check_torch</span><span class="p">()</span>
|
|
<span class="k">if</span> <span class="n">torch_available</span><span class="p">:</span>
|
|
<span class="kn">import</span> <span class="nn">torch</span>
|
|
|
|
|
|
<span class="n">_PYTORCH_CUDA_TEMPLATE</span> <span class="o">=</span> <span class="n">common</span><span class="o">.</span><span class="n">_CSTYLE_AUTOGEN_COMMENT</span> <span class="o">+</span> <span class="s2">"""</span>
|
|
<span class="s2">#include <torch/extension.h></span>
|
|
<span class="s2">#include <ATen/ATen.h></span>
|
|
|
|
<span class="s2">#include "cutlass/cutlass.h"</span>
|
|
<span class="s2">#include "cutlass/util/device_memory.h"</span>
|
|
|
|
<span class="s2">$</span><span class="si">{includes}</span>
|
|
<span class="s2">$</span><span class="si">{declaration}</span>
|
|
<span class="s2">$</span><span class="si">{impl}</span>
|
|
<span class="s2">"""</span>
|
|
|
|
<span class="n">_PYTORCH_GEMM_CPP_TEMPLATE</span> <span class="o">=</span> <span class="n">common</span><span class="o">.</span><span class="n">_CSTYLE_AUTOGEN_COMMENT</span> <span class="o">+</span> <span class="s2">"""</span>
|
|
<span class="s2">#include <torch/extension.h></span>
|
|
<span class="s2">#include <ATen/ATen.h></span>
|
|
<span class="s2">#include <pybind11/stl.h></span>
|
|
|
|
<span class="s2">// CUDA forward declarations</span>
|
|
<span class="s2">at::Tensor $</span><span class="si">{name}</span><span class="s2">_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt, float alpha=1.f, float beta=0.f);</span>
|
|
|
|
<span class="s2">// C++ interface</span>
|
|
<span class="s2">at::Tensor $</span><span class="si">{name}</span><span class="s2">(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt, float alpha=1.f, float beta=0.f) {</span>
|
|
<span class="s2"> return $</span><span class="si">{name}</span><span class="s2">_kernel(A, B, C, alpha, beta);</span>
|
|
<span class="s2">}</span>
|
|
|
|
<span class="s2">PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {</span>
|
|
<span class="s2"> m.def("run", py::overload_cast<const at::Tensor&, const at::Tensor&, at::optional<const at::Tensor>, float, float>(&$</span><span class="si">{name}</span><span class="s2">), py::arg("A"), py::arg("B"), py::arg("C") = nullptr, py::arg("alpha") = 1.f, py::arg("beta") = 0.f);</span>
|
|
<span class="s2">}</span>
|
|
<span class="s2">"""</span>
|
|
|
|
<span class="n">_PYTORCH_GROUPED_GEMM_CPP_TEMPLATE</span> <span class="o">=</span> <span class="n">common</span><span class="o">.</span><span class="n">_CSTYLE_AUTOGEN_COMMENT</span> <span class="o">+</span> <span class="s2">"""</span>
|
|
<span class="s2">#include <torch/extension.h></span>
|
|
<span class="s2">#include <ATen/ATen.h></span>
|
|
<span class="s2">#include <pybind11/stl.h></span>
|
|
|
|
<span class="s2">// CUDA forward declarations</span>
|
|
<span class="s2">std::vector<at::Tensor> $</span><span class="si">{name}</span><span class="s2">_kernel(const std::vector<at::Tensor>& A, const std::vector<at::Tensor>& B, at::optional<const std::vector<at::Tensor>> C=at::nullopt, float alpha=1.f, float beta=0.f);</span>
|
|
|
|
<span class="s2">// C++ interface</span>
|
|
<span class="s2">std::vector<at::Tensor> $</span><span class="si">{name}</span><span class="s2">(const std::vector<at::Tensor>& A, const std::vector<at::Tensor>& B, at::optional<const std::vector<at::Tensor>> C=at::nullopt, float alpha=1.f, float beta=0.f) {</span>
|
|
<span class="s2"> return $</span><span class="si">{name}</span><span class="s2">_kernel(A, B, C, alpha, beta);</span>
|
|
<span class="s2">}</span>
|
|
|
|
<span class="s2">PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {</span>
|
|
<span class="s2"> m.def("run", py::overload_cast<const std::vector<at::Tensor>&, const std::vector<at::Tensor>&, at::optional<const std::vector<at::Tensor>>, float, float>(&$</span><span class="si">{name}</span><span class="s2">),</span>
|
|
<span class="s2"> py::arg("A"), py::arg("B"), py::arg("C") = nullptr, py::arg("alpha") = 1.f, py::arg("beta") = 0.f);</span>
|
|
<span class="s2">}</span>
|
|
<span class="s2">"""</span>
|
|
|
|
<span class="n">_PYTORCH_GEMM_INCLUDES</span> <span class="o">=</span> <span class="p">{</span>
|
|
<span class="n">ApiVersion</span><span class="o">.</span><span class="n">v2x</span><span class="p">:</span> <span class="s2">"""</span>
|
|
<span class="s2">#include "cutlass/gemm/device/gemm_universal.h"</span>
|
|
<span class="s2">"""</span><span class="p">,</span>
|
|
<span class="n">ApiVersion</span><span class="o">.</span><span class="n">v3x</span><span class="p">:</span> <span class="s2">"""</span>
|
|
<span class="s2">#include "cutlass/gemm/device/gemm_universal_adapter.h"</span>
|
|
<span class="s2">#include "cutlass/gemm/collective/collective_builder.hpp"</span>
|
|
<span class="s2">#include "cutlass/gemm/device/gemm_universal_adapter.h"</span>
|
|
<span class="s2">#include "cutlass/gemm/kernel/gemm_universal.hpp"</span>
|
|
<span class="s2">#include "cutlass/epilogue/collective/default_epilogue.hpp"</span>
|
|
<span class="s2">#include "cutlass/util/packed_stride.hpp"</span>
|
|
<span class="s2">"""</span><span class="p">,</span>
|
|
<span class="p">}</span>
|
|
|
|
<span class="n">_PYTORCH_GROUPED_GEMM_INCLUDES</span> <span class="o">=</span> <span class="s2">"""</span>
|
|
<span class="s2">#include "cutlass/gemm/kernel/default_gemm_grouped.h"</span>
|
|
<span class="s2">#include "cutlass/gemm/device/gemm_grouped.h"</span>
|
|
<span class="s2">"""</span>
|
|
|
|
<span class="n">_CUTLASS_TYPE_TO_TORCH_TYPE</span> <span class="o">=</span> <span class="p">{</span>
|
|
<span class="n">cutlass_bindings</span><span class="o">.</span><span class="n">float16</span><span class="p">:</span> <span class="s2">"torch::kF16"</span><span class="p">,</span>
|
|
<span class="n">cutlass_bindings</span><span class="o">.</span><span class="n">float32</span><span class="p">:</span> <span class="s2">"torch::kF32"</span><span class="p">,</span>
|
|
<span class="n">cutlass_bindings</span><span class="o">.</span><span class="n">float64</span><span class="p">:</span> <span class="s2">"torch::kF64"</span><span class="p">,</span>
|
|
<span class="n">cutlass_bindings</span><span class="o">.</span><span class="n">int8</span><span class="p">:</span> <span class="s2">"torch::I8"</span><span class="p">,</span>
|
|
<span class="n">cutlass_bindings</span><span class="o">.</span><span class="n">int32</span><span class="p">:</span> <span class="s2">"torch::I32"</span><span class="p">,</span>
|
|
<span class="p">}</span>
|
|
|
|
<span class="n">_PYTORCH_GEMM_IMPL_TEMPLATE_2x</span> <span class="o">=</span> <span class="p">(</span>
|
|
<span class="n">common</span><span class="o">.</span><span class="n">_CUTLASS_KERNEL_RUN_GEMM_2x</span>
|
|
<span class="o">+</span> <span class="s2">"""</span>
|
|
<span class="s2">at::Tensor $</span><span class="si">{name}</span><span class="s2">_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C, float alpha, float beta) {</span>
|
|
<span class="s2"> int M = A.size(0);</span>
|
|
<span class="s2"> int N = B.size(1);</span>
|
|
<span class="s2"> int K = A.size(1);</span>
|
|
|
|
<span class="s2"> typename DeviceKernel::ElementC* ptrC = (C == at::nullopt) ?</span>
|
|
<span class="s2"> nullptr :</span>
|
|
<span class="s2"> reinterpret_cast<typename DeviceKernel::ElementC*>(C->contiguous().data_ptr());</span>
|
|
<span class="s2"> at::Tensor D = B.new_empty({M, N}, $</span><span class="si">{torch_type_C}</span><span class="s2">);</span>
|
|
|
|
<span class="s2"> cutlass::Status status = $</span><span class="si">{name}</span><span class="s2">_kernel_run(M, N, K,</span>
|
|
<span class="s2"> reinterpret_cast<typename DeviceKernel::ElementA*>(A.contiguous().data_ptr()),</span>
|
|
<span class="s2"> reinterpret_cast<typename DeviceKernel::ElementB*>(B.contiguous().data_ptr()),</span>
|
|
<span class="s2"> ptrC,</span>
|
|
<span class="s2"> reinterpret_cast<typename DeviceKernel::ElementC*>(D.contiguous().data_ptr()),</span>
|
|
<span class="s2"> ElementCompute(alpha), ElementCompute(beta));</span>
|
|
|
|
<span class="s2"> TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");</span>
|
|
<span class="s2"> return D;</span>
|
|
<span class="s2">}</span>
|
|
<span class="s2">"""</span>
|
|
<span class="p">)</span>
|
|
|
|
<span class="n">_PYTORCH_GEMM_IMPL_TEMPLATE_3x</span> <span class="o">=</span> <span class="p">(</span>
|
|
<span class="n">common</span><span class="o">.</span><span class="n">_CUTLASS_KERNEL_RUN_GEMM_3x</span>
|
|
<span class="o">+</span> <span class="s2">"""</span>
|
|
<span class="s2">bool hw_info_queried = false;</span>
|
|
<span class="s2">cutlass::KernelHardwareInfo hw_info;</span>
|
|
|
|
<span class="s2">at::Tensor $</span><span class="si">{name}</span><span class="s2">_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C, float alpha, float beta) {</span>
|
|
<span class="s2"> int M = A.size(0);</span>
|
|
<span class="s2"> int N = B.size(1);</span>
|
|
<span class="s2"> int K = A.size(1);</span>
|
|
<span class="s2"> int L = 1;</span>
|
|
|
|
<span class="s2"> // Query hardware info if we haven't already</span>
|
|
<span class="s2"> if (!hw_info_queried) {</span>
|
|
<span class="s2"> hw_info.device_id = 0;</span>
|
|
<span class="s2"> hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);</span>
|
|
<span class="s2"> }</span>
|
|
|
|
<span class="s2"> typename DeviceKernel::ElementC* ptrC = (C == at::nullopt) ?</span>
|
|
<span class="s2"> nullptr :</span>
|
|
<span class="s2"> reinterpret_cast<typename DeviceKernel::ElementC*>(C->contiguous().data_ptr());</span>
|
|
<span class="s2"> at::Tensor D = B.new_empty({M, N}, $</span><span class="si">{torch_type_C}</span><span class="s2">);</span>
|
|
|
|
<span class="s2"> cutlass::Status status = $</span><span class="si">{name}</span><span class="s2">_kernel_run(M, N, K, L,</span>
|
|
<span class="s2"> reinterpret_cast<typename DeviceKernel::ElementA*>(A.contiguous().data_ptr()),</span>
|
|
<span class="s2"> reinterpret_cast<typename DeviceKernel::ElementB*>(B.contiguous().data_ptr()),</span>
|
|
<span class="s2"> ptrC,</span>
|
|
<span class="s2"> reinterpret_cast<typename DeviceKernel::ElementC*>(D.contiguous().data_ptr()),</span>
|
|
<span class="s2"> ElementCompute(alpha), ElementCompute(beta),</span>
|
|
<span class="s2"> hw_info);</span>
|
|
|
|
<span class="s2"> TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");</span>
|
|
<span class="s2"> return D;</span>
|
|
<span class="s2">}</span>
|
|
<span class="s2">"""</span>
|
|
<span class="p">)</span>
|
|
|
|
|
|
<span class="n">_PYTORCH_GROUPED_GEMM_IMPL_TEMPLATE</span> <span class="o">=</span> <span class="p">(</span>
|
|
<span class="n">common</span><span class="o">.</span><span class="n">_CUTLASS_KERNEL_RUN_GROUPED_GEMM_2x</span>
|
|
<span class="o">+</span> <span class="s2">"""</span>
|
|
<span class="s2">std::vector<at::Tensor> $</span><span class="si">{name}</span><span class="s2">_kernel(const std::vector<at::Tensor>& A, const std::vector<at::Tensor>& B, at::optional<const std::vector<at::Tensor>> C, float alpha, float beta) {</span>
|
|
<span class="s2"> size_t num = A.size();</span>
|
|
|
|
<span class="s2"> // To avoid performing many small cudaMallocs and host-to-device copies,</span>
|
|
<span class="s2"> // we serialize the grouped GEMM arguments on the host, allocate one</span>
|
|
<span class="s2"> // large chunk of device memory, and perform a single cudaMemcpy to</span>
|
|
<span class="s2"> // copy the host data to the device. Allocation overheads could be</span>
|
|
<span class="s2"> // avoided by using a memory pool.</span>
|
|
|
|
<span class="s2"> // Calculate the total size of the data to be copied from host to device</span>
|
|
<span class="s2"> size_t total_size = sizeof(cutlass::gemm::GemmCoord) +</span>
|
|
<span class="s2"> sizeof(DeviceKernel::ElementA*) +</span>
|
|
<span class="s2"> sizeof(DeviceKernel::ElementB*) +</span>
|
|
<span class="s2"> sizeof(DeviceKernel::ElementC*) +</span>
|
|
<span class="s2"> sizeof(DeviceKernel::ElementC*) +</span>
|
|
<span class="s2"> sizeof(int64_t) +</span>
|
|
<span class="s2"> sizeof(int64_t) +</span>
|
|
<span class="s2"> sizeof(int64_t);</span>
|
|
<span class="s2"> total_size *= num;</span>
|
|
|
|
<span class="s2"> // num * sizeof(cutlass::gemm::GemmCoord) may leave one at a non-multiple</span>
|
|
<span class="s2"> // of sizeof(DeviceKernel::ElementA*) (which will be 64 on a 64-bit system).</span>
|
|
<span class="s2"> // To ensure that we don't end up having misaligned loads in the kernel,</span>
|
|
<span class="s2"> // we pad to the nearest multiple of 8.</span>
|
|
<span class="s2"> //</span>
|
|
<span class="s2"> // Note that, even on a 32-bit system (for which sizeof(X*) will not equal</span>
|
|
<span class="s2"> // sizeof(int64_t)), only padding between the list of GemmCoords and the</span>
|
|
<span class="s2"> // list of ptr_As is sufficient because the set of four equal-length lists of pointers</span>
|
|
<span class="s2"> // (A*, B*, C*, D*) will ensure that the first list of int64_ts will always</span>
|
|
<span class="s2"> // start on a multiple of 8.</span>
|
|
<span class="s2"> int64_t padding = 8 - (total_size % 8);</span>
|
|
<span class="s2"> total_size += padding;</span>
|
|
|
|
<span class="s2"> uint8_t* host_data = new uint8_t[total_size];</span>
|
|
<span class="s2"> cutlass::DeviceAllocation<uint8_t> device_data(total_size);</span>
|
|
|
|
<span class="s2"> uint8_t* start = host_data;</span>
|
|
<span class="s2"> cutlass::gemm::GemmCoord* problem_sizes_host = reinterpret_cast<cutlass::gemm::GemmCoord*>(start);</span>
|
|
|
|
<span class="s2"> // Apply the padding after the list of GemmCoords</span>
|
|
<span class="s2"> start += num * sizeof(cutlass::gemm::GemmCoord) + padding;</span>
|
|
|
|
<span class="s2"> int64_t ptr_A_offset = start - host_data;</span>
|
|
<span class="s2"> DeviceKernel::ElementA** ptr_A_host = reinterpret_cast<DeviceKernel::ElementA**>(start);</span>
|
|
<span class="s2"> start += num * sizeof(DeviceKernel::ElementA*);</span>
|
|
|
|
<span class="s2"> int64_t ptr_B_offset = start - host_data;</span>
|
|
<span class="s2"> DeviceKernel::ElementB** ptr_B_host = reinterpret_cast<DeviceKernel::ElementB**>(start);</span>
|
|
<span class="s2"> start += num * sizeof(DeviceKernel::ElementB*);</span>
|
|
|
|
<span class="s2"> int64_t ptr_C_offset = start - host_data;</span>
|
|
<span class="s2"> DeviceKernel::ElementC** ptr_C_host = reinterpret_cast<DeviceKernel::ElementC**>(start);</span>
|
|
<span class="s2"> start += num * sizeof(DeviceKernel::ElementC*);</span>
|
|
|
|
<span class="s2"> int64_t ptr_D_offset = start - host_data;</span>
|
|
<span class="s2"> DeviceKernel::ElementC** ptr_D_host = reinterpret_cast<DeviceKernel::ElementC**>(start);</span>
|
|
<span class="s2"> start += num * sizeof(DeviceKernel::ElementC*);</span>
|
|
|
|
<span class="s2"> int64_t lda_offset = start - host_data;</span>
|
|
<span class="s2"> int64_t* lda_host = reinterpret_cast<int64_t*>(start);</span>
|
|
<span class="s2"> start += num * sizeof(int64_t);</span>
|
|
|
|
<span class="s2"> int64_t ldb_offset = start - host_data;</span>
|
|
<span class="s2"> int64_t* ldb_host = reinterpret_cast<int64_t*>(start);</span>
|
|
<span class="s2"> start += num * sizeof(int64_t);</span>
|
|
|
|
<span class="s2"> int64_t ldc_offset = start - host_data;</span>
|
|
<span class="s2"> int64_t* ldc_host = reinterpret_cast<int64_t*>(start);</span>
|
|
<span class="s2"> start += num * sizeof(int64_t);</span>
|
|
|
|
<span class="s2"> std::vector<at::Tensor> D(num);</span>
|
|
|
|
<span class="s2"> bool need_C = (C != at::nullopt) && (beta != 0.f);</span>
|
|
<span class="s2"> for (size_t i = 0; i < num; ++i) {</span>
|
|
<span class="s2"> int M = A[i].size(0);</span>
|
|
<span class="s2"> int N = B[i].size(1);</span>
|
|
<span class="s2"> int K = A[i].size(1);</span>
|
|
<span class="s2"> *(problem_sizes_host + i) = {M, N, K};</span>
|
|
<span class="s2"> *(ptr_A_host + i) = reinterpret_cast<typename DeviceKernel::ElementA*>(A[i].contiguous().data_ptr());</span>
|
|
<span class="s2"> *(ptr_B_host + i) = reinterpret_cast<typename DeviceKernel::ElementB*>(B[i].contiguous().data_ptr());</span>
|
|
|
|
<span class="s2"> if (need_C) {</span>
|
|
<span class="s2"> *(ptr_C_host + i) = reinterpret_cast<typename DeviceKernel::ElementC*>(C->at(i).contiguous().data_ptr());</span>
|
|
<span class="s2"> }</span>
|
|
<span class="s2"> else {</span>
|
|
<span class="s2"> *(ptr_C_host + i) = nullptr;</span>
|
|
<span class="s2"> }</span>
|
|
|
|
<span class="s2"> D[i] = B[i].new_empty({M, N}, $</span><span class="si">{torch_type_C}</span><span class="s2">);</span>
|
|
<span class="s2"> *(ptr_D_host + i) = reinterpret_cast<typename DeviceKernel::ElementC*>(D[i].contiguous().data_ptr());</span>
|
|
|
|
<span class="s2"> *(lda_host + i) = DeviceKernel::LayoutA::packed({M, K}).stride(0);</span>
|
|
<span class="s2"> *(ldb_host + i) = DeviceKernel::LayoutB::packed({K, N}).stride(0);</span>
|
|
<span class="s2"> *(ldc_host + i) = DeviceKernel::LayoutC::packed({M, N}).stride(0);</span>
|
|
<span class="s2"> }</span>
|
|
|
|
<span class="s2"> device_data.copy_from_host(host_data);</span>
|
|
|
|
<span class="s2"> cutlass::Status status = $</span><span class="si">{name}</span><span class="s2">_kernel_run(</span>
|
|
<span class="s2"> num,</span>
|
|
<span class="s2"> reinterpret_cast<cutlass::gemm::GemmCoord*>(device_data.get()),</span>
|
|
<span class="s2"> reinterpret_cast<DeviceKernel::ElementA**>(device_data.get() + ptr_A_offset),</span>
|
|
<span class="s2"> reinterpret_cast<DeviceKernel::ElementB**>(device_data.get() + ptr_B_offset),</span>
|
|
<span class="s2"> reinterpret_cast<DeviceKernel::ElementC**>(device_data.get() + ptr_C_offset),</span>
|
|
<span class="s2"> reinterpret_cast<DeviceKernel::ElementC**>(device_data.get() + ptr_D_offset),</span>
|
|
<span class="s2"> reinterpret_cast<int64_t*>(device_data.get() + lda_offset),</span>
|
|
<span class="s2"> reinterpret_cast<int64_t*>(device_data.get() + ldb_offset),</span>
|
|
<span class="s2"> reinterpret_cast<int64_t*>(device_data.get() + ldc_offset),</span>
|
|
<span class="s2"> reinterpret_cast<int64_t*>(device_data.get() + ldc_offset),</span>
|
|
<span class="s2"> ElementCompute(alpha), ElementCompute(beta));</span>
|
|
|
|
<span class="s2"> delete[] host_data;</span>
|
|
|
|
<span class="s2"> TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");</span>
|
|
<span class="s2"> return D;</span>
|
|
<span class="s2">}</span>
|
|
<span class="s2">"""</span>
|
|
<span class="p">)</span>
|
|
|
|
|
|
<span class="n">_PYTORCH_SETUP_PY</span> <span class="o">=</span> <span class="n">common</span><span class="o">.</span><span class="n">_PYSTYLE_AUTOGEN_COMMENT</span> <span class="o">+</span> <span class="s2">"""</span>
|
|
<span class="s2">from setuptools import setup</span>
|
|
<span class="s2">from torch.utils.cpp_extension import BuildExtension, CUDAExtension</span>
|
|
|
|
<span class="s2">setup(</span>
|
|
<span class="s2"> name='$</span><span class="si">{name}</span><span class="s2">',</span>
|
|
<span class="s2"> ext_modules=[</span>
|
|
<span class="s2"> CUDAExtension('$</span><span class="si">{name}</span><span class="s2">', [</span>
|
|
<span class="s2"> '$</span><span class="si">{name}</span><span class="s2">.cpp',</span>
|
|
<span class="s2"> '$</span><span class="si">{name}</span><span class="s2">_kernel.cu',</span>
|
|
<span class="s2"> ],</span>
|
|
<span class="s2"> include_dirs=['$</span><span class="si">{cutlass_path}</span><span class="s2">/include', '$</span><span class="si">{cutlass_path}</span><span class="s2">/tools/util/include'],</span>
|
|
<span class="s2"> extra_compile_args=['-std=c++17']</span>
|
|
<span class="s2"> ),</span>
|
|
<span class="s2"> ],</span>
|
|
<span class="s2"> cmdclass={</span>
|
|
<span class="s2"> 'build_ext': BuildExtension</span>
|
|
<span class="s2"> })</span>
|
|
|
|
<span class="s2">"""</span>
|
|
|
|
|
|
<span class="k">def</span> <span class="nf">_generate_setup</span><span class="p">(</span><span class="n">name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">sourcedir</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
|
|
<span class="w"> </span><span class="sd">"""</span>
|
|
<span class="sd"> Generates a setup.py file for the extension</span>
|
|
|
|
<span class="sd"> :param name: name of the module to generate</span>
|
|
<span class="sd"> :type name: str</span>
|
|
<span class="sd"> :param sourcedir: directory to which generated source files should be written</span>
|
|
<span class="sd"> :type sourcedir: str</span>
|
|
<span class="sd"> """</span>
|
|
<span class="n">setup_py_file</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">sourcedir</span><span class="p">,</span> <span class="s2">"setup.py"</span><span class="p">)</span>
|
|
<span class="n">setup_source</span> <span class="o">=</span> <span class="n">SubstituteTemplate</span><span class="p">(</span>
|
|
<span class="n">_PYTORCH_SETUP_PY</span><span class="p">,</span> <span class="p">{</span><span class="s2">"name"</span><span class="p">:</span> <span class="n">name</span><span class="p">,</span> <span class="s2">"cutlass_path"</span><span class="p">:</span> <span class="n">CUTLASS_PATH</span><span class="p">}</span>
|
|
<span class="p">)</span>
|
|
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">setup_py_file</span><span class="p">,</span> <span class="s2">"w"</span><span class="p">)</span> <span class="k">as</span> <span class="n">outfile</span><span class="p">:</span>
|
|
<span class="n">outfile</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="n">setup_source</span><span class="p">)</span>
|
|
|
|
|
|
<span class="k">class</span> <span class="nc">_ArchListSetter</span><span class="p">:</span>
|
|
<span class="w"> </span><span class="sd">"""</span>
|
|
<span class="sd"> Utility context manager for temporarily setting the value of the ``TORCH_CUDA_ARCH_LIST``</span>
|
|
<span class="sd"> environment variable when building a PyTorch CUDA module.</span>
|
|
|
|
<span class="sd"> ``TORCH_CUDA_ARCH_LIST`` is a space-delmited list of compute capabilites for which a PyTorch</span>
|
|
<span class="sd"> CUDA module should be compiled.</span>
|
|
|
|
<span class="sd"> For example, ``TORCH_CUDA_ARCH_LIST="7.0 8.0"`` would result in the inclusion of</span>
|
|
<span class="sd"> ``-gencode=arch=compute_70,code=sm_70`` and ``-gencode=arch=compute_80,code=sm_80`` in the</span>
|
|
<span class="sd"> compilation of the module.</span>
|
|
|
|
<span class="sd"> This utility wraps the building of a PyTorch CUDA module with a setting of this environment</span>
|
|
<span class="sd"> variable according to the current compute capability being targetted.</span>
|
|
|
|
<span class="sd"> Example usage:</span>
|
|
|
|
<span class="sd"> .. highlight:: python</span>
|
|
<span class="sd"> .. code-block:: python</span>
|
|
|
|
<span class="sd"> # Temporarily set TORCH_CUDA_ARCH_LIST="8.0"</span>
|
|
<span class="sd"> with _ArchListSetter(80):</span>
|
|
<span class="sd"> # Perform JIT compilation and loading of the module</span>
|
|
<span class="sd"> mod = torch.utils.cpp_extension.load(...)</span>
|
|
|
|
<span class="sd"> :param cc: compute capability</span>
|
|
<span class="sd"> :type cc: int</span>
|
|
<span class="sd"> """</span>
|
|
|
|
<span class="n">_TORCH_CUDA_ARCH_LIST</span> <span class="o">=</span> <span class="s2">"TORCH_CUDA_ARCH_LIST"</span>
|
|
|
|
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">cc</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">cc_str</span> <span class="o">=</span> <span class="s2">"."</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">cc</span><span class="p">)))</span>
|
|
|
|
<span class="k">def</span> <span class="fm">__enter__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="w"> </span><span class="sd">"""</span>
|
|
<span class="sd"> Saves the old value of TORCH_CUDA_ARCH_LIST and reset it to the new value based on ``cc``</span>
|
|
<span class="sd"> """</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">old_arch_list</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">getenv</span><span class="p">(</span><span class="n">_ArchListSetter</span><span class="o">.</span><span class="n">_TORCH_CUDA_ARCH_LIST</span><span class="p">)</span>
|
|
<span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="n">_ArchListSetter</span><span class="o">.</span><span class="n">_TORCH_CUDA_ARCH_LIST</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">cc_str</span>
|
|
|
|
<span class="k">return</span> <span class="bp">self</span>
|
|
|
|
<span class="k">def</span> <span class="fm">__exit__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">exc_type</span><span class="p">,</span> <span class="n">exc_val</span><span class="p">,</span> <span class="n">traceback</span><span class="p">):</span>
|
|
<span class="w"> </span><span class="sd">"""</span>
|
|
<span class="sd"> Restores the old value of TORCH_CUDA_ARCH_LIST</span>
|
|
<span class="sd"> """</span>
|
|
<span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="n">_ArchListSetter</span><span class="o">.</span><span class="n">_TORCH_CUDA_ARCH_LIST</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">old_arch_list</span>
|
|
|
|
|
|
<span class="k">def</span> <span class="nf">_jit</span><span class="p">(</span><span class="n">name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">cc</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">cpp_file</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">cuda_file</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
|
|
<span class="w"> </span><span class="sd">"""</span>
|
|
<span class="sd"> JIT compiles and loads a PyTorch CUDA extension.</span>
|
|
|
|
<span class="sd"> :param name: name of the module to generate</span>
|
|
<span class="sd"> :type name: str</span>
|
|
<span class="sd"> :param cc: compute capability of the device the module should target</span>
|
|
<span class="sd"> :type cc: int</span>
|
|
<span class="sd"> :param cpp_file: path to file containing extension's C++ interface</span>
|
|
<span class="sd"> :type cpp_file: str</span>
|
|
<span class="sd"> :param cuda_file: path to file containing extension's CUDA interface</span>
|
|
<span class="sd"> :type cuda_file: str</span>
|
|
|
|
<span class="sd"> :return: loaded PyTorch module</span>
|
|
<span class="sd"> """</span>
|
|
|
|
<span class="kn">from</span> <span class="nn">torch.utils.cpp_extension</span> <span class="kn">import</span> <span class="n">load</span>
|
|
|
|
<span class="n">extra_cuda_cflags</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"-std=c++17"</span><span class="p">]</span>
|
|
<span class="k">if</span> <span class="n">cc</span> <span class="o">==</span> <span class="mi">90</span><span class="p">:</span>
|
|
<span class="c1"># PyTorch does not currently add the sm_90a target when compute capability</span>
|
|
<span class="c1"># 9.0 is set within TORCH_CUDA_ARCH_LIST. Thus, we manually add the sm_90a target.</span>
|
|
<span class="n">extra_cuda_cflags</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="s2">"-gencode=arch=compute_90a,code=sm_90a"</span><span class="p">)</span>
|
|
|
|
<span class="k">with</span> <span class="n">_ArchListSetter</span><span class="p">(</span><span class="n">cc</span><span class="p">):</span>
|
|
<span class="n">jitmodule</span> <span class="o">=</span> <span class="n">load</span><span class="p">(</span>
|
|
<span class="n">name</span><span class="p">,</span>
|
|
<span class="p">[</span><span class="n">cpp_file</span><span class="p">,</span> <span class="n">cuda_file</span><span class="p">],</span>
|
|
<span class="n">extra_cuda_cflags</span><span class="o">=</span><span class="n">extra_cuda_cflags</span><span class="p">,</span>
|
|
<span class="n">extra_include_paths</span><span class="o">=</span><span class="p">[</span>
|
|
<span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">CUTLASS_PATH</span><span class="p">,</span> <span class="s2">"include"</span><span class="p">),</span>
|
|
<span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">CUTLASS_PATH</span><span class="p">,</span> <span class="s2">"tools/util/include"</span><span class="p">),</span>
|
|
<span class="p">],</span>
|
|
<span class="n">verbose</span><span class="o">=</span><span class="p">(</span><span class="n">logger</span><span class="o">.</span><span class="n">level</span> <span class="o">==</span> <span class="n">logging</span><span class="o">.</span><span class="n">DEBUG</span><span class="p">)</span>
|
|
<span class="p">)</span>
|
|
<span class="k">return</span> <span class="n">jitmodule</span>
|
|
|
|
|
|
<span class="k">def</span> <span class="nf">_pytorch_gemm</span><span class="p">(</span><span class="n">op</span><span class="p">,</span> <span class="n">name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">cc</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">jit</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span> <span class="n">sourcedir</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">""</span><span class="p">):</span>
|
|
<span class="w"> </span><span class="sd">"""</span>
|
|
<span class="sd"> Generates source for building a PyTorch CUDA module that leverages the CUTLASS GEMM</span>
|
|
<span class="sd"> specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time</span>
|
|
<span class="sd"> compiled, loaded, and returned.</span>
|
|
|
|
<span class="sd"> :param op: operation to emit in the module</span>
|
|
<span class="sd"> :param name: name of the module to generate</span>
|
|
<span class="sd"> :type name: str</span>
|
|
<span class="sd"> :param cc: compute capability of the device the module should target</span>
|
|
<span class="sd"> :type cc: int</span>
|
|
<span class="sd"> :param jit: whether the module should be just-in-time compiled</span>
|
|
<span class="sd"> :type jit: bool</span>
|
|
<span class="sd"> :param sourcedir: directory to which generated source files should be written</span>
|
|
<span class="sd"> :type sourcedir: str</span>
|
|
|
|
<span class="sd"> :return: loaded PyTorch module if ``jit=True`` or ``None`` otherwise</span>
|
|
<span class="sd"> """</span>
|
|
<span class="k">if</span> <span class="n">sourcedir</span> <span class="o">!=</span> <span class="s2">""</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">isdir</span><span class="p">(</span><span class="n">sourcedir</span><span class="p">):</span>
|
|
<span class="n">os</span><span class="o">.</span><span class="n">makedirs</span><span class="p">(</span><span class="n">sourcedir</span><span class="p">)</span>
|
|
|
|
<span class="n">cuda_file</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">sourcedir</span><span class="p">,</span> <span class="n">name</span> <span class="o">+</span> <span class="s2">"_kernel.cu"</span><span class="p">)</span>
|
|
<span class="n">extra_kw</span> <span class="o">=</span> <span class="p">{}</span>
|
|
<span class="k">if</span> <span class="n">op</span><span class="o">.</span><span class="n">api</span> <span class="o">==</span> <span class="n">ApiVersion</span><span class="o">.</span><span class="n">v3x</span><span class="p">:</span>
|
|
<span class="n">impl_template</span> <span class="o">=</span> <span class="n">_PYTORCH_GEMM_IMPL_TEMPLATE_3x</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">impl_template</span> <span class="o">=</span> <span class="n">_PYTORCH_GEMM_IMPL_TEMPLATE_2x</span>
|
|
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">op</span><span class="o">.</span><span class="n">swizzling_functor</span><span class="p">,</span> <span class="n">swizzle</span><span class="o">.</span><span class="n">ThreadblockSwizzleStreamK</span><span class="p">):</span>
|
|
<span class="n">extra_kw</span><span class="p">[</span><span class="s2">"args"</span><span class="p">]</span> <span class="o">=</span> <span class="n">common</span><span class="o">.</span><span class="n">_CUTLASS_KERNEL_ARGS_2x_STREAM_K</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">extra_kw</span><span class="p">[</span><span class="s2">"args"</span><span class="p">]</span> <span class="o">=</span> <span class="n">common</span><span class="o">.</span><span class="n">_CUTLASS_KERNEL_ARGS_2x</span>
|
|
<span class="n">impl_template</span> <span class="o">=</span> <span class="p">(</span>
|
|
<span class="n">_PYTORCH_GEMM_IMPL_TEMPLATE_3x</span>
|
|
<span class="k">if</span> <span class="n">op</span><span class="o">.</span><span class="n">api</span> <span class="o">==</span> <span class="n">ApiVersion</span><span class="o">.</span><span class="n">v3x</span>
|
|
<span class="k">else</span> <span class="n">_PYTORCH_GEMM_IMPL_TEMPLATE_2x</span>
|
|
<span class="p">)</span>
|
|
<span class="n">cuda_impl</span> <span class="o">=</span> <span class="n">SubstituteTemplate</span><span class="p">(</span><span class="n">impl_template</span><span class="p">,</span> <span class="p">{</span><span class="s2">"name"</span><span class="p">:</span> <span class="n">name</span><span class="p">,</span> <span class="o">**</span><span class="n">extra_kw</span><span class="p">})</span>
|
|
<span class="n">cuda_source</span> <span class="o">=</span> <span class="n">SubstituteTemplate</span><span class="p">(</span>
|
|
<span class="n">_PYTORCH_CUDA_TEMPLATE</span><span class="p">,</span>
|
|
<span class="p">{</span>
|
|
<span class="s2">"includes"</span><span class="p">:</span> <span class="n">_PYTORCH_GEMM_INCLUDES</span><span class="p">[</span><span class="n">op</span><span class="o">.</span><span class="n">api</span><span class="p">],</span>
|
|
<span class="s2">"declaration"</span><span class="p">:</span> <span class="n">op</span><span class="o">.</span><span class="n">rt_module</span><span class="o">.</span><span class="n">emit</span><span class="p">(),</span>
|
|
<span class="s2">"procedural_name"</span><span class="p">:</span> <span class="n">op</span><span class="o">.</span><span class="n">procedural_name</span><span class="p">(),</span>
|
|
<span class="s2">"impl"</span><span class="p">:</span> <span class="n">cuda_impl</span><span class="p">,</span>
|
|
<span class="s2">"torch_type_C"</span><span class="p">:</span> <span class="n">_CUTLASS_TYPE_TO_TORCH_TYPE</span><span class="p">[</span><span class="n">op</span><span class="o">.</span><span class="n">C</span><span class="o">.</span><span class="n">element</span><span class="p">],</span>
|
|
<span class="p">},</span>
|
|
<span class="p">)</span>
|
|
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">cuda_file</span><span class="p">,</span> <span class="s2">"w"</span><span class="p">)</span> <span class="k">as</span> <span class="n">outfile</span><span class="p">:</span>
|
|
<span class="n">outfile</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="n">cuda_source</span><span class="p">)</span>
|
|
|
|
<span class="n">cpp_file</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">sourcedir</span><span class="p">,</span> <span class="n">name</span> <span class="o">+</span> <span class="s2">".cpp"</span><span class="p">)</span>
|
|
<span class="n">cpp_source</span> <span class="o">=</span> <span class="n">SubstituteTemplate</span><span class="p">(</span>
|
|
<span class="n">_PYTORCH_GEMM_CPP_TEMPLATE</span><span class="p">,</span>
|
|
<span class="p">{</span><span class="s2">"name"</span><span class="p">:</span> <span class="n">name</span><span class="p">,</span> <span class="s2">"description"</span><span class="p">:</span> <span class="sa">f</span><span class="s2">"CUTLASS </span><span class="si">{</span><span class="n">op</span><span class="o">.</span><span class="n">procedural_name</span><span class="p">()</span><span class="si">}</span><span class="s2"> GEMM"</span><span class="p">},</span>
|
|
<span class="p">)</span>
|
|
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">cpp_file</span><span class="p">,</span> <span class="s2">"w"</span><span class="p">)</span> <span class="k">as</span> <span class="n">outfile</span><span class="p">:</span>
|
|
<span class="n">outfile</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="n">cpp_source</span><span class="p">)</span>
|
|
|
|
<span class="n">_generate_setup</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">sourcedir</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="n">jit</span><span class="p">:</span>
|
|
<span class="k">return</span> <span class="n">_jit</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">cc</span><span class="p">,</span> <span class="n">cpp_file</span><span class="p">,</span> <span class="n">cuda_file</span><span class="p">)</span>
|
|
|
|
<span class="k">return</span> <span class="kc">None</span>
|
|
|
|
|
|
<span class="k">def</span> <span class="nf">_pytorch_grouped_gemm</span><span class="p">(</span>
|
|
<span class="n">op</span><span class="p">,</span> <span class="n">name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">cc</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">jit</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span> <span class="n">sourcedir</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">""</span>
|
|
<span class="p">):</span>
|
|
<span class="w"> </span><span class="sd">"""</span>
|
|
<span class="sd"> Generates source for building a PyTorch CUDA module that leverages the CUTLASS grouped GEMM</span>
|
|
<span class="sd"> specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time</span>
|
|
<span class="sd"> compiled, loaded, and returned.</span>
|
|
|
|
<span class="sd"> :param op: operation to emit in the module</span>
|
|
<span class="sd"> :param name: name of the module to generate</span>
|
|
<span class="sd"> :type name: str</span>
|
|
<span class="sd"> :param cc: compute capability of the device the module should target</span>
|
|
<span class="sd"> :type cc: int</span>
|
|
<span class="sd"> :param jit: whether the module should be just-in-time compiled</span>
|
|
<span class="sd"> :type jit: bool</span>
|
|
<span class="sd"> :param sourcedir: directory to which generated source files should be written</span>
|
|
<span class="sd"> :type sourcedir: str</span>
|
|
|
|
<span class="sd"> :return: loaded PyTorch module if ``jit=True`` or ``None`` otherwise</span>
|
|
<span class="sd"> """</span>
|
|
<span class="k">if</span> <span class="n">op</span><span class="o">.</span><span class="n">api</span> <span class="o">!=</span> <span class="n">ApiVersion</span><span class="o">.</span><span class="n">v2x</span><span class="p">:</span>
|
|
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="s2">"Grouped GEMM is currently only supported for CUTLASS 2.x"</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="n">sourcedir</span> <span class="o">!=</span> <span class="s2">""</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">isdir</span><span class="p">(</span><span class="n">sourcedir</span><span class="p">):</span>
|
|
<span class="n">os</span><span class="o">.</span><span class="n">makedirs</span><span class="p">(</span><span class="n">sourcedir</span><span class="p">)</span>
|
|
|
|
<span class="n">cuda_file</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">sourcedir</span><span class="p">,</span> <span class="n">name</span> <span class="o">+</span> <span class="s2">"_kernel.cu"</span><span class="p">)</span>
|
|
<span class="n">cuda_impl</span> <span class="o">=</span> <span class="n">SubstituteTemplate</span><span class="p">(</span><span class="n">_PYTORCH_GROUPED_GEMM_IMPL_TEMPLATE</span><span class="p">,</span> <span class="p">{</span><span class="s2">"name"</span><span class="p">:</span> <span class="n">name</span><span class="p">})</span>
|
|
<span class="n">cuda_source</span> <span class="o">=</span> <span class="n">SubstituteTemplate</span><span class="p">(</span>
|
|
<span class="n">_PYTORCH_CUDA_TEMPLATE</span><span class="p">,</span>
|
|
<span class="p">{</span>
|
|
<span class="s2">"includes"</span><span class="p">:</span> <span class="n">_PYTORCH_GROUPED_GEMM_INCLUDES</span><span class="p">,</span>
|
|
<span class="s2">"declaration"</span><span class="p">:</span> <span class="n">op</span><span class="o">.</span><span class="n">rt_module</span><span class="o">.</span><span class="n">emit</span><span class="p">(),</span>
|
|
<span class="s2">"procedural_name"</span><span class="p">:</span> <span class="n">op</span><span class="o">.</span><span class="n">procedural_name</span><span class="p">(),</span>
|
|
<span class="s2">"impl"</span><span class="p">:</span> <span class="n">cuda_impl</span><span class="p">,</span>
|
|
<span class="s2">"torch_type_C"</span><span class="p">:</span> <span class="n">_CUTLASS_TYPE_TO_TORCH_TYPE</span><span class="p">[</span><span class="n">op</span><span class="o">.</span><span class="n">C</span><span class="o">.</span><span class="n">element</span><span class="p">],</span>
|
|
<span class="p">},</span>
|
|
<span class="p">)</span>
|
|
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">cuda_file</span><span class="p">,</span> <span class="s2">"w"</span><span class="p">)</span> <span class="k">as</span> <span class="n">outfile</span><span class="p">:</span>
|
|
<span class="n">outfile</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="n">cuda_source</span><span class="p">)</span>
|
|
|
|
<span class="n">cpp_file</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">sourcedir</span><span class="p">,</span> <span class="n">name</span> <span class="o">+</span> <span class="s2">".cpp"</span><span class="p">)</span>
|
|
<span class="n">cpp_source</span> <span class="o">=</span> <span class="n">SubstituteTemplate</span><span class="p">(</span>
|
|
<span class="n">_PYTORCH_GROUPED_GEMM_CPP_TEMPLATE</span><span class="p">,</span>
|
|
<span class="p">{</span><span class="s2">"name"</span><span class="p">:</span> <span class="n">name</span><span class="p">,</span> <span class="s2">"description"</span><span class="p">:</span> <span class="sa">f</span><span class="s2">"CUTLASS </span><span class="si">{</span><span class="n">op</span><span class="o">.</span><span class="n">procedural_name</span><span class="p">()</span><span class="si">}</span><span class="s2"> grouped GEMM"</span><span class="p">},</span>
|
|
<span class="p">)</span>
|
|
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">cpp_file</span><span class="p">,</span> <span class="s2">"w"</span><span class="p">)</span> <span class="k">as</span> <span class="n">outfile</span><span class="p">:</span>
|
|
<span class="n">outfile</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="n">cpp_source</span><span class="p">)</span>
|
|
|
|
<span class="n">_generate_setup</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">sourcedir</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="n">jit</span><span class="p">:</span>
|
|
<span class="k">return</span> <span class="n">_jit</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">cc</span><span class="p">,</span> <span class="n">cpp_file</span><span class="p">,</span> <span class="n">cuda_file</span><span class="p">)</span>
|
|
|
|
<span class="k">return</span> <span class="kc">None</span>
|
|
|
|
|
|
<div class="viewcode-block" id="pytorch"><a class="viewcode-back" href="../../../cutlass.emit.html#cutlass.emit.pytorch.pytorch">[docs]</a><span class="k">def</span> <span class="nf">pytorch</span><span class="p">(</span><span class="n">op</span><span class="p">,</span> <span class="n">name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">cc</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">jit</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span> <span class="n">sourcedir</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">""</span><span class="p">):</span>
|
|
<span class="w"> </span><span class="sd">"""</span>
|
|
<span class="sd"> Generates source for building a PyTorch CUDA module that leverages the CUTLASS kernel</span>
|
|
<span class="sd"> specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time</span>
|
|
<span class="sd"> compiled, loaded, and returned.</span>
|
|
|
|
<span class="sd"> The result of this method is files within ``sourcedir`` that can be used for building</span>
|
|
<span class="sd"> a PyTorch module.</span>
|
|
|
|
<span class="sd"> :param op: operation to emit in the module</span>
|
|
<span class="sd"> :param name: name of the module to generate</span>
|
|
<span class="sd"> :type name: str</span>
|
|
<span class="sd"> :param cc: compute capability of the device the module should target</span>
|
|
<span class="sd"> :type cc: int</span>
|
|
<span class="sd"> :param jit: whether the module should be just-in-time compiled</span>
|
|
<span class="sd"> :type jit: bool</span>
|
|
<span class="sd"> :param sourcedir: directory to which generated source files should be written</span>
|
|
<span class="sd"> :type sourcedir: str</span>
|
|
|
|
<span class="sd"> :return: loaded PyTorch module (if ``jit=True``) or None</span>
|
|
<span class="sd"> """</span>
|
|
<span class="n">device_op</span> <span class="o">=</span> <span class="n">op</span><span class="o">.</span><span class="n">device_op</span><span class="p">()</span>
|
|
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">op</span><span class="p">,</span> <span class="n">GemmOperationUniversal</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="n">_pytorch_gemm</span><span class="p">(</span><span class="n">device_op</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">cc</span><span class="p">,</span> <span class="n">jit</span><span class="p">,</span> <span class="n">sourcedir</span><span class="p">)</span>
|
|
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">op</span><span class="p">,</span> <span class="n">GemmOperationGrouped</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="n">_pytorch_grouped_gemm</span><span class="p">(</span><span class="n">device_op</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">cc</span><span class="p">,</span> <span class="n">jit</span><span class="p">,</span> <span class="n">sourcedir</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span>
|
|
<span class="sa">f</span><span class="s2">"Operation type </span><span class="si">{</span><span class="nb">type</span><span class="p">(</span><span class="n">op</span><span class="p">)</span><span class="si">}</span><span class="s2"> is not currently supported for PyTorch emission."</span>
|
|
<span class="p">)</span></div>
|
|
</pre></div>
|
|
</article>
|
|
</div>
|
|
<footer>
|
|
|
|
<div class="related-pages">
|
|
|
|
|
|
</div>
|
|
<div class="bottom-of-page">
|
|
<div class="left-details">
|
|
<div class="copyright">
|
|
Copyright © 2023, NVIDIA
|
|
</div>
|
|
Made with <a href="https://www.sphinx-doc.org/">Sphinx</a> and <a class="muted-link" href="https://pradyunsg.me">@pradyunsg</a>'s
|
|
|
|
<a href="https://github.com/pradyunsg/furo">Furo</a>
|
|
|
|
</div>
|
|
<div class="right-details">
|
|
<div class="icons">
|
|
<a class="muted-link " href="https://github.com/NVIDIA/cutlass" aria-label="GitHub">
|
|
<svg stroke="currentColor" fill="currentColor" stroke-width="0" viewBox="0 0 16 16">
|
|
<path fill-rule="evenodd" d="M8 0C3.58 0 0 3.58 0 8c0 3.54 2.29 6.53 5.47 7.59.4.07.55-.17.55-.38 0-.19-.01-.82-.01-1.49-2.01.37-2.53-.49-2.69-.94-.09-.23-.48-.94-.82-1.13-.28-.15-.68-.52-.01-.53.63-.01 1.08.58 1.23.82.72 1.21 1.87.87 2.33.66.07-.52.28-.87.51-1.07-1.78-.2-3.64-.89-3.64-3.95 0-.87.31-1.59.82-2.15-.08-.2-.36-1.02.08-2.12 0 0 .67-.21 2.2.82.64-.18 1.32-.27 2-.27.68 0 1.36.09 2 .27 1.53-1.04 2.2-.82 2.2-.82.44 1.1.16 1.92.08 2.12.51.56.82 1.27.82 2.15 0 3.07-1.87 3.75-3.65 3.95.29.25.54.73.54 1.48 0 1.07-.01 1.93-.01 2.2 0 .21.15.46.55.38A8.013 8.013 0 0 0 16 8c0-4.42-3.58-8-8-8z"></path>
|
|
</svg>
|
|
</a>
|
|
|
|
</div>
|
|
</div>
|
|
</div>
|
|
|
|
</footer>
|
|
</div>
|
|
<aside class="toc-drawer no-toc">
|
|
|
|
|
|
|
|
</aside>
|
|
</div>
|
|
</div><script data-url_root="../../../" id="documentation_options" src="../../../_static/documentation_options.js"></script>
|
|
<script src="../../../_static/doctools.js"></script>
|
|
<script src="../../../_static/sphinx_highlight.js"></script>
|
|
<script src="../../../_static/scripts/furo.js"></script>
|
|
<script src="../../../_static/clipboard.min.js"></script>
|
|
<script src="../../../_static/copybutton.js"></script>
|
|
<script src="../../../_static/tabs.js"></script>
|
|
<script crossorigin="anonymous" integrity="sha256-Ae2Vz/4ePdIu6ZyI/5ZGsYnb+m0JlOmKPjt6XZ9JJkA=" src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js"></script>
|
|
</body>
|
|
</html> |