mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-13 17:55:42 +00:00
554 lines
50 KiB
HTML
554 lines
50 KiB
HTML
<!doctype html>
|
|
<html class="no-js" lang="en">
|
|
<head><meta charset="utf-8"/>
|
|
<meta name="viewport" content="width=device-width,initial-scale=1"/>
|
|
<meta name="color-scheme" content="light dark"><link rel="index" title="Index" href="../../../genindex.html" /><link rel="search" title="Search" href="../../../search.html" />
|
|
<link rel="canonical" href="docs/_modules/cutlass/op/gemm_grouped.html" />
|
|
|
|
<!-- Generated with Sphinx 6.1.3 and Furo 2023.03.27 -->
|
|
<title>cutlass.op.gemm_grouped - CUTLASS Python</title>
|
|
<link rel="stylesheet" type="text/css" href="../../../_static/pygments.css" />
|
|
<link rel="stylesheet" type="text/css" href="../../../_static/styles/furo.css?digest=fad236701ea90a88636c2a8c73b44ae642ed2a53" />
|
|
<link rel="stylesheet" type="text/css" href="../../../_static/copybutton.css" />
|
|
<link rel="stylesheet" type="text/css" href="../../../_static/tabs.css" />
|
|
<link rel="stylesheet" type="text/css" href="../../../_static/styles/furo-extensions.css?digest=30d1aed668e5c3a91c3e3bf6a60b675221979f0e" />
|
|
|
|
|
|
|
|
|
|
<style>
|
|
body {
|
|
--color-code-background: #eeffcc;
|
|
--color-code-foreground: black;
|
|
--color-brand-primary: #76B900;
|
|
--color-brand-content: #76B900;
|
|
|
|
}
|
|
@media not print {
|
|
body[data-theme="dark"] {
|
|
--color-code-background: #272822;
|
|
--color-code-foreground: #f8f8f2;
|
|
--color-brand-primary: #76B900;
|
|
--color-brand-content: #76B900;
|
|
|
|
}
|
|
@media (prefers-color-scheme: dark) {
|
|
body:not([data-theme="light"]) {
|
|
--color-code-background: #272822;
|
|
--color-code-foreground: #f8f8f2;
|
|
--color-brand-primary: #76B900;
|
|
--color-brand-content: #76B900;
|
|
|
|
}
|
|
}
|
|
}
|
|
</style></head>
|
|
<body>
|
|
|
|
<script>
|
|
document.body.dataset.theme = localStorage.getItem("theme") || "auto";
|
|
</script>
|
|
|
|
|
|
<svg xmlns="http://www.w3.org/2000/svg" style="display: none;">
|
|
<symbol id="svg-toc" viewBox="0 0 24 24">
|
|
<title>Contents</title>
|
|
<svg stroke="currentColor" fill="currentColor" stroke-width="0" viewBox="0 0 1024 1024">
|
|
<path d="M408 442h480c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8H408c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8zm-8 204c0 4.4 3.6 8 8 8h480c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8H408c-4.4 0-8 3.6-8 8v56zm504-486H120c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8h784c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8zm0 632H120c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8h784c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8zM115.4 518.9L271.7 642c5.8 4.6 14.4.5 14.4-6.9V388.9c0-7.4-8.5-11.5-14.4-6.9L115.4 505.1a8.74 8.74 0 0 0 0 13.8z"/>
|
|
</svg>
|
|
</symbol>
|
|
<symbol id="svg-menu" viewBox="0 0 24 24">
|
|
<title>Menu</title>
|
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
|
stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="feather-menu">
|
|
<line x1="3" y1="12" x2="21" y2="12"></line>
|
|
<line x1="3" y1="6" x2="21" y2="6"></line>
|
|
<line x1="3" y1="18" x2="21" y2="18"></line>
|
|
</svg>
|
|
</symbol>
|
|
<symbol id="svg-arrow-right" viewBox="0 0 24 24">
|
|
<title>Expand</title>
|
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
|
stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="feather-chevron-right">
|
|
<polyline points="9 18 15 12 9 6"></polyline>
|
|
</svg>
|
|
</symbol>
|
|
<symbol id="svg-sun" viewBox="0 0 24 24">
|
|
<title>Light mode</title>
|
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
|
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="feather-sun">
|
|
<circle cx="12" cy="12" r="5"></circle>
|
|
<line x1="12" y1="1" x2="12" y2="3"></line>
|
|
<line x1="12" y1="21" x2="12" y2="23"></line>
|
|
<line x1="4.22" y1="4.22" x2="5.64" y2="5.64"></line>
|
|
<line x1="18.36" y1="18.36" x2="19.78" y2="19.78"></line>
|
|
<line x1="1" y1="12" x2="3" y2="12"></line>
|
|
<line x1="21" y1="12" x2="23" y2="12"></line>
|
|
<line x1="4.22" y1="19.78" x2="5.64" y2="18.36"></line>
|
|
<line x1="18.36" y1="5.64" x2="19.78" y2="4.22"></line>
|
|
</svg>
|
|
</symbol>
|
|
<symbol id="svg-moon" viewBox="0 0 24 24">
|
|
<title>Dark mode</title>
|
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
|
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="icon-tabler-moon">
|
|
<path stroke="none" d="M0 0h24v24H0z" fill="none" />
|
|
<path d="M12 3c.132 0 .263 0 .393 0a7.5 7.5 0 0 0 7.92 12.446a9 9 0 1 1 -8.313 -12.454z" />
|
|
</svg>
|
|
</symbol>
|
|
<symbol id="svg-sun-half" viewBox="0 0 24 24">
|
|
<title>Auto light/dark mode</title>
|
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
|
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="icon-tabler-shadow">
|
|
<path stroke="none" d="M0 0h24v24H0z" fill="none"/>
|
|
<circle cx="12" cy="12" r="9" />
|
|
<path d="M13 12h5" />
|
|
<path d="M13 15h4" />
|
|
<path d="M13 18h1" />
|
|
<path d="M13 9h4" />
|
|
<path d="M13 6h1" />
|
|
</svg>
|
|
</symbol>
|
|
</svg>
|
|
|
|
<input type="checkbox" class="sidebar-toggle" name="__navigation" id="__navigation">
|
|
<input type="checkbox" class="sidebar-toggle" name="__toc" id="__toc">
|
|
<label class="overlay sidebar-overlay" for="__navigation">
|
|
<div class="visually-hidden">Hide navigation sidebar</div>
|
|
</label>
|
|
<label class="overlay toc-overlay" for="__toc">
|
|
<div class="visually-hidden">Hide table of contents sidebar</div>
|
|
</label>
|
|
|
|
|
|
|
|
<div class="page">
|
|
<header class="mobile-header">
|
|
<div class="header-left">
|
|
<label class="nav-overlay-icon" for="__navigation">
|
|
<div class="visually-hidden">Toggle site navigation sidebar</div>
|
|
<i class="icon"><svg><use href="#svg-menu"></use></svg></i>
|
|
</label>
|
|
</div>
|
|
<div class="header-center">
|
|
<a href="../../../index.html"><div class="brand">CUTLASS Python</div></a>
|
|
</div>
|
|
<div class="header-right">
|
|
<div class="theme-toggle-container theme-toggle-header">
|
|
<button class="theme-toggle">
|
|
<div class="visually-hidden">Toggle Light / Dark / Auto color theme</div>
|
|
<svg class="theme-icon-when-auto"><use href="#svg-sun-half"></use></svg>
|
|
<svg class="theme-icon-when-dark"><use href="#svg-moon"></use></svg>
|
|
<svg class="theme-icon-when-light"><use href="#svg-sun"></use></svg>
|
|
</button>
|
|
</div>
|
|
<label class="toc-overlay-icon toc-header-icon no-toc" for="__toc">
|
|
<div class="visually-hidden">Toggle table of contents sidebar</div>
|
|
<i class="icon"><svg><use href="#svg-toc"></use></svg></i>
|
|
</label>
|
|
</div>
|
|
</header>
|
|
<aside class="sidebar-drawer">
|
|
<div class="sidebar-container">
|
|
|
|
<div class="sidebar-sticky"><a class="sidebar-brand" href="../../../index.html">
|
|
|
|
<div class="sidebar-logo-container">
|
|
<img class="sidebar-logo only-light" src="../../../_static/cutlass-logo-small.png" alt="Light Logo"/>
|
|
<img class="sidebar-logo only-dark" src="../../../_static/cutlass-logo-small.png" alt="Dark Logo"/>
|
|
</div>
|
|
|
|
<span class="sidebar-brand-text">CUTLASS Python</span>
|
|
|
|
</a><form class="sidebar-search-container" method="get" action="../../../search.html" role="search">
|
|
<input class="sidebar-search" placeholder="Search" name="q" aria-label="Search">
|
|
<input type="hidden" name="check_keywords" value="yes">
|
|
<input type="hidden" name="area" value="default">
|
|
</form>
|
|
<div id="searchbox"></div><div class="sidebar-scroll"><div class="sidebar-tree">
|
|
<ul>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../index.html">Home</a></li>
|
|
</ul>
|
|
<p class="caption" role="heading"><span class="caption-text">Getting Started:</span></p>
|
|
<ul>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../install.html">Installation</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../externals/00_basic_gemm.html">Getting Started</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../contribute.html">Contributing</a></li>
|
|
</ul>
|
|
<p class="caption" role="heading"><span class="caption-text">Python Documentation:</span></p>
|
|
<ul>
|
|
<li class="toctree-l1 has-children"><a class="reference internal" href="../../../modules.html">CUTLASS Python API</a><input class="toctree-checkbox" id="toctree-checkbox-1" name="toctree-checkbox-1" role="switch" type="checkbox"/><label for="toctree-checkbox-1"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul>
|
|
<li class="toctree-l2 has-children"><a class="reference internal" href="../../../cutlass.html">CUTLASS</a><input class="toctree-checkbox" id="toctree-checkbox-2" name="toctree-checkbox-2" role="switch" type="checkbox"/><label for="toctree-checkbox-2"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul>
|
|
<li class="toctree-l3"><a class="reference internal" href="../../../cutlass.emit.html">Emitters</a></li>
|
|
<li class="toctree-l3"><a class="reference internal" href="../../../cutlass.op.html">Operations</a></li>
|
|
<li class="toctree-l3"><a class="reference internal" href="../../../cutlass.utils.html">Utilities</a></li>
|
|
</ul>
|
|
</li>
|
|
</ul>
|
|
</li>
|
|
</ul>
|
|
<p class="caption" role="heading"><span class="caption-text">Examples and Tutorials:</span></p>
|
|
<ul>
|
|
<li class="toctree-l1 has-children"><a class="reference internal" href="../../../examples.html">Examples</a><input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" role="switch" type="checkbox"/><label for="toctree-checkbox-3"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../externals/00_basic_gemm.html">Basic GEMM</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../externals/01_epilogue.html">Epilogue</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../externals/02_pytorch_extension_grouped_gemm.html">PyTorch Extension</a></li>
|
|
</ul>
|
|
</li>
|
|
</ul>
|
|
<p class="caption" role="heading"><span class="caption-text">Reference:</span></p>
|
|
<ul>
|
|
<li class="toctree-l1"><a class="reference external" href="https://github.com/NVIDIA/cutlass">Github</a></li>
|
|
</ul>
|
|
|
|
</div>
|
|
</div>
|
|
|
|
</div>
|
|
|
|
</div>
|
|
</aside>
|
|
<div class="main">
|
|
<div class="content">
|
|
<div class="article-container">
|
|
<a href="#" class="back-to-top muted-link">
|
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24">
|
|
<path d="M13 20h-2V8l-5.5 5.5-1.42-1.42L12 4.16l7.92 7.92-1.42 1.42L13 8v12z"></path>
|
|
</svg>
|
|
<span>Back to top</span>
|
|
</a>
|
|
<div class="content-icon-container">
|
|
<div class="theme-toggle-container theme-toggle-content">
|
|
<button class="theme-toggle">
|
|
<div class="visually-hidden">Toggle Light / Dark / Auto color theme</div>
|
|
<svg class="theme-icon-when-auto"><use href="#svg-sun-half"></use></svg>
|
|
<svg class="theme-icon-when-dark"><use href="#svg-moon"></use></svg>
|
|
<svg class="theme-icon-when-light"><use href="#svg-sun"></use></svg>
|
|
</button>
|
|
</div>
|
|
<label class="toc-overlay-icon toc-content-icon no-toc" for="__toc">
|
|
<div class="visually-hidden">Toggle table of contents sidebar</div>
|
|
<i class="icon"><svg><use href="#svg-toc"></use></svg></i>
|
|
</label>
|
|
</div>
|
|
<article role="main">
|
|
<h1>Source code for cutlass.op.gemm_grouped</h1><div class="highlight"><pre>
|
|
<span></span><span class="c1">#################################################################################################</span>
|
|
<span class="c1">#</span>
|
|
<span class="c1"># Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.</span>
|
|
<span class="c1"># SPDX-License-Identifier: BSD-3-Clause</span>
|
|
<span class="c1">#</span>
|
|
<span class="c1"># Redistribution and use in source and binary forms, with or without</span>
|
|
<span class="c1"># modification, are permitted provided that the following conditions are met:</span>
|
|
<span class="c1">#</span>
|
|
<span class="c1"># 1. Redistributions of source code must retain the above copyright notice, this</span>
|
|
<span class="c1"># list of conditions and the following disclaimer.</span>
|
|
<span class="c1">#</span>
|
|
<span class="c1"># 2. Redistributions in binary form must reproduce the above copyright notice,</span>
|
|
<span class="c1"># this list of conditions and the following disclaimer in the documentation</span>
|
|
<span class="c1"># and/or other materials provided with the distribution.</span>
|
|
<span class="c1">#</span>
|
|
<span class="c1"># 3. Neither the name of the copyright holder nor the names of its</span>
|
|
<span class="c1"># contributors may be used to endorse or promote products derived from</span>
|
|
<span class="c1"># this software without specific prior written permission.</span>
|
|
<span class="c1">#</span>
|
|
<span class="c1"># THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"</span>
|
|
<span class="c1"># AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE</span>
|
|
<span class="c1"># IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE</span>
|
|
<span class="c1"># DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE</span>
|
|
<span class="c1"># FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL</span>
|
|
<span class="c1"># DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR</span>
|
|
<span class="c1"># SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER</span>
|
|
<span class="c1"># CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,</span>
|
|
<span class="c1"># OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE</span>
|
|
<span class="c1"># OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.</span>
|
|
<span class="c1">#</span>
|
|
<span class="c1">#################################################################################################</span>
|
|
|
|
<span class="sd">"""</span>
|
|
<span class="sd"> Ease-of-use interface for constructing, compiling, and running GEMMs.</span>
|
|
|
|
<span class="sd"> The ``GroupedGemm`` interface is meant to allow one to easily instantiate, compile, and run</span>
|
|
<span class="sd"> grouped GEMM operations in CUTLASS via Python, without specifying many configuration parameters.</span>
|
|
<span class="sd"> Under the hood, the interface will select sensible default parameters for the many template</span>
|
|
<span class="sd"> parameters for CUTLASS grouped GEMMs.</span>
|
|
|
|
<span class="sd"> Note: optimal performance is not to be expected from this interface. To achieve optimal</span>
|
|
<span class="sd"> performance, one should specify and tune each configuration parameter.</span>
|
|
|
|
<span class="sd"> The simplest example of using this interface is the following:</span>
|
|
|
|
<span class="sd"> .. highlight:: python</span>
|
|
<span class="sd"> .. code-block:: python</span>
|
|
|
|
<span class="sd"> # As, Bs, Cs, and Ds are torch/numpy/cupy tensor objects</span>
|
|
<span class="sd"> plan = cutlass.op.GroupedGemm(element=cutlass.DataType.f16, layout=cutlass.LayoutType.RowMajor)</span>
|
|
<span class="sd"> plan.run([A0, A1], [B0, B1], [C0, C1], [D0, D1])</span>
|
|
<span class="sd">"""</span>
|
|
|
|
<span class="kn">import</span> <span class="nn">cutlass_bindings</span>
|
|
|
|
<span class="kn">from</span> <span class="nn">cutlass.backend.gemm_operation</span> <span class="kn">import</span> <span class="p">(</span>
|
|
<span class="n">GemmGroupedArguments</span><span class="p">,</span>
|
|
<span class="n">GemmOperationGrouped</span><span class="p">,</span>
|
|
<span class="p">)</span>
|
|
<span class="kn">from</span> <span class="nn">cutlass.backend.library</span> <span class="kn">import</span> <span class="p">(</span>
|
|
<span class="n">DataTypeSize</span><span class="p">,</span>
|
|
<span class="n">SchedulerMode</span><span class="p">,</span>
|
|
<span class="n">TensorDescription</span><span class="p">,</span>
|
|
<span class="n">TileDescription</span><span class="p">,</span>
|
|
<span class="p">)</span>
|
|
<span class="kn">from</span> <span class="nn">cutlass.op.gemm</span> <span class="kn">import</span> <span class="n">Gemm</span>
|
|
<span class="kn">from</span> <span class="nn">cutlass.utils</span> <span class="kn">import</span> <span class="n">check</span><span class="p">,</span> <span class="n">datatypes</span>
|
|
|
|
|
|
<div class="viewcode-block" id="GroupedGemm"><a class="viewcode-back" href="../../../cutlass.op.html#cutlass.op.gemm_grouped.GroupedGemm">[docs]</a><span class="k">class</span> <span class="nc">GroupedGemm</span><span class="p">(</span><span class="n">Gemm</span><span class="p">):</span>
|
|
<span class="w"> </span><span class="sd">"""</span>
|
|
<span class="sd"> Constructs a ``GroupedGemm`` object.</span>
|
|
|
|
<span class="sd"> The data types and layouts of operands A, B, and C, along with the data type of output D</span>
|
|
<span class="sd"> and that used for accumulation, are bound to the ``GroupedGemm`` object throughout its lifetime --</span>
|
|
<span class="sd"> these are not to be changed after a ``GroupedGemm`` has been constructed.</span>
|
|
|
|
<span class="sd"> The constructor has optional parameters for flexibly setting these parameters. Please see the constructor</span>
|
|
<span class="sd"> for ``Gemm`` for examples of these.</span>
|
|
|
|
<span class="sd"> :param cc: compute capability of device to generate kernels for</span>
|
|
<span class="sd"> :type cc: int</span>
|
|
<span class="sd"> :param A: tensor representing data type and layout of operands A</span>
|
|
<span class="sd"> :param B: tensor representing data type and layout of operands B</span>
|
|
<span class="sd"> :param C: tensor representing data type and layout of operands C</span>
|
|
<span class="sd"> :param D: tensor representing data type and layout of operands D</span>
|
|
<span class="sd"> :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B</span>
|
|
<span class="sd"> :param beta: scalar parameter beta from GEMM operation that scales operand C</span>
|
|
<span class="sd"> :param element_accumulator: data type to be used in accumulation of the product of operands A and B</span>
|
|
<span class="sd"> :type element_accumulator: cutlass.DataType</span>
|
|
<span class="sd"> :param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type</span>
|
|
<span class="sd"> :type element: cutlass.DataType</span>
|
|
<span class="sd"> :param layout: generic layout type to be used for operands A, B, C, and D</span>
|
|
<span class="sd"> :type layout: cutlass.LayoutType</span>
|
|
<span class="sd"> :param element_A: data type to be used for operand A</span>
|
|
<span class="sd"> :type element_A: cutlass.DataType</span>
|
|
<span class="sd"> :param element_B: data type to be used for operand B</span>
|
|
<span class="sd"> :type element_B: cutlass.DataType</span>
|
|
<span class="sd"> :param element_C: data type to be used for operand C</span>
|
|
<span class="sd"> :type element_C: cutlass.DataType</span>
|
|
<span class="sd"> :param element_D: data type to be used for operand D</span>
|
|
<span class="sd"> :type element_D: cutlass.DataType</span>
|
|
<span class="sd"> :type layout_A: layout of operand A</span>
|
|
<span class="sd"> :param layout_A: cutlass.LayoutType</span>
|
|
<span class="sd"> :type layout_B: layout of operand B</span>
|
|
<span class="sd"> :param layout_B: cutlass.LayoutType</span>
|
|
<span class="sd"> :type layout_C: layout of operand C</span>
|
|
<span class="sd"> :param layout_C: cutlass.LayoutType</span>
|
|
<span class="sd"> :type layout_D: layout of operand D</span>
|
|
<span class="sd"> :param layout_D: cutlass.LayoutType</span>
|
|
<span class="sd"> """</span>
|
|
|
|
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="p">,</span> <span class="n">A</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">B</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">C</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">D</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">alpha</span><span class="o">=</span><span class="mf">1.0</span><span class="p">,</span> <span class="n">beta</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">element_accumulator</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">element</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">layout</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">element_A</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">element_B</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">element_C</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">element_D</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">layout_A</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">layout_B</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">layout_C</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">cc</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="p">):</span>
|
|
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span>
|
|
<span class="n">A</span><span class="o">=</span><span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="o">=</span><span class="n">B</span><span class="p">,</span> <span class="n">C</span><span class="o">=</span><span class="n">C</span><span class="p">,</span> <span class="n">D</span><span class="o">=</span><span class="n">D</span><span class="p">,</span>
|
|
<span class="n">alpha</span><span class="o">=</span><span class="n">alpha</span><span class="p">,</span> <span class="n">beta</span><span class="o">=</span><span class="n">beta</span><span class="p">,</span>
|
|
<span class="n">element_accumulator</span><span class="o">=</span><span class="n">element_accumulator</span><span class="p">,</span>
|
|
<span class="n">element</span><span class="o">=</span><span class="n">element</span><span class="p">,</span> <span class="n">layout</span><span class="o">=</span><span class="n">layout</span><span class="p">,</span>
|
|
<span class="n">element_A</span><span class="o">=</span><span class="n">element_A</span><span class="p">,</span> <span class="n">element_B</span><span class="o">=</span><span class="n">element_B</span><span class="p">,</span>
|
|
<span class="n">element_C</span><span class="o">=</span><span class="n">element_C</span><span class="p">,</span> <span class="n">element_D</span><span class="o">=</span><span class="n">element_D</span><span class="p">,</span>
|
|
<span class="n">layout_A</span><span class="o">=</span><span class="n">layout_A</span><span class="p">,</span> <span class="n">layout_B</span><span class="o">=</span><span class="n">layout_B</span><span class="p">,</span> <span class="n">layout_C</span><span class="o">=</span><span class="n">layout_C</span><span class="p">,</span>
|
|
<span class="n">cc</span><span class="o">=</span><span class="n">cc</span>
|
|
<span class="p">)</span>
|
|
|
|
<span class="c1"># Grouped GEMM specializations for SM90 are currently unavailable. Revert to using SM80</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">current_cc</span> <span class="o">==</span> <span class="mi">90</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">_reset_options</span><span class="p">(</span><span class="mi">80</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">_reset_operations</span><span class="p">(</span><span class="n">reset_epilogue</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
|
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">name</span> <span class="o">=</span> <span class="s2">"grouped_gemm"</span>
|
|
|
|
<span class="nd">@Gemm</span><span class="o">.</span><span class="n">swizzling_functor</span><span class="o">.</span><span class="n">setter</span>
|
|
<span class="k">def</span> <span class="nf">swizzling_functor</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">swizzling_functor</span><span class="p">):</span>
|
|
<span class="w"> </span><span class="sd">"""</span>
|
|
<span class="sd"> Sets the swizzling functor to the type specified by `swizzling_functor`</span>
|
|
<span class="sd"> """</span>
|
|
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="s1">'Grouped GEMM does not currently support different swizzling functors'</span><span class="p">)</span>
|
|
|
|
<div class="viewcode-block" id="GroupedGemm.construct"><a class="viewcode-back" href="../../../cutlass.op.html#cutlass.op.gemm_grouped.GroupedGemm.construct">[docs]</a> <span class="k">def</span> <span class="nf">construct</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">tile_description</span><span class="p">:</span> <span class="n">TileDescription</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">alignment_A</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">alignment_B</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">alignment_C</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="n">GemmOperationGrouped</span><span class="p">:</span>
|
|
<span class="w"> </span><span class="sd">"""</span>
|
|
<span class="sd"> Constructs a ``cutlass.backend.GemmOperationGrouped`` based on the input parameters and current</span>
|
|
<span class="sd"> kernel specification of the ``Gemm`` object.</span>
|
|
|
|
<span class="sd"> :param tile_description: tile description specifying shapes and operand types to use in the kernel</span>
|
|
<span class="sd"> :type tile_description: cutlass.backend.TileDescription</span>
|
|
<span class="sd"> :param alignment_A: alignment of operand A</span>
|
|
<span class="sd"> :type alignment_A: int</span>
|
|
<span class="sd"> :param alignment_B: alignment of operand B</span>
|
|
<span class="sd"> :type alignment_B: int</span>
|
|
<span class="sd"> :param alignment_C: alignment of operand C</span>
|
|
<span class="sd"> :type alignment_C: int</span>
|
|
|
|
<span class="sd"> :return: operation that was constructed</span>
|
|
<span class="sd"> :rtype: cutlass.backend.GemmOperationGrouped</span>
|
|
<span class="sd"> """</span>
|
|
<span class="n">alignment_preference</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">possible_operations</span><span class="o">.</span><span class="n">alignments</span><span class="p">)</span>
|
|
<span class="n">alignment_A</span> <span class="o">=</span> <span class="n">check</span><span class="o">.</span><span class="n">alignment_or_default</span><span class="p">(</span><span class="n">alignment_A</span><span class="p">,</span> <span class="n">alignment_preference</span><span class="p">)</span>
|
|
<span class="n">alignment_B</span> <span class="o">=</span> <span class="n">check</span><span class="o">.</span><span class="n">alignment_or_default</span><span class="p">(</span><span class="n">alignment_B</span><span class="p">,</span> <span class="n">alignment_preference</span><span class="p">)</span>
|
|
<span class="n">alignment_C</span> <span class="o">=</span> <span class="n">check</span><span class="o">.</span><span class="n">alignment_or_default</span><span class="p">(</span><span class="n">alignment_C</span><span class="p">,</span> <span class="n">alignment_preference</span><span class="p">)</span>
|
|
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">_reset_epilogue_functor_alignment</span><span class="p">(</span><span class="n">alignment_C</span><span class="p">)</span>
|
|
|
|
<span class="n">tensor_A</span> <span class="o">=</span> <span class="n">TensorDescription</span><span class="p">(</span>
|
|
<span class="n">datatypes</span><span class="o">.</span><span class="n">binding_type</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_element_a</span><span class="p">),</span>
|
|
<span class="n">datatypes</span><span class="o">.</span><span class="n">binding_layout</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_layout_a</span><span class="p">),</span>
|
|
<span class="n">alignment_A</span>
|
|
<span class="p">)</span>
|
|
<span class="n">tensor_B</span> <span class="o">=</span> <span class="n">TensorDescription</span><span class="p">(</span>
|
|
<span class="n">datatypes</span><span class="o">.</span><span class="n">binding_type</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_element_b</span><span class="p">),</span>
|
|
<span class="n">datatypes</span><span class="o">.</span><span class="n">binding_layout</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_layout_b</span><span class="p">),</span>
|
|
<span class="n">alignment_B</span>
|
|
<span class="p">)</span>
|
|
<span class="n">tensor_C</span> <span class="o">=</span> <span class="n">TensorDescription</span><span class="p">(</span>
|
|
<span class="n">datatypes</span><span class="o">.</span><span class="n">binding_type</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_element_c</span><span class="p">),</span>
|
|
<span class="n">datatypes</span><span class="o">.</span><span class="n">binding_layout</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_layout_c</span><span class="p">),</span>
|
|
<span class="n">alignment_C</span>
|
|
<span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="n">tile_description</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="n">op</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">possible_operations</span><span class="o">.</span><span class="n">operations</span><span class="p">(</span><span class="n">alignment_A</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
|
|
<span class="n">tile_description</span> <span class="o">=</span> <span class="n">datatypes</span><span class="o">.</span><span class="n">td_from_profiler_op</span><span class="p">(</span><span class="n">op</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">valid</span><span class="p">,</span> <span class="n">err_str</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_valid_tile_description</span><span class="p">(</span><span class="n">tile_description</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="ow">not</span> <span class="n">valid</span><span class="p">:</span>
|
|
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Invalid tile description. </span><span class="si">{</span><span class="n">err_str</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">tile_description</span> <span class="o">=</span> <span class="n">tile_description</span>
|
|
|
|
<span class="n">operation</span> <span class="o">=</span> <span class="n">GemmOperationGrouped</span><span class="p">(</span>
|
|
<span class="n">arch</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">current_cc</span><span class="p">,</span>
|
|
<span class="n">tile_description</span><span class="o">=</span><span class="n">tile_description</span><span class="p">,</span>
|
|
<span class="n">A</span><span class="o">=</span><span class="n">tensor_A</span><span class="p">,</span> <span class="n">B</span><span class="o">=</span><span class="n">tensor_B</span><span class="p">,</span> <span class="n">C</span><span class="o">=</span><span class="n">tensor_C</span><span class="p">,</span>
|
|
<span class="n">epilogue_functor</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">epilogue_functor</span><span class="p">,</span>
|
|
<span class="n">swizzling_functor</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_swizzling_functor</span><span class="p">,</span>
|
|
<span class="n">precompute_mode</span><span class="o">=</span><span class="n">SchedulerMode</span><span class="o">.</span><span class="n">Device</span><span class="p">)</span>
|
|
|
|
<span class="k">return</span> <span class="n">operation</span></div>
|
|
|
|
<div class="viewcode-block" id="GroupedGemm.run"><a class="viewcode-back" href="../../../cutlass.op.html#cutlass.op.gemm_grouped.GroupedGemm.run">[docs]</a> <span class="k">def</span> <span class="nf">run</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">C</span><span class="p">,</span> <span class="n">D</span><span class="p">,</span>
|
|
<span class="n">alpha</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">beta</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">sync</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
|
|
<span class="n">print_module</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">)</span> <span class="o">-></span> <span class="n">GemmGroupedArguments</span><span class="p">:</span>
|
|
<span class="w"> </span><span class="sd">"""</span>
|
|
<span class="sd"> Runs the kernel currently specified.</span>
|
|
|
|
<span class="sd"> By default, this call returns only once the kernel has completed. To launch the kernel</span>
|
|
<span class="sd"> and immediately return, set ``sync=False``. In this case, it is the responsibility of the</span>
|
|
<span class="sd"> caller to syncrhonize the results of the kernel before attempting to access outputs</span>
|
|
<span class="sd"> by calling ``sync()`` on the arguments returned from this call.</span>
|
|
|
|
<span class="sd"> :param A: list of tensors representing data type and layout of operand A</span>
|
|
<span class="sd"> :type A: list</span>
|
|
<span class="sd"> :param B: list of tensors representing data type and layout of operand B</span>
|
|
<span class="sd"> :type B: list</span>
|
|
<span class="sd"> :param C: list of tensors representing data type and layout of operand C</span>
|
|
<span class="sd"> :type C: list</span>
|
|
<span class="sd"> :param D: list of tensors representing data type and layout of operand D</span>
|
|
<span class="sd"> :type D: list</span>
|
|
<span class="sd"> :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B</span>
|
|
<span class="sd"> :param beta: scalar parameter beta from GEMM operation that scales operand C</span>
|
|
<span class="sd"> :param sync: whether the call should wait for the kernel to complete before returning</span>
|
|
<span class="sd"> :type sync: bool</span>
|
|
<span class="sd"> :param print_module: whether to print the emitted C++ code</span>
|
|
<span class="sd"> :type print_module: bool</span>
|
|
|
|
<span class="sd"> :return: arguments passed in to the kernel</span>
|
|
<span class="sd"> :rtype: cutlass.backend.GemmGroupedArguments</span>
|
|
<span class="sd"> """</span>
|
|
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">A</span><span class="p">)</span> <span class="o">!=</span> <span class="nb">len</span><span class="p">(</span><span class="n">B</span><span class="p">)</span> <span class="ow">or</span> <span class="nb">len</span><span class="p">(</span><span class="n">A</span><span class="p">)</span> <span class="o">!=</span> <span class="nb">len</span><span class="p">(</span><span class="n">C</span><span class="p">)</span> <span class="ow">or</span> <span class="nb">len</span><span class="p">(</span><span class="n">A</span><span class="p">)</span> <span class="o">!=</span> <span class="nb">len</span><span class="p">(</span><span class="n">D</span><span class="p">):</span>
|
|
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="s2">"Lengths of A, B, C, and D lists must be equal"</span><span class="p">)</span>
|
|
|
|
<span class="n">problem_sizes</span> <span class="o">=</span> <span class="p">[]</span>
|
|
<span class="n">As</span><span class="p">,</span> <span class="n">Bs</span><span class="p">,</span> <span class="n">Cs</span><span class="p">,</span> <span class="n">Ds</span> <span class="o">=</span> <span class="p">([</span><span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="n">A</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">4</span><span class="p">))</span>
|
|
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">A</span><span class="p">)):</span>
|
|
<span class="n">As</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_verify_tensor</span><span class="p">(</span><span class="n">A</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">A</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_element_a</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_layout_a</span><span class="p">,</span> <span class="s2">"A"</span><span class="p">)</span>
|
|
<span class="n">Bs</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_verify_tensor</span><span class="p">(</span><span class="n">B</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">B</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_element_b</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_layout_b</span><span class="p">,</span> <span class="s2">"B"</span><span class="p">)</span>
|
|
<span class="n">Cs</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_verify_tensor</span><span class="p">(</span><span class="n">C</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">C</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_element_c</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_layout_c</span><span class="p">,</span> <span class="s2">"C"</span><span class="p">)</span>
|
|
<span class="n">Ds</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_verify_tensor</span><span class="p">(</span><span class="n">D</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">D</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_element_d</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_layout_d</span><span class="p">,</span> <span class="s2">"D"</span><span class="p">)</span>
|
|
<span class="n">problem_sizes</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">cutlass_bindings</span><span class="o">.</span><span class="n">gemm</span><span class="o">.</span><span class="n">GemmCoord</span><span class="p">(</span><span class="n">A</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">B</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">A</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]))</span>
|
|
|
|
<span class="n">alpha</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_verify_scalar</span><span class="p">(</span><span class="n">alpha</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">alpha</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_element_c</span><span class="p">,</span> <span class="s2">"alpha"</span><span class="p">)</span>
|
|
<span class="n">beta</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_verify_scalar</span><span class="p">(</span><span class="n">beta</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">beta</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_element_c</span><span class="p">,</span> <span class="s2">"beta"</span><span class="p">)</span>
|
|
|
|
<span class="n">alignment_a</span> <span class="o">=</span> <span class="nb">min</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">possible_operations</span><span class="o">.</span><span class="n">find_alignment</span><span class="p">(</span><span class="n">A</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_layout_a</span><span class="p">)</span> <span class="k">for</span> <span class="n">A</span> <span class="ow">in</span> <span class="n">As</span><span class="p">))</span>
|
|
<span class="n">alignment_b</span> <span class="o">=</span> <span class="nb">min</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">possible_operations</span><span class="o">.</span><span class="n">find_alignment</span><span class="p">(</span><span class="n">B</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_layout_b</span><span class="p">)</span> <span class="k">for</span> <span class="n">B</span> <span class="ow">in</span> <span class="n">Bs</span><span class="p">))</span>
|
|
<span class="n">alignment_c</span> <span class="o">=</span> <span class="nb">min</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">possible_operations</span><span class="o">.</span><span class="n">find_alignment</span><span class="p">(</span><span class="n">C</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_layout_c</span><span class="p">)</span> <span class="k">for</span> <span class="n">C</span> <span class="ow">in</span> <span class="n">Cs</span><span class="p">))</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">tile_description</span><span class="p">,</span> <span class="n">alignment_A</span><span class="o">=</span><span class="n">alignment_a</span><span class="p">,</span> <span class="n">alignment_B</span><span class="o">=</span><span class="n">alignment_b</span><span class="p">,</span>
|
|
<span class="n">alignment_C</span><span class="o">=</span><span class="n">alignment_c</span><span class="p">,</span> <span class="n">print_module</span><span class="o">=</span><span class="n">print_module</span><span class="p">)</span>
|
|
|
|
<span class="n">arguments</span> <span class="o">=</span> <span class="n">GemmGroupedArguments</span><span class="p">(</span>
|
|
<span class="n">operation</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">operation</span><span class="p">,</span>
|
|
<span class="n">problem_sizes</span><span class="o">=</span><span class="n">problem_sizes</span><span class="p">,</span>
|
|
<span class="n">A</span><span class="o">=</span><span class="n">As</span><span class="p">,</span> <span class="n">B</span><span class="o">=</span><span class="n">Bs</span><span class="p">,</span> <span class="n">C</span><span class="o">=</span><span class="n">Cs</span><span class="p">,</span> <span class="n">D</span><span class="o">=</span><span class="n">Ds</span><span class="p">,</span>
|
|
<span class="n">output_op</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">operation</span><span class="o">.</span><span class="n">epilogue_type</span><span class="p">(</span><span class="n">alpha</span><span class="p">,</span> <span class="n">beta</span><span class="p">)</span>
|
|
<span class="p">)</span>
|
|
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">operation</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">arguments</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="n">sync</span><span class="p">:</span>
|
|
<span class="n">arguments</span><span class="o">.</span><span class="n">sync</span><span class="p">()</span>
|
|
|
|
<span class="k">return</span> <span class="n">arguments</span></div></div>
|
|
</pre></div>
|
|
</article>
|
|
</div>
|
|
<footer>
|
|
|
|
<div class="related-pages">
|
|
|
|
|
|
</div>
|
|
<div class="bottom-of-page">
|
|
<div class="left-details">
|
|
<div class="copyright">
|
|
Copyright © 2023, NVIDIA
|
|
</div>
|
|
Made with <a href="https://www.sphinx-doc.org/">Sphinx</a> and <a class="muted-link" href="https://pradyunsg.me">@pradyunsg</a>'s
|
|
|
|
<a href="https://github.com/pradyunsg/furo">Furo</a>
|
|
|
|
</div>
|
|
<div class="right-details">
|
|
<div class="icons">
|
|
<a class="muted-link " href="https://github.com/NVIDIA/cutlass" aria-label="GitHub">
|
|
<svg stroke="currentColor" fill="currentColor" stroke-width="0" viewBox="0 0 16 16">
|
|
<path fill-rule="evenodd" d="M8 0C3.58 0 0 3.58 0 8c0 3.54 2.29 6.53 5.47 7.59.4.07.55-.17.55-.38 0-.19-.01-.82-.01-1.49-2.01.37-2.53-.49-2.69-.94-.09-.23-.48-.94-.82-1.13-.28-.15-.68-.52-.01-.53.63-.01 1.08.58 1.23.82.72 1.21 1.87.87 2.33.66.07-.52.28-.87.51-1.07-1.78-.2-3.64-.89-3.64-3.95 0-.87.31-1.59.82-2.15-.08-.2-.36-1.02.08-2.12 0 0 .67-.21 2.2.82.64-.18 1.32-.27 2-.27.68 0 1.36.09 2 .27 1.53-1.04 2.2-.82 2.2-.82.44 1.1.16 1.92.08 2.12.51.56.82 1.27.82 2.15 0 3.07-1.87 3.75-3.65 3.95.29.25.54.73.54 1.48 0 1.07-.01 1.93-.01 2.2 0 .21.15.46.55.38A8.013 8.013 0 0 0 16 8c0-4.42-3.58-8-8-8z"></path>
|
|
</svg>
|
|
</a>
|
|
|
|
</div>
|
|
</div>
|
|
</div>
|
|
|
|
</footer>
|
|
</div>
|
|
<aside class="toc-drawer no-toc">
|
|
|
|
|
|
|
|
</aside>
|
|
</div>
|
|
</div><script data-url_root="../../../" id="documentation_options" src="../../../_static/documentation_options.js"></script>
|
|
<script src="../../../_static/doctools.js"></script>
|
|
<script src="../../../_static/sphinx_highlight.js"></script>
|
|
<script src="../../../_static/scripts/furo.js"></script>
|
|
<script src="../../../_static/clipboard.min.js"></script>
|
|
<script src="../../../_static/copybutton.js"></script>
|
|
<script src="../../../_static/tabs.js"></script>
|
|
<script crossorigin="anonymous" integrity="sha256-Ae2Vz/4ePdIu6ZyI/5ZGsYnb+m0JlOmKPjt6XZ9JJkA=" src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js"></script>
|
|
</body>
|
|
</html> |