mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-20 06:48:59 +00:00
v3.9 (#2185)
* v3.8 update x * fix blackwell gg * doc change * doc change * doc change --------- Co-authored-by: yuzhai <yuzhai@nvidia.com> Co-authored-by: Haicheng Wu <haichengw@nvidia.com> Co-authored-by: Haicheng Wu <57973641+hwu36@users.noreply.github.com>
This commit is contained in:
@@ -117,12 +117,7 @@ def shape_div(a, b):
|
||||
return shape_div(a, product(b))
|
||||
else: # "int" "int"
|
||||
assert a % b == 0 or b % a == 0
|
||||
#return -(-a // b) # Python exclusive impl: "//" is always floor div
|
||||
if a % b == 0:
|
||||
return a // b
|
||||
else:
|
||||
return signum(a*b)
|
||||
|
||||
return (a + b - 1) // b
|
||||
|
||||
# Exclusive prefix product with output congruent to input a
|
||||
def prefix_product(a, init=1):
|
||||
|
||||
@@ -204,19 +204,28 @@ def composition(layoutA, layoutB):
|
||||
else:
|
||||
result_shape = []
|
||||
result_stride = []
|
||||
rest_shape = layoutB.shape
|
||||
rest_stride = layoutB.stride
|
||||
for (s, d) in zip(flatten(layoutA.shape)[:-1], flatten(layoutA.stride)[:-1]):
|
||||
s1 = shape_div(s, rest_stride)
|
||||
result_shape.append(min(s1,rest_shape))
|
||||
result_stride.append(rest_stride * d)
|
||||
rest_shape = shape_div(rest_shape, abs(s1))
|
||||
rest_stride = shape_div(rest_stride, s)
|
||||
rest_shape = layoutB.shape
|
||||
rest_stride = layoutB.stride
|
||||
flat_A = coalesce(layoutA)
|
||||
for (curr_shape, curr_stride) in zip(flatten(flat_A.shape)[:-1], flatten(flat_A.stride)[:-1]):
|
||||
assert curr_shape % rest_stride == 0 or rest_stride % curr_shape == 0
|
||||
new_shape = min(max(1, curr_shape // rest_stride), rest_shape)
|
||||
|
||||
result_shape.append(rest_shape)
|
||||
result_stride.append(rest_stride * flatten(layoutA.stride)[-1])
|
||||
if new_shape != 1:
|
||||
result_shape.append(new_shape)
|
||||
result_stride.append(rest_stride * curr_stride)
|
||||
|
||||
return coalesce(Layout(tuple(result_shape), tuple(result_stride)))
|
||||
rest_shape = rest_shape // new_shape
|
||||
rest_stride = -(-rest_stride // curr_shape) # Python exclusive impl: "//" is always floor div so == ceil_div(abs(rest_stride), curr_shape) * signum(rest_stride)
|
||||
|
||||
if rest_shape != 1 or len(result_shape) == 0:
|
||||
result_shape.append(rest_shape)
|
||||
result_stride.append(rest_stride * flatten(flat_A.stride)[-1])
|
||||
|
||||
if len(result_shape) == 1:
|
||||
return Layout(result_shape[0], result_stride[0])
|
||||
else:
|
||||
return Layout(tuple(result_shape), tuple(result_stride))
|
||||
|
||||
|
||||
# Layout complement
|
||||
|
||||
Reference in New Issue
Block a user