mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-19 22:38:56 +00:00
CUTLASS 3.2.1 (#1113)
* Updates for 3.2.1 release. * Minor fix in gemm op profiler for raster order. * Add scheduler mapping for raster order in the kernels.
This commit is contained in:
36
python/pycute/__init__.py
Normal file
36
python/pycute/__init__.py
Normal file
@@ -0,0 +1,36 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from .int_tuple import *
|
||||
from .layout import *
|
||||
from .swizzle import *
|
||||
from .typing import *
|
||||
230
python/pycute/int_tuple.py
Normal file
230
python/pycute/int_tuple.py
Normal file
@@ -0,0 +1,230 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Functions for manipulating IntTuples
|
||||
"""
|
||||
|
||||
from functools import reduce
|
||||
from itertools import chain
|
||||
from typing import Union
|
||||
from .typing import Integer
|
||||
|
||||
|
||||
def is_int(x):
|
||||
return isinstance(x, Integer)
|
||||
|
||||
|
||||
def is_tuple(x):
|
||||
return isinstance(x, tuple)
|
||||
|
||||
|
||||
def flatten(t):
|
||||
if is_tuple(t):
|
||||
if len(t) == 0:
|
||||
return ()
|
||||
else:
|
||||
return tuple(i for a in t for i in flatten(a))
|
||||
else:
|
||||
return (t,)
|
||||
|
||||
|
||||
def signum(a):
|
||||
return bool(a > 0) - bool(a < 0)
|
||||
|
||||
|
||||
def product(a):
|
||||
if is_tuple(a):
|
||||
return reduce(lambda val,elem : val*product(elem), a, 1)
|
||||
else:
|
||||
return a
|
||||
|
||||
|
||||
def inner_product(a, b):
|
||||
if is_tuple(a): # tuple tuple
|
||||
assert len(a) == len(b)
|
||||
return sum(inner_product(x,y) for x,y in zip(a,b))
|
||||
else: # "int" "int"
|
||||
assert not is_tuple(b)
|
||||
return a * b
|
||||
|
||||
|
||||
def tuple_max(a):
|
||||
if is_tuple(a):
|
||||
return max(tuple_max(x) for x in a)
|
||||
else:
|
||||
return a
|
||||
|
||||
|
||||
def elem_scale(a, b):
|
||||
if is_tuple(a):
|
||||
if is_tuple(b): # tuple tuple
|
||||
assert len(a) == len(b)
|
||||
return tuple(elem_scale(x,y) for x,y in zip(a,b))
|
||||
else: # tuple "int"
|
||||
assert False # Error
|
||||
else:
|
||||
if is_tuple(b): # "int" tuple
|
||||
return elem_scale(a, product(b))
|
||||
else: # "int" "int"
|
||||
return a * b
|
||||
|
||||
|
||||
# Inclusive prefix ceil div with output congruent to input a
|
||||
def shape_div(a, b):
|
||||
if is_tuple(a):
|
||||
if is_tuple(b): # tuple tuple
|
||||
assert len(a) == len(b)
|
||||
return tuple(shape_div(x,y) for x,y in zip(a,b))
|
||||
else: # tuple "int"
|
||||
#r = [shape_div(a[0],b)] + [shape_div(a[i],b := shape_div(b, product(a[i-1]))) for i in range(1,len(a))]
|
||||
r = []
|
||||
for v in a:
|
||||
r.append(shape_div(v,b))
|
||||
b = shape_div(b,product(v))
|
||||
return tuple(r)
|
||||
else:
|
||||
if is_tuple(b): # "int" tuple
|
||||
return shape_div(a, product(b))
|
||||
else: # "int" "int"
|
||||
assert a % b == 0 or b % a == 0
|
||||
#return -(-a // b) # Python exclusive impl: "//" is always floor div
|
||||
if a % b == 0:
|
||||
return a // b
|
||||
else:
|
||||
return signum(a*b)
|
||||
|
||||
|
||||
# Exclusive prefix product with output congruent to input a
|
||||
def prefix_product(a, init=1):
|
||||
if is_tuple(a):
|
||||
if is_tuple(init): # tuple tuple
|
||||
assert len(a) == len(init)
|
||||
return tuple(prefix_product(x,i) for x,i in zip(a,init))
|
||||
else: # tuple "int"
|
||||
#r = [prefix_product(a[0],init)] + [prefix_product(a[i],init := init * product(a[i-1])) for i in range(1,len(a))]
|
||||
r = []
|
||||
for v in a:
|
||||
r.append(prefix_product(v,init))
|
||||
init = init * product(v)
|
||||
return tuple(r)
|
||||
else:
|
||||
if is_tuple(init): # "int" tuple
|
||||
assert False # Error
|
||||
else: # "int" "int"
|
||||
return init
|
||||
|
||||
|
||||
def idx2crd(idx, shape, stride=None):
|
||||
if stride is None:
|
||||
stride = prefix_product(shape)
|
||||
|
||||
if is_tuple(idx):
|
||||
if is_tuple(shape): # tuple tuple tuple
|
||||
assert len(idx) == len(shape) and len(idx) == len(stride)
|
||||
return tuple(idx2crd(i, s, d) for i, s, d in zip(idx,shape,stride))
|
||||
else: # tuple "int" "int"
|
||||
assert False # Error
|
||||
else:
|
||||
if is_tuple(shape): # "int" tuple tuple
|
||||
assert len(shape) == len(stride)
|
||||
return tuple(idx2crd(idx, s, d) for s,d in zip(shape,stride))
|
||||
else: # "int" "int" "int"
|
||||
return (idx // stride) % shape
|
||||
|
||||
|
||||
def crd2idx(crd, shape, stride=None):
|
||||
if stride is None:
|
||||
stride = prefix_product(shape)
|
||||
|
||||
if is_tuple(crd):
|
||||
if is_tuple(shape): # tuple tuple tuple
|
||||
assert len(crd) == len(shape) and len(crd) == len(stride)
|
||||
return sum(crd2idx(c, s, d) for c, s, d in zip(crd, shape, stride))
|
||||
else: # tuple "int" "int"
|
||||
assert False, f"crd={crd}, shape={shape}" # Error
|
||||
else:
|
||||
if crd is None:
|
||||
crd = 0
|
||||
|
||||
if is_tuple(shape): # "int" tuple tuple
|
||||
assert len(shape) == len(stride)
|
||||
result = 0
|
||||
for i in range(len(shape)-1):
|
||||
result += crd2idx(crd % product(shape[i]), shape[i], stride[i])
|
||||
crd = crd // product(shape[i])
|
||||
return result + crd2idx(crd, shape[-1], stride[-1])
|
||||
else: # "int" "int" "int"
|
||||
return crd * stride
|
||||
|
||||
|
||||
# Transform crd into the dst_shape's iteration space
|
||||
def crd2crd(crd, dst_shape, src_shape=None):
|
||||
if is_tuple(crd):
|
||||
if is_tuple(dst_shape): # tuple tuple
|
||||
assert len(crd) == len(dst_shape)
|
||||
return tuple(crd2crd(x, y) for x, y in zip(crd,dst_shape))
|
||||
else: # tuple "int"
|
||||
# Ambiguous unless we have src_shape
|
||||
assert src_shape is not None
|
||||
return crd2idx(crd, src_shape)
|
||||
else:
|
||||
if is_tuple(dst_shape): # "int" tuple
|
||||
return idx2crd(crd, dst_shape)
|
||||
else: # "int" "int"
|
||||
assert crd < dst_shape
|
||||
return crd
|
||||
|
||||
|
||||
# Filter trg according to crd: keep only elements of trg that are paired with None
|
||||
def slice_(crd: Union[None, tuple, int],
|
||||
trg: Union[tuple, int]):
|
||||
if is_tuple(crd):
|
||||
if is_tuple(trg): # tuple tuple
|
||||
assert len(crd) == len(trg)
|
||||
# match C++ behavior of `filter_tuple` using `tuple_cat(...)`
|
||||
return tuple(chain(*filter(lambda x: x != (), [slice_(c, s) for c, s in zip(crd, trg)])))
|
||||
else:
|
||||
assert False # tuple "int" : Error
|
||||
elif crd is None:
|
||||
# match C++ behavior `return cute::tuple<B>{b};`
|
||||
return (trg,)
|
||||
else:
|
||||
return ()
|
||||
|
||||
|
||||
# Determine if None appears at any of an int_tuples' terminals
|
||||
def has_none(a: Union[None, tuple, int]):
|
||||
if is_tuple(a):
|
||||
return any(has_none(v) for v in a)
|
||||
else:
|
||||
return a is None
|
||||
358
python/pycute/layout.py
Normal file
358
python/pycute/layout.py
Normal file
@@ -0,0 +1,358 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Definition of CuTe Layouts and functions to manipulate them
|
||||
"""
|
||||
|
||||
from itertools import chain
|
||||
from typing import Union
|
||||
|
||||
from .int_tuple import *
|
||||
|
||||
|
||||
class LayoutBase:
|
||||
pass
|
||||
|
||||
|
||||
def is_layout(x):
|
||||
return isinstance(x, LayoutBase)
|
||||
|
||||
|
||||
class Layout(LayoutBase):
|
||||
def __init__(self, _shape, _stride=None):
|
||||
self.shape = _shape
|
||||
if _stride is None:
|
||||
self.stride = prefix_product(self.shape)
|
||||
else:
|
||||
self.stride = _stride
|
||||
|
||||
# operator ==
|
||||
def __eq__(self, other):
|
||||
return self.shape == other.shape and self.stride == other.stride
|
||||
|
||||
# operator len(L) (len [rank] like tuples)
|
||||
def __len__(self):
|
||||
if is_tuple(self.shape):
|
||||
return len(self.shape)
|
||||
else:
|
||||
return 1
|
||||
|
||||
# operator () (map coord to idx)
|
||||
def __call__(self, *args):
|
||||
"""
|
||||
Map a logical coordinate to a linear index (Coord has no Underscore slice operators)
|
||||
OR
|
||||
Slice the layout and return the sublayout (Coord has an Underscore slice op)
|
||||
|
||||
Follow the same behavior of `Layout::operator(Coord const&)` in cute C++
|
||||
"""
|
||||
if has_none(args):
|
||||
if len(args) == 1:
|
||||
return Layout(slice_(args[0], self.shape), slice_(args[0], self.stride))
|
||||
else:
|
||||
return Layout(slice_(args, self.shape), slice_(args, self.stride))
|
||||
else:
|
||||
if len(args) == 1:
|
||||
return crd2idx(args[0], self.shape, self.stride)
|
||||
else:
|
||||
return crd2idx(args, self.shape, self.stride)
|
||||
|
||||
# operator [] (get-i like tuples)
|
||||
def __getitem__(self, i):
|
||||
if is_tuple(self.shape):
|
||||
return Layout(self.shape[i], self.stride[i])
|
||||
else:
|
||||
assert i == 0
|
||||
return Layout(self.shape, self.stride)
|
||||
|
||||
# size(layout) Size of the domain
|
||||
def size(self):
|
||||
return product(self.shape)
|
||||
|
||||
# cosize(layout) Size of the codomain
|
||||
def cosize(self):
|
||||
return tuple_max(tuple((1, elem_scale(self.shape, self.stride))))
|
||||
|
||||
# print and str
|
||||
def __str__(self):
|
||||
return f"{self.shape}:{self.stride}"
|
||||
|
||||
# error msgs and representation
|
||||
def __repr__(self):
|
||||
return f"Layout({self.shape},{self.stride})"
|
||||
|
||||
|
||||
# Make Layout from a list of layouts (each layout it's own mode in the result)
|
||||
def make_layout(*layouts):
|
||||
if len(layouts) == 1 and not is_layout(layouts[0]):
|
||||
layouts = layouts[0]
|
||||
|
||||
shape, stride = zip(*((a.shape,a.stride) for a in layouts))
|
||||
return Layout(shape, stride)
|
||||
|
||||
|
||||
# Size of the domain
|
||||
def size(layout):
|
||||
if is_layout(layout):
|
||||
return layout.size()
|
||||
return product(layout)
|
||||
|
||||
|
||||
# Size of the codomain
|
||||
def cosize(layout):
|
||||
return layout.cosize()
|
||||
|
||||
|
||||
# Layout coalesce -- flatten and combine as many modes as possible while preserving the int-to-int function
|
||||
def coalesce(layout, profile=None):
|
||||
if is_tuple(profile):
|
||||
assert len(layout) >= len(profile)
|
||||
return make_layout(chain((coalesce(layout[i], profile[i]) for i in range( 0,len(profile))),
|
||||
(layout[i] for i in range(len(profile),len(layout)))))
|
||||
|
||||
result_shape = [1]
|
||||
result_stride = [0]
|
||||
for (shape,stride) in zip(flatten(layout.shape),flatten(layout.stride)):
|
||||
# skip their shape-1s
|
||||
if shape == 1:
|
||||
continue
|
||||
# replace our shape-1 with anything
|
||||
elif result_shape[-1] == 1:
|
||||
result_shape[-1] = shape
|
||||
result_stride[-1] = stride
|
||||
# merge modes if the shape*stride match
|
||||
elif result_shape[-1] * result_stride[-1] == stride:
|
||||
result_shape[-1] = result_shape[-1] * shape
|
||||
# append a new mode
|
||||
else:
|
||||
result_shape.append(shape)
|
||||
result_stride.append(stride)
|
||||
|
||||
if len(result_shape) == 1:
|
||||
return Layout(result_shape[0], result_stride[0])
|
||||
else:
|
||||
return Layout(tuple(result_shape), tuple(result_stride))
|
||||
|
||||
|
||||
# Layout filter -- replace all stride-0 modes with size-1 and then coalesce to remove them
|
||||
def filter(layout, profile=None):
|
||||
if is_tuple(profile):
|
||||
assert len(layout) >= len(profile)
|
||||
return make_layout(chain((filter(layout[i], profile[i]) for i in range( 0,len(profile))),
|
||||
(layout[i] for i in range(len(profile),len(layout)))))
|
||||
|
||||
result_shape = []
|
||||
result_stride = []
|
||||
for (shape,stride) in zip(flatten(layout.shape),flatten(layout.stride)):
|
||||
# skip their shape-1s and stride-0s
|
||||
if not (shape == 1 or stride == 0):
|
||||
result_shape.append(shape)
|
||||
result_stride.append(stride)
|
||||
|
||||
if len(result_shape) == 0:
|
||||
return Layout(1,0)
|
||||
else:
|
||||
return coalesce(Layout(tuple(result_shape), tuple(result_stride)))
|
||||
|
||||
|
||||
# Layout composition
|
||||
# Use tuples-of-layouts to perform this operation by-mode and None as no-op
|
||||
def composition(layoutA, layoutB):
|
||||
if layoutB is None:
|
||||
return layoutA
|
||||
elif is_int(layoutB):
|
||||
return composition(layoutA, Layout(layoutB))
|
||||
elif is_tuple(layoutB):
|
||||
assert len(layoutA) >= len(layoutB)
|
||||
return make_layout(chain((composition(layoutA[i], layoutB[i]) for i in range( 0,len(layoutB))),
|
||||
(layoutA[i] for i in range(len(layoutB),len(layoutA)))))
|
||||
elif is_tuple(layoutB.shape):
|
||||
return make_layout(composition(layoutA, layoutB_i) for layoutB_i in layoutB)
|
||||
|
||||
if layoutB.stride == 0:
|
||||
return Layout(layoutB.shape, 0)
|
||||
else:
|
||||
result_shape = []
|
||||
result_stride = []
|
||||
rest_shape = layoutB.shape
|
||||
rest_stride = layoutB.stride
|
||||
for (s, d) in zip(flatten(layoutA.shape)[:-1], flatten(layoutA.stride)[:-1]):
|
||||
s1 = shape_div(s, rest_stride)
|
||||
result_shape.append(min(s1,rest_shape))
|
||||
result_stride.append(rest_stride * d)
|
||||
rest_shape = shape_div(rest_shape, abs(s1))
|
||||
rest_stride = shape_div(rest_stride, s)
|
||||
|
||||
result_shape.append(rest_shape)
|
||||
result_stride.append(rest_stride * flatten(layoutA.stride)[-1])
|
||||
|
||||
return coalesce(Layout(tuple(result_shape), tuple(result_stride)))
|
||||
|
||||
|
||||
# Layout complement
|
||||
def complement(layout, max_idx=1):
|
||||
if is_int(layout):
|
||||
return complement(Layout(layout))
|
||||
|
||||
result_shape = []
|
||||
result_stride = []
|
||||
current_idx = 1
|
||||
|
||||
sorted_DS = sorted(zip(flatten(layout.stride), flatten(layout.shape)))
|
||||
for (stride, shape) in sorted_DS:
|
||||
if stride == 0 or shape == 1:
|
||||
continue
|
||||
|
||||
in_bound = current_idx <= shape * stride
|
||||
# To support symbolic value which can't be evaluated now
|
||||
assert (type(in_bound) is not bool) or in_bound
|
||||
|
||||
result_shape.append(stride // current_idx)
|
||||
result_stride.append(current_idx)
|
||||
current_idx = shape * stride
|
||||
|
||||
result_shape.append((max_idx + current_idx - 1) // current_idx) # ceil_div
|
||||
result_stride.append(current_idx)
|
||||
|
||||
return coalesce(Layout(tuple(result_shape), tuple(result_stride)))
|
||||
|
||||
|
||||
# Layout right inverse
|
||||
def right_inverse(layout):
|
||||
if layout is None:
|
||||
return None
|
||||
elif is_int(layout):
|
||||
return Layout(layout)
|
||||
|
||||
result_shape = []
|
||||
result_stride = []
|
||||
current_idx = 1
|
||||
|
||||
flat_shape = flatten(layout.shape)
|
||||
flat_stride = flatten(layout.stride)
|
||||
sorted_DSA = sorted(zip(flat_stride, flat_shape, prefix_product(flat_shape)))
|
||||
for (stride,shape,rstride) in sorted_DSA:
|
||||
if shape == 1:
|
||||
continue
|
||||
if current_idx != stride:
|
||||
break
|
||||
|
||||
result_shape.append(shape)
|
||||
result_stride.append(rstride)
|
||||
current_idx = shape * stride
|
||||
|
||||
return coalesce(Layout(tuple(result_shape), tuple(result_stride)))
|
||||
|
||||
|
||||
# Layout left inverse
|
||||
def left_inverse(layout):
|
||||
if layout is None:
|
||||
return None
|
||||
elif is_int(layout):
|
||||
return Layout(layout)
|
||||
return right_inverse(make_layout(layout, complement(layout)))
|
||||
|
||||
|
||||
# Split a layout by the composition of B and the "rest"
|
||||
# Use tuples-of-layouts to perform this operation by-mode and None as no-op
|
||||
def logical_divide(layoutA, layoutB):
|
||||
if layoutB is None:
|
||||
return layoutA
|
||||
elif is_int(layoutB):
|
||||
return logical_divide(layoutA, Layout(layoutB))
|
||||
elif is_tuple(layoutB):
|
||||
assert len(layoutA) >= len(layoutB)
|
||||
return make_layout(chain((logical_divide(layoutA[i], layoutB[i]) for i in range( 0,len(layoutB))),
|
||||
(layoutA[i] for i in range(len(layoutB),len(layoutA)))))
|
||||
|
||||
return composition(layoutA, make_layout(layoutB, complement(layoutB, size(layoutA))))
|
||||
|
||||
|
||||
# Reproduce a layoutA over a layoutB
|
||||
# Use tuples-of-layouts to perform this operation by-mode and None as no-op
|
||||
def logical_product(layoutA, layoutB):
|
||||
if layoutB is None:
|
||||
return layoutA
|
||||
elif is_int(layoutB):
|
||||
return logical_divide(layoutA, Layout(layoutB))
|
||||
elif is_tuple(layoutB):
|
||||
assert len(layoutA) >= len(layoutB)
|
||||
return make_layout(chain((logical_product(layoutA[i], layoutB[i]) for i in range( 0,len(layoutB))),
|
||||
(layoutA[i] for i in range(len(layoutB),len(layoutA)))))
|
||||
|
||||
return make_layout(layoutA, composition(complement(layoutA, size(layoutA)*cosize(layoutB)), layoutB));
|
||||
|
||||
|
||||
# Gather the modes from a hierarchical logical_divide or logical_product
|
||||
def hier_unzip(splitter, layoutA, layoutB):
|
||||
if layoutB is None:
|
||||
return make_layout(Layout(1,0), layoutA)
|
||||
elif is_tuple(layoutB):
|
||||
assert len(layoutA) >= len(layoutB)
|
||||
# A layout with shape ((A,a),(B,b),(C,c))
|
||||
split = make_layout(hier_unzip(splitter, layoutA[i], layoutB[i]) for i in range(0,len(layoutB)))
|
||||
# Gather to shape ((A,B,C,...),(a,b,c,...,y,z))
|
||||
return make_layout(make_layout( split[i][0] for i in range( 0,len(layoutB))),
|
||||
make_layout(chain((split[i][1] for i in range( 0,len(layoutB))),
|
||||
(layoutA[i] for i in range(len(layoutB),len(layoutA))))))
|
||||
|
||||
# splitter must return a rank-2 layout
|
||||
return splitter(layoutA, layoutB)
|
||||
|
||||
|
||||
# Apply logical divide hierarchically and gather the split modes into two modes
|
||||
def zipped_divide(layoutA, layoutB):
|
||||
return hier_unzip(logical_divide, layoutA, layoutB)
|
||||
|
||||
|
||||
# Perform logical divide hierarchically and gather tiles (B-layouts) into a new mode
|
||||
def tiled_divide(layoutA, layoutB):
|
||||
result = zipped_divide(layoutA, layoutB)
|
||||
return make_layout([result[0]] + [result[1][i] for i in range(len(result[1]))])
|
||||
|
||||
|
||||
# Apply logical product hierarchically and gather the split modes into two modes
|
||||
def zipped_product(layoutA, layoutB):
|
||||
return hier_unzip(logical_product, layoutA, layoutB)
|
||||
|
||||
|
||||
# Perform logical product hierarchically and gather tiles (B-layouts) into a new mode
|
||||
def tiled_product(layoutA, layoutB):
|
||||
result = zipped_product(layoutA, layoutB)
|
||||
return make_layout([result[0]] + [result[1][i] for i in range(len(result[1]))])
|
||||
|
||||
|
||||
def slice_and_offset(crd: tuple,
|
||||
layout: Layout):
|
||||
return (Layout(slice_(crd, layout.shape), slice_(crd, layout.stride)),
|
||||
crd2idx(crd, layout.shape, layout.stride))
|
||||
129
python/pycute/swizzle.py
Normal file
129
python/pycute/swizzle.py
Normal file
@@ -0,0 +1,129 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Methods for layout swizzling
|
||||
"""
|
||||
|
||||
from .layout import *
|
||||
|
||||
|
||||
def shiftr(a, s):
|
||||
return a >> s if s > 0 else shiftl(a, -s)
|
||||
|
||||
|
||||
def shiftl(a, s):
|
||||
return a << s if s > 0 else shiftr(a, -s)
|
||||
|
||||
|
||||
## A generic Swizzle functor
|
||||
# 0bxxxxxxxxxxxxxxxYYYxxxxxxxZZZxxxx
|
||||
# ^--^ Base is the number of least-sig bits to keep constant
|
||||
# ^-^ ^-^ Bits is the number of bits in the mask
|
||||
# ^---------^ Shift is the distance to shift the YYY mask
|
||||
# (pos shifts YYY to the right, neg shifts YYY to the left)
|
||||
#
|
||||
# e.g. Given
|
||||
# 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxZZxxx
|
||||
# the result is
|
||||
# 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxAAxxx where AA = ZZ xor YY
|
||||
#
|
||||
class Swizzle:
|
||||
def __init__(self, bits, base, shift):
|
||||
assert bits >= 0
|
||||
assert base >= 0
|
||||
assert abs(shift) >= bits
|
||||
self.bits = bits
|
||||
self.base = base
|
||||
self.shift = shift
|
||||
bit_msk = (1 << bits) - 1
|
||||
self.yyy_msk = bit_msk << (base + max(0,shift))
|
||||
self.zzz_msk = bit_msk << (base - min(0,shift))
|
||||
|
||||
# operator () (transform integer)
|
||||
def __call__(self, offset):
|
||||
return offset ^ shiftr(offset & self.yyy_msk, self.shift)
|
||||
|
||||
# Size of the domain
|
||||
def size(self):
|
||||
return 1 << (bits + base + abs(shift))
|
||||
|
||||
# Size of the codomain
|
||||
def cosize(self):
|
||||
return self.size()
|
||||
|
||||
# print and str
|
||||
def __str__(self):
|
||||
return f"SW_{self.bits}_{self.base}_{self.shift}"
|
||||
|
||||
# error msgs and representation
|
||||
def __repr__(self):
|
||||
return f"Swizzle({self.bits},{self.base},{self.shift})"
|
||||
|
||||
|
||||
class ComposedLayout(LayoutBase):
|
||||
def __init__(self, layoutB, offset, layoutA):
|
||||
self.layoutB = layoutB
|
||||
self.offset = offset
|
||||
self.layoutA = layoutA
|
||||
|
||||
# operator ==
|
||||
def __eq__(self, other):
|
||||
return self.layoutB == other.layoutB and self.offset == other.offset and self.layoutA == other.layoutA
|
||||
|
||||
# operator len(L) (len [rank] like tuples)
|
||||
def __len__(self):
|
||||
return len(self.layoutA)
|
||||
|
||||
# operator () (map coord to idx)
|
||||
def __call__(self, *args):
|
||||
return self.layoutB(self.offset + self.layoutA(*args))
|
||||
|
||||
# operator [] (get-i like tuples)
|
||||
def __getitem__(self, i):
|
||||
return ComposedLayout(self.layoutB, self.offset, self.layoutA[i])
|
||||
|
||||
# size(layout) Size of the domain
|
||||
def size(self):
|
||||
return size(self.layoutA)
|
||||
|
||||
# cosize(layout) Size of the codomain
|
||||
def cosize(self):
|
||||
return cosize(self.layoutB)
|
||||
|
||||
# print and str
|
||||
def __str__(self):
|
||||
return f"{self.layoutB} o {self.offset} o {self.layoutA}"
|
||||
|
||||
# error msgs and representation
|
||||
def __repr__(self):
|
||||
return f"ComposedLayout({repr(self.layoutB)},{repr(self.offset)},{repr(self.layoutA)})"
|
||||
42
python/pycute/typing.py
Normal file
42
python/pycute/typing.py
Normal file
@@ -0,0 +1,42 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from abc import ABC
|
||||
|
||||
|
||||
class Integer(ABC):
|
||||
@classmethod
|
||||
def __subclasshook__(cls, c):
|
||||
if c in [bool, float]:
|
||||
return False
|
||||
|
||||
return issubclass(c, int)
|
||||
Reference in New Issue
Block a user