update 3.8 v2 (#2112)

* update 3.8 v2

* update 3.8

---------

Co-authored-by: yuzhai <yuzhai@nvidia.com>
This commit is contained in:
Yujia Zhai
2025-02-19 19:03:14 -08:00
committed by GitHub
parent e9627ce55b
commit b84e9802d8
166 changed files with 3986 additions and 4037 deletions

View File

@@ -86,5 +86,3 @@
#define CUTE_ARCH_FLOAT2_MATH_ENABLED
#endif

View File

@@ -208,6 +208,7 @@ to_CUtensorMapDataType() {
if constexpr (is_same_v<T, uint8_t>) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else
if constexpr (is_same_v<T, float_e4m3_t>) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else
if constexpr (is_same_v<T, float_e5m2_t>) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else
if constexpr (is_same_v<T, float_ue8m0_t>) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else
if constexpr (is_same_v<T, type_erased_dynamic_float8_t>) { return CU_TENSOR_MAP_DATA_TYPE_UINT8;} else
if constexpr (is_same_v<T, uint16_t>) { return CU_TENSOR_MAP_DATA_TYPE_UINT16; } else
if constexpr (is_same_v<T, uint32_t>) { return CU_TENSOR_MAP_DATA_TYPE_UINT32; } else

View File

@@ -956,7 +956,7 @@ template <class a_type, class b_type, class c_type, class sf_type,
UMMA::ScaleIn a_neg = UMMA::ScaleIn::One, UMMA::ScaleIn b_neg = UMMA::ScaleIn::One>
struct SM100_MMA_MXF4_SS
{
static_assert(M == 128, "SM100_MMA_MXF4_SS M-mode size should be 128 for 1 CTA cluster OMMA.");
static_assert(M == 128, "SM100_MMA_MXF4_SS M-mode size should be 128 for 1 CTA cluster MMA.");
static_assert((N % 8 == 0) && (8 <= N) && (N <= 256), "SM100_MMA_MXF4_SS N-mode size should be a multiple of 8 between 8 and 256.");
static_assert((VS == 16) || (VS == 32), "SM100_MMA_MXF4_SS Vector size can only be 16 or 32.");

View File

@@ -45,7 +45,6 @@
namespace cute
{
template <>
struct Copy_Traits<SM100_U8x8_LDSM_T>
{

View File

@@ -372,7 +372,7 @@ void swap(array<T,N>& a, array<T,N>& b)
/// @return A cute::array of the elements of @c t in reverse order.
template <class T, size_t N>
CUTE_HOST_DEVICE constexpr
cute::array<T,N> reverse(cute::array<T,N> const& t)
cute::array<T,N> reverse(cute::array<T,N> const& t)
{
if constexpr (N == 0u) {
return t;
@@ -441,17 +441,6 @@ struct tuple_element<I, cute::array<T,N>>
using type = T;
};
template <class T, size_t N>
struct tuple_size<cute::array<T,N> const>
: CUTE_STL_NAMESPACE::integral_constant<size_t, N>
{};
template <size_t I, class T, size_t N>
struct tuple_element<I, cute::array<T,N> const>
{
using type = T;
};
} // end namespace CUTE_STL_NAMESPACE
#ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD
@@ -477,16 +466,5 @@ struct tuple_element<I, cute::array<T,N>>
using type = T;
};
template <class T, size_t N>
struct tuple_size<cute::array<T,N> const>
: CUTE_STL_NAMESPACE::integral_constant<size_t, N>
{};
template <size_t I, class T, size_t N>
struct tuple_element<I, cute::array<T,N> const>
{
using type = T;
};
} // end namespace std
#endif // CUTE_STL_NAMESPACE_IS_CUDA_STD

View File

@@ -611,17 +611,6 @@ struct tuple_element<I, cute::array_subbyte<T,N>>
using type = T;
};
template <class T, size_t N>
struct tuple_size<const cute::array_subbyte<T,N>>
: CUTE_STL_NAMESPACE::integral_constant<size_t, N>
{};
template <size_t I, class T, size_t N>
struct tuple_element<I, const cute::array_subbyte<T,N>>
{
using type = T;
};
} // end namespace CUTE_STL_NAMESPACE
#ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD
@@ -647,16 +636,5 @@ struct tuple_element<I, cute::array_subbyte<T,N>>
using type = T;
};
template <class T, size_t N>
struct tuple_size<const cute::array_subbyte<T,N>>
: CUTE_STL_NAMESPACE::integral_constant<size_t, N>
{};
template <size_t I, class T, size_t N>
struct tuple_element<I, const cute::array_subbyte<T,N>>
{
using type = T;
};
} // end namespace std
#endif // CUTE_STL_NAMESPACE_IS_CUDA_STD

View File

@@ -1,254 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <cute/util/type_traits.hpp>
#include <cute/numeric/integral_constant.hpp>
#include <cute/container/type_list.hpp>
namespace cute {
namespace detail {
// Empty Structure Optimization
template <bool IsFirstEmpty, bool IsRestEmpty, class... T>
struct ESO;
template <class First, class... Rest>
static constexpr bool is_first_empty_v = cute::is_empty<First>::value;
template <class First, class... Rest>
static constexpr bool is_rest_empty_v = (cute::is_empty<Rest>::value && ...);
template <class... T>
using ESO_t = ESO<is_first_empty_v<T...>, is_rest_empty_v<T...>, T...>;
// Empty First and Empty Rest...
template <class First, class... Rest>
struct ESO<true, true, First, Rest...> {
CUTE_HOST_DEVICE constexpr
ESO() {}
CUTE_HOST_DEVICE constexpr
ESO(First const&, Rest const&...) {}
};
// NonEmpty First and Empty Rest...
template <class First, class... Rest>
struct ESO<false, true, First, Rest...> {
CUTE_HOST_DEVICE constexpr
ESO() : first_{} {}
CUTE_HOST_DEVICE constexpr
ESO(First const& first, Rest const&...) : first_{first} {}
First first_;
};
// Empty First and NonEmpty Rest...
template <class First, class... Rest>
struct ESO<true, false, First, Rest...> {
CUTE_HOST_DEVICE constexpr
ESO() : rest_{} {}
CUTE_HOST_DEVICE constexpr
ESO(First const&, Rest const&... rest) : rest_{rest...} {}
ESO_t<Rest...> rest_;
};
// NonEmpty T and NonEmpty Rest...
template <class First, class... Rest>
struct ESO<false, false, First, Rest...> {
CUTE_HOST_DEVICE constexpr
ESO() : first_{}, rest_{} {}
CUTE_HOST_DEVICE constexpr
ESO(First const& first, Rest const&... rest) : first_{first}, rest_{rest...} {}
First first_;
ESO_t<Rest...> rest_;
};
// Get Nth value from ESO
template <size_t N, class T, class... Rest, bool F, bool R>
CUTE_HOST_DEVICE constexpr decltype(auto) getv(ESO<F, R, T, Rest...> const& s) {
if constexpr (N == 0) {
if constexpr (F) { return T{}; }
else { return static_cast<T const&>(s.first_); }
} else {
if constexpr (R) { return cute::tuple_element_t<N-1, cute::type_list<Rest...>>{}; }
else { return getv<N-1>(s.rest_); }
}
}
template <size_t N, class T, class... Rest, bool F, bool R>
CUTE_HOST_DEVICE constexpr decltype(auto) getv(ESO<F, R, T, Rest...>& s) {
if constexpr (N == 0) {
if constexpr (F) { return T{}; }
else { return static_cast<T&>(s.first_); }
} else {
if constexpr (R) { return cute::tuple_element_t<N-1, cute::type_list<Rest...>>{}; }
else { return getv<N-1>(s.rest_); }
}
}
template <size_t N, class T, class... Rest, bool F, bool R>
CUTE_HOST_DEVICE constexpr decltype(auto) getv(ESO<F, R, T, Rest...>&& s) {
if constexpr (N == 0) {
if constexpr (F) { return T{}; }
else { return static_cast<T&&>(s.first_); }
} else {
if constexpr (R) { return cute::tuple_element_t<N-1, cute::type_list<Rest...>>{}; }
else { return getv<N-1>(static_cast<ESO_t<Rest...>&&>(s.rest_)); }
}
}
// findt: Implementation detail of cute::find.
// If X is the first template argument of the tuple, findt returns C<N>.
template <class X, size_t N,
bool IsFirstEmpty, bool IsRestEmpty, class First, class... Rest>
CUTE_HOST_DEVICE constexpr
auto
findt(ESO<IsFirstEmpty, IsRestEmpty, First, Rest...> const& t) noexcept
{
if constexpr (cute::is_same_v<X, First>) {
return C<N>{};
}
else {
static_assert(sizeof...(Rest) != 0,
"The type does not appear in the argument list of the tuple.");
if constexpr (IsRestEmpty) {
// The rest is empty, so creating an instance of it is cheap.
return cute::detail::findt<X, N+1>(ESO_t<Rest...>{});
}
else {
return cute::detail::findt<X, N+1>(t.rest_);
}
}
}
} // end namespace detail
// packed_tuple<T...> is a tuple type that is a standard-layout type
// whenever all of its template arguments are standard layout types:
// (cute::is_standard_layout_v<T> && ...) implies (cute::is_standard_layout_v<packed_tuple<T...>>)
template <class... T>
struct packed_tuple : detail::ESO_t<T...>
{
CUTE_HOST_DEVICE constexpr
packed_tuple() {}
CUTE_HOST_DEVICE constexpr
packed_tuple(T const&... ts)
: detail::ESO_t<T...>(ts...)
{}
};
template <>
struct packed_tuple<> {};
template <size_t I, class... T>
CUTE_HOST_DEVICE constexpr
decltype(auto)
get(packed_tuple<T...> const& t) {
static_assert(I < sizeof...(T), "Index out of range");
return detail::getv<I>(t);
}
template <size_t I, class... T>
CUTE_HOST_DEVICE constexpr
decltype(auto)
get(packed_tuple<T...>& t) {
static_assert(I < sizeof...(T), "Index out of range");
return detail::getv<I>(t);
}
template <size_t I, class... T>
CUTE_HOST_DEVICE constexpr
decltype(auto)
get(packed_tuple<T...>&& t) {
static_assert(I < sizeof...(T), "Index out of range");
return detail::getv<I>(static_cast<detail::ESO_t<T...>&&>(t));
}
template <class... T>
CUTE_HOST_DEVICE constexpr
packed_tuple<T...>
make_packed_tuple(T const&... t)
{
return {t...};
}
// Returns the position of type X (as a static integer) in the tuple
// type's argument list. X must be unique in the argument list.
template <class X, class... T>
CUTE_HOST_DEVICE constexpr
auto
find(packed_tuple<T...> const& t) noexcept
{
return detail::findt<X, 0>(t);
}
} // end namespace cute
namespace CUTE_STL_NAMESPACE
{
template <class... T>
struct tuple_size<cute::packed_tuple<T...>>
: CUTE_STL_NAMESPACE::integral_constant<size_t, sizeof...(T)>
{};
template <size_t I, class... T>
struct tuple_element<I, cute::packed_tuple<T...>>
: CUTE_STL_NAMESPACE::tuple_element<I, CUTE_STL_NAMESPACE::tuple<T...>>
{};
} // end namespace CUTE_STL_NAMESPACE
#ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD
namespace std {
template <class ... T>
struct tuple_size<cute::packed_tuple<T...>>
: CUTE_STL_NAMESPACE::integral_constant<size_t, sizeof...(T)>
{};
template <size_t I, class ... T>
struct tuple_element<I, cute::packed_tuple<T...>>
: CUTE_STL_NAMESPACE::tuple_element<I, cute::packed_tuple<T...>>
{};
} // end namespace std
#endif // CUTE_STL_NAMESPACE_IS_CUDA_STD

View File

@@ -37,169 +37,183 @@
#include <cute/container/cuda_types.hpp>
#include <cute/container/type_list.hpp>
#if defined(CUTLASS_USE_PACKED_TUPLE)
# include <cute/container/packed_tuple.hpp>
#endif
//#include <cute/container/array.hpp> // Advanced optimizations
// cute::tuple is like std::tuple, with two differences.
// cute::tuple is like std::tuple, with differences:
//
// 1. It works on both host and device.
// 2. Its template arguments must be semiregular types.
// 3. It is always a standard-layout type if all of its template arguments are standard-layout types.
// 4. It is always an empty type if all of its template arguments are empty types.
//
// Semiregular types are default constructible and copyable.
// They include "value types" like int or float,
// but do _not_ include references like int& or float&.
// (See std::tie for an example of a tuple of references.)
//
// If the template arguments of cute::tuple are all empty types (in
// the sense of std::is_empty_v), then the cute::tuple is also an
// empty type. Furthermore, if CUTLASS_USE_PACKED_TUPLE is defined,
// cute::tuple is always a standard-layout type if all of its template
// arguments are standard-layout types.
namespace cute
{
#if defined(CUTLASS_USE_PACKED_TUPLE)
template<class... T>
using tuple = packed_tuple<T...>;
#else
namespace detail
{
// This is simplified over the implementations in std::, cuda::std::, and thrust:: by ignoring much of
// Standard-layout types preserve ABI across host-device boundaries.
// They are safe to use as device kernel parameters.
//
// The cute::tuple is also simplified over the implementations in std::, cuda::std::, and thrust:: by ignoring much of
// the conversion SFINAE, special overloading, and avoiding cvref template types.
//
// Over standard-conforming tuple implementations, this appears to accelerate compilation times by over 3x.
// EBO stands for "empty base optimization."
namespace cute
{
namespace detail
{
// ESO stands for "empty structure optimization."
// We use this technique to ensure that cute::tuple
// doesn't need to waste space storing any template arguments
// of cute::tuple that have no data (like integral_constant).
// Otherwise, cute::tuple would need to spend at least 1 byte
// for each of its template arguments.
//
// This is one way in which cute::tuple differs from std::tuple.
// doesn't waste space storing template arguments that have no data (like integral_constant).
// Empty types in the template argument list are not even constructed,
// and do not have unique element addresses. In fact, they are not
// even members of the tuple or stored in any way. Calling `get`
// and do not have unique element addresses. Calling `get`
// constructs and returns an instance of an empty type on demand.
//
// EBO always "holds" a single value of type T.
// N is like an array index that TupleBase uses
// to access the desired tuple element.
template <size_t N, class T, bool IsEmpty = is_empty<T>::value>
struct EBO;
template <class T, size_t N, bool B>
CUTE_HOST_DEVICE constexpr C<N> findt(EBO<N, T, B> const&)
{ return {}; }
template <bool IsFirstEmpty, bool IsRestEmpty, class... T>
struct ESO;
// Specialization for types T that have no data;
// the "static tuple leaf." Valid T here include
// integral_constant<U, Value>, Int<Value>,
// and any other semiregular type
// for which std::is_empty_v<T> is true.
template <size_t N, class T>
struct EBO<N, T, true>
{
template <class First, class... Rest>
static constexpr bool is_first_empty_v = cute::is_empty<First>::value;
template <class First, class... Rest>
static constexpr bool is_rest_empty_v = (cute::is_empty<Rest>::value && ...);
template <class... T>
using ESO_t = ESO<is_first_empty_v<T...>, is_rest_empty_v<T...>, T...>;
// Empty First and Empty Rest...
template <class First, class... Rest>
struct ESO<true, true, First, Rest...> {
CUTE_HOST_DEVICE constexpr
EBO() {}
ESO() {}
CUTE_HOST_DEVICE constexpr
EBO(T const&) {}
ESO(First const&, Rest const&...) {}
};
template <size_t N, class T>
CUTE_HOST_DEVICE constexpr T getv(EBO<N, T, true> const&)
{ return {}; }
// This is a work around approach to solve a shared memory misalign issue (https://github.com/NVIDIA/cutlass/issues/1250).
// Will remove this work around implementation once the corresponding fix in compiler is released.
struct dummy_EBO_base {};
// Specialization for types T that are not empty;
// the "dynamic tuple leaf." Valid T here include int,
// any other integral or floating-point type,
// or any semiregular type for which std::is_empty_v<T> is false.
template <size_t N, class T>
struct EBO<N, T, false> : private dummy_EBO_base
{
// NonEmpty First and Empty Rest...
template <class First, class... Rest>
struct ESO<false, true, First, Rest...> {
CUTE_HOST_DEVICE constexpr
EBO() : t_{} {}
ESO() : first_{} {}
CUTE_HOST_DEVICE constexpr
EBO(T const& t) : t_{t} {}
ESO(First const& first, Rest const&...) : first_{first} {}
T t_;
First first_;
};
template <size_t N, class T>
CUTE_HOST_DEVICE constexpr T const& getv(EBO<N, T, false> const& x)
{ return x.t_; }
template <size_t N, class T>
CUTE_HOST_DEVICE constexpr T& getv(EBO<N, T, false>& x)
{ return x.t_; }
template <size_t N, class T>
CUTE_HOST_DEVICE constexpr T&& getv(EBO<N, T, false>&& x)
{ return cute::move(x.t_); }
template <class IdxSeq, class... T>
struct TupleBase;
// Base class of cute::tuple binds each element to an index
// by inheriting from EBO<i, t> for each (i, t) in (I..., T...).
// The storage (for nonempty t) lives in the base classes.
template <size_t... I, class... T>
struct TupleBase<index_sequence<I...>, T...>
: EBO<I,T>...
{
// Empty First and NonEmpty Rest...
template <class First, class... Rest>
struct ESO<true, false, First, Rest...> {
CUTE_HOST_DEVICE constexpr
TupleBase() {}
ESO() : rest_{} {}
CUTE_HOST_DEVICE constexpr
TupleBase(T const&... t) : EBO<I,T>(t)... {}
ESO(First const&, Rest const&... rest) : rest_{rest...} {}
ESO_t<Rest...> rest_;
};
// NonEmpty T and NonEmpty Rest...
template <class First, class... Rest>
struct ESO<false, false, First, Rest...> {
CUTE_HOST_DEVICE constexpr
ESO() : first_{}, rest_{} {}
CUTE_HOST_DEVICE constexpr
ESO(First const& first, Rest const&... rest) : first_{first}, rest_{rest...} {}
First first_;
ESO_t<Rest...> rest_;
};
// Get Nth value from ESO
template <size_t N, bool F, bool R, class T, class... Rest>
CUTE_HOST_DEVICE constexpr
cute::enable_if_t<cute::is_empty<cute::tuple_element_t<N, cute::type_list<T, Rest...>>>::value,
cute::tuple_element_t<N, cute::type_list<T, Rest...>>>
getv(ESO<F, R, T, Rest...> const&)
{
return {};
}
template <size_t N, bool F, bool R, class T, class... Rest>
CUTE_HOST_DEVICE constexpr
cute::enable_if_t<not cute::is_empty<cute::tuple_element_t<N, cute::type_list<T, Rest...>>>::value,
cute::tuple_element_t<N, cute::type_list<T, Rest...>> const&>
getv(ESO<F, R, T, Rest...> const& s)
{
if constexpr (N == 0) {
return static_cast<T const&>(s.first_);
} else {
return getv<N-1>(s.rest_);
}
}
template <size_t N, bool F, bool R, class T, class... Rest>
CUTE_HOST_DEVICE constexpr
cute::enable_if_t<not cute::is_empty<cute::tuple_element_t<N, cute::type_list<T, Rest...>>>::value,
cute::tuple_element_t<N, cute::type_list<T, Rest...>> &>
getv(ESO<F, R, T, Rest...>& s)
{
if constexpr (N == 0) {
return static_cast<T&>(s.first_);
} else {
return getv<N-1>(s.rest_);
}
}
template <size_t N, bool F, bool R, class T, class... Rest>
CUTE_HOST_DEVICE constexpr
cute::enable_if_t<not cute::is_empty<cute::tuple_element_t<N, cute::type_list<T, Rest...>>>::value,
cute::tuple_element_t<N, cute::type_list<T, Rest...>> &&>
getv(ESO<F, R, T, Rest...>&& s)
{
if constexpr (N == 0) {
return static_cast<T&&>(s.first_);
} else {
return getv<N-1>(static_cast<ESO_t<Rest...>&&>(s.rest_));
}
}
template <class X, size_t N,
bool IsFirstEmpty, bool IsRestEmpty, class First, class... Rest>
CUTE_HOST_DEVICE constexpr
auto
findt(ESO<IsFirstEmpty, IsRestEmpty, First, Rest...> const& t) noexcept
{
if constexpr (cute::is_same_v<X, First>) {
return C<N>{};
} else
if constexpr (sizeof...(Rest) == 0) {
return C<N+1>{};
} else
if constexpr (IsRestEmpty) {
return cute::detail::findt<X, N+1>(ESO_t<Rest...>{});
} else {
return cute::detail::findt<X, N+1>(t.rest_);
}
}
} // end namespace detail
// Attempting to use the following commented-out alias
// in the declaration of `struct tuple` causes MSVC 2022 build errors.
//
//template <class... T>
//using TupleBase = detail::TupleBase<make_index_sequence<sizeof...(T)>, T...>;
// This is the actual cute::tuple class.
// The storage (if any) lives in TupleBase's EBO base classes.
//
// Inheriting from the above alias TupleBase
// causes MSVC 2022 build errors when assigning one tuple to another:
// In summary: this is verbose as a work-around for MSVC build errors.
template <class... T>
struct tuple : detail::TupleBase<make_index_sequence<sizeof...(T)>, T...>
struct tuple : detail::ESO_t<T...>
{
CUTE_HOST_DEVICE constexpr
tuple() {}
CUTE_HOST_DEVICE constexpr
tuple(T const&... t) : detail::TupleBase<make_index_sequence<sizeof...(T)>, T...>(t...) {}
tuple(T const&... t) : detail::ESO_t<T...>(t...) {}
};
template <>
struct tuple<>
{};
//
// get for cute::tuple (just like std::get for std::tuple)
//
struct tuple<> {};
// Returns the element in the ith position of the tuple
template <size_t I, class... T>
CUTE_HOST_DEVICE constexpr
decltype(auto)
@@ -224,25 +238,19 @@ decltype(auto)
get(tuple<T...>&& t) noexcept
{
static_assert(I < sizeof...(T), "Index out of range");
return detail::getv<I>(static_cast<tuple<T...>&&>(t));
return detail::getv<I>(static_cast<detail::ESO_t<T...>&&>(t));
}
//
// find a type X within a cute::tuple
// Requires X to be unique in tuple
// Returns a static integer
//
// Returns the position of type X (as a static integer) in the tuple
// type's argument list. X must be unique in the argument list.
template <class X, class... T>
CUTE_HOST_DEVICE constexpr
auto
find(tuple<T...> const& t) noexcept
{
return detail::findt<X>(t);
return detail::findt<X, 0>(t);
}
#endif // CUTLASS_USE_PACKED_TUPLE
//
// Custom is_tuple trait simply checks the existence of tuple_size
// and assumes std::get<I>(.), std::tuple_element<I,.>
@@ -258,7 +266,7 @@ auto has_tuple_size(...) -> false_type;
template <class T>
struct is_tuple : decltype(detail::has_tuple_size((T*)0)) {};
template<typename T>
template <class T>
constexpr bool is_tuple_v = cute::is_tuple<T>::value;
//
@@ -679,8 +687,6 @@ CUTE_HOST std::ostream& operator<<(std::ostream& os, Tuple const& t)
} // end namespace cute
#if ! defined(CUTLASS_USE_PACKED_TUPLE)
namespace CUTE_STL_NAMESPACE
{
@@ -694,22 +700,8 @@ struct tuple_element<I, cute::tuple<T...>>
: CUTE_STL_NAMESPACE::tuple_element<I, CUTE_STL_NAMESPACE::tuple<T...>>
{};
template <class... T>
struct tuple_size<const cute::tuple<T...>>
: CUTE_STL_NAMESPACE::integral_constant<size_t, sizeof...(T)>
{};
template <size_t I, class... T>
struct tuple_element<I, const cute::tuple<T...>>
: CUTE_STL_NAMESPACE::tuple_element<I, const CUTE_STL_NAMESPACE::tuple<T...>>
{};
} // end namespace CUTE_STL_NAMESPACE
//
// std compatibility
//
#ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD
namespace std
{
@@ -732,17 +724,5 @@ struct tuple_element<I, cute::tuple<T...>>
: CUTE_STL_NAMESPACE::tuple_element<I, CUTE_STL_NAMESPACE::tuple<T...>>
{};
template <class... T>
struct tuple_size<const cute::tuple<T...>>
: CUTE_STL_NAMESPACE::integral_constant<size_t, sizeof...(T)>
{};
template <size_t I, class... T>
struct tuple_element<I, const cute::tuple<T...>>
: CUTE_STL_NAMESPACE::tuple_element<I, const CUTE_STL_NAMESPACE::tuple<T...>>
{};
} // end namespace std
#endif // CUTE_STL_NAMESPACE_IS_CUDA_STD
#endif // CUTLASS_USE_PACKED_TUPLE

View File

@@ -73,17 +73,6 @@ struct tuple_element<I, cute::type_list<T...>>
using type = typename CUTE_STL_NAMESPACE::tuple_element<I, CUTE_STL_NAMESPACE::tuple<T...>>::type;
};
template <class... T>
struct tuple_size<const cute::type_list<T...>>
: CUTE_STL_NAMESPACE::integral_constant<size_t, sizeof...(T)>
{};
template <size_t I, class... T>
struct tuple_element<I, const cute::type_list<T...>>
{
using type = typename CUTE_STL_NAMESPACE::tuple_element<I, CUTE_STL_NAMESPACE::tuple<T...>>::type;
};
} // end namespace std
#ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD
@@ -109,16 +98,5 @@ struct tuple_element<I, cute::type_list<T...>>
using type = typename CUTE_STL_NAMESPACE::tuple_element<I, CUTE_STL_NAMESPACE::tuple<T...>>::type;
};
template <class... T>
struct tuple_size<const cute::type_list<T...>>
: CUTE_STL_NAMESPACE::integral_constant<size_t, sizeof...(T)>
{};
template <size_t I, class... T>
struct tuple_element<I, const cute::type_list<T...>>
{
using type = typename CUTE_STL_NAMESPACE::tuple_element<I, CUTE_STL_NAMESPACE::tuple<T...>>::type;
};
} // end namespace std
#endif // CUTE_STL_NAMESPACE_IS_CUDA_STD

View File

@@ -330,7 +330,7 @@ ceil_div(IntTupleA const& a, IntTupleB const& b)
constexpr int R = tuple_size<IntTupleA>::value; // Missing ranks in TupleB are implicitly 1
return transform(a, append<R>(b,Int<1>{}), [](auto const& x, auto const& y) { return ceil_div(x,y); });
} else { // tuple int
auto const [result, rest] = fold(a, cute::make_tuple(cute::make_tuple(), b),
auto [result, rest] = fold(a, cute::make_tuple(cute::make_tuple(), b),
[] (auto const& init, auto const& ai) {
return cute::make_tuple(append(get<0>(init), ceil_div(ai, get<1>(init))), ceil_div(get<1>(init), ai));
});
@@ -390,7 +390,7 @@ shape_div(IntTupleA const& a, IntTupleB const& b)
static_assert(tuple_size<IntTupleA>::value == tuple_size<IntTupleB>::value, "Mismatched ranks");
return transform(a, b, [](auto const& x, auto const& y) { return shape_div(x,y); });
} else { // tuple int
auto const [result, rest] = fold(a, cute::make_tuple(cute::make_tuple(), b),
auto [result, rest] = fold(a, cute::make_tuple(cute::make_tuple(), b),
[] (auto const& init, auto const& ai) {
return cute::make_tuple(append(get<0>(init), shape_div(ai, get<1>(init))), shape_div(get<1>(init), ai));
});

View File

@@ -1044,7 +1044,7 @@ composition_impl(LShape const& lhs_shape, LStride const& lhs_stride,
auto result_shape_0 = take<0,R-1>(lhs_shape);
// Mod out the rhs_shape from the lhs_shape
auto const [result_shape_1, rest_shape] = fold(result_shape_0, cute::make_tuple(cute::make_tuple(), rhs_shape),
auto [result_shape_1, rest_shape] = fold(result_shape_0, cute::make_tuple(cute::make_tuple(), rhs_shape),
[] (auto const& init, auto const& si) {
return cute::make_tuple(append(get<0>(init), shape_min(abs(si), get<1>(init))), shape_div(get<1>(init), abs(si)));
});
@@ -1058,7 +1058,7 @@ composition_impl(LShape const& lhs_shape, LStride const& lhs_stride,
auto result_stride_0 = take<0,R-1>(lhs_stride);
// Divide out the rhs_stride from the lhs_shape
auto const [result_shape_1, rest_stride] = fold(result_shape_0, cute::make_tuple(cute::make_tuple(), rhs_stride),
auto [result_shape_1, rest_stride] = fold(result_shape_0, cute::make_tuple(cute::make_tuple(), rhs_stride),
[] (auto const& init, auto const& di) {
return cute::make_tuple(append(get<0>(init), shape_div(di, get<1>(init))), shape_div(get<1>(init), di));
});
@@ -1067,7 +1067,7 @@ composition_impl(LShape const& lhs_shape, LStride const& lhs_stride,
auto result_stride_1 = elem_scale(result_stride_0, shape_div(result_shape_0, result_shape_1));
// Mod out the rhs_shape from the lhs_shape
auto const [result_shape_2, rest_shape] = fold(result_shape_1, cute::make_tuple(cute::make_tuple(), rhs_shape),
auto [result_shape_2, rest_shape] = fold(result_shape_1, cute::make_tuple(cute::make_tuple(), rhs_shape),
[] (auto const& init, auto const& si) {
return cute::make_tuple(append(get<0>(init), shape_min(abs(si), get<1>(init))), shape_div(get<1>(init), abs(si)));
});

View File

@@ -508,16 +508,6 @@ struct tuple_element<I, cute::ArithmeticTuple<T...>>
: CUTE_STL_NAMESPACE::tuple_element<I, CUTE_STL_NAMESPACE::tuple<T...>>
{};
template <class... T>
struct tuple_size<const cute::ArithmeticTuple<T...>>
: CUTE_STL_NAMESPACE::integral_constant<size_t, sizeof...(T)>
{};
template <size_t I, class... T>
struct tuple_element<I, const cute::ArithmeticTuple<T...>>
: CUTE_STL_NAMESPACE::tuple_element<I, const CUTE_STL_NAMESPACE::tuple<T...>>
{};
} // end namespace CUTE_STL_NAMESPACE
#ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD
@@ -542,15 +532,5 @@ struct tuple_element<I, cute::ArithmeticTuple<T...>>
: CUTE_STL_NAMESPACE::tuple_element<I, CUTE_STL_NAMESPACE::tuple<T...>>
{};
template <class... T>
struct tuple_size<const cute::ArithmeticTuple<T...>>
: CUTE_STL_NAMESPACE::integral_constant<size_t, sizeof...(T)>
{};
template <size_t I, class... T>
struct tuple_element<I, const cute::ArithmeticTuple<T...>>
: CUTE_STL_NAMESPACE::tuple_element<I, const CUTE_STL_NAMESPACE::tuple<T...>>
{};
} // end namespace std
#endif // CUTE_STL_NAMESPACE_IS_CUDA_STD

View File

@@ -84,7 +84,6 @@ using CUTE_STL_NAMESPACE::uint16_t;
using CUTE_STL_NAMESPACE::uint32_t;
using CUTE_STL_NAMESPACE::uint64_t;
using cutlass::uint128_t;
template <int N> struct uint_bit;
template <> struct uint_bit< 1> { using type = uint1_t; };
template <> struct uint_bit< 2> { using type = uint2_t; };
@@ -95,7 +94,6 @@ template <> struct uint_bit< 16> { using type = uint16_t; };
template <> struct uint_bit< 32> { using type = uint32_t; };
template <> struct uint_bit< 64> { using type = uint64_t; };
template <> struct uint_bit<128> { using type = cutlass::uint128_t; };
template <int N>
using uint_bit_t = typename uint_bit<N>::type;

View File

@@ -235,7 +235,7 @@ struct Tensor
decltype(auto)
operator()(Coord const& coord) {
if constexpr (has_underscore<Coord>::value) {
auto const& [sliced_layout,offset] = slice_and_offset(coord, layout());
auto [sliced_layout,offset] = slice_and_offset(coord, layout());
return make_tensor(data() + offset, sliced_layout);
} else {
return data()[layout()(coord)];
@@ -249,7 +249,7 @@ struct Tensor
decltype(auto)
operator()(Coord const& coord) const {
if constexpr (has_underscore<Coord>::value) {
auto const& [sliced_layout,offset] = slice_and_offset(coord, layout());
auto [sliced_layout,offset] = slice_and_offset(coord, layout());
return make_tensor(data() + offset, sliced_layout);
} else {
return data()[layout()(coord)];

View File

@@ -102,6 +102,7 @@
#if (!defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && defined(__CUDA_ARCH_FEAT_SM100_ALL))
#define CUTLASS_ARCH_MMA_SM100A_ENABLED 1
#endif
#endif
#endif

View File

@@ -38,10 +38,14 @@
#include "cutlass/cutlass.h"
#ifndef CUDA_CTA_RECONFIG_ACTIVATED
#if (__CUDACC_VER_MAJOR__ >= 12 && \
defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL))
#if defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ >= 12 && ( \
(__CUDA_ARCH__ == 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL)) \
|| (__CUDA_ARCH__ == 1000 && defined(__CUDA_ARCH_FEAT_SM100_ALL)) \
)
#define CUDA_CTA_RECONFIG_ACTIVATED 1
#endif
#endif
namespace cutlass {

View File

@@ -106,7 +106,6 @@ struct CollectiveConv<
using ProblemShape = ConvProblemShape<ConvOp, NumSpatialDimensions>;
// TODO: move pipeline mode tiling into the collective setup phase instead
static_assert(rank(SmemLayoutA{}) == 3, "SmemLayout must be rank 3 (M/N, K, PIPE)");
static_assert((size<0>(TileShape{}) == size<0>(SmemLayoutA{})), "SmemLayout must be compatible with the tile shape.");
static_assert((size<2>(TileShape{}) == size<1>(SmemLayoutA{})), "SmemLayout must be compatible with the tile shape.");

View File

@@ -255,23 +255,27 @@ public:
CUTLASS_HOST_DEVICE
int64_t activation_size() const {
return (N * H * W * C);
return static_cast<int64_t>(N) * static_cast<int64_t>(H) *
static_cast<int64_t>(W) * static_cast<int64_t>(C);
}
/// Returns filter size in number of elements
CUTLASS_HOST_DEVICE
int64_t filter_size() const {
return (K * R * S * C / groups);
return static_cast<int64_t>(K) * static_cast<int64_t>(R) *
static_cast<int64_t>(S) * static_cast<int64_t>(C) /
static_cast<int64_t>(groups);
}
/// Returns output size in number of elements
CUTLASS_HOST_DEVICE
int64_t output_size() const {
return (N * P * Q * K);
return static_cast<int64_t>(N) * static_cast<int64_t>(P) *
static_cast<int64_t>(Q) * static_cast<int64_t>(K);
}
/// Returns padding as Tensor4DCoord
CUTLASS_HOST_DEVICE
cutlass::Tensor4DCoord padding() const {

View File

@@ -285,21 +285,27 @@ public:
CUTLASS_HOST_DEVICE
int64_t activation_size() const {
return (N * D * H * W * C);
return static_cast<int64_t>(N) * static_cast<int64_t>(D) *
static_cast<int64_t>(H) * static_cast<int64_t>(W) *
static_cast<int64_t>(C);
}
/// Returns filter size in number of elements
CUTLASS_HOST_DEVICE
int64_t filter_size() const {
return (K * T * R * S * C);
return static_cast<int64_t>(K) * static_cast<int64_t>(T) *
static_cast<int64_t>(R) * static_cast<int64_t>(S) *
static_cast<int64_t>(C);
}
/// Returns output size in number of elements
CUTLASS_HOST_DEVICE
int64_t output_size() const {
return (N * Z * P * Q * K);
return static_cast<int64_t>(N) * static_cast<int64_t>(Z) *
static_cast<int64_t>(P) * static_cast<int64_t>(Q) *
static_cast<int64_t>(K);
}
/// Returns padding as Coord3D

View File

@@ -114,6 +114,33 @@ public:
return status;
}
// Check that tensor sizes don't exceed maximum supported size
if (kConvolutionalOperator == conv::Operator::kFprop) {
if (args.problem_size.activation_size() * sizeof(ElementA) >=
(1ull << 31) ||
args.problem_size.filter_size() * sizeof(ElementB) >= (1ull << 31) ||
args.problem_size.output_size() * sizeof(ElementC) >= (1ull << 31)) {
return Status::kErrorInvalidProblem;
}
}
else if (kConvolutionalOperator == conv::Operator::kDgrad ||
kConvolutionalOperator == conv::Operator::kDeconv) {
if (args.problem_size.activation_size() * sizeof(ElementC) >=
(1ull << 31) ||
args.problem_size.filter_size() * sizeof(ElementB) >= (1ull << 31) ||
args.problem_size.output_size() * sizeof(ElementA) >= (1ull << 31)) {
return Status::kErrorInvalidProblem;
}
}
else if (kConvolutionalOperator == conv::Operator::kWgrad) {
if (args.problem_size.activation_size() * sizeof(ElementB) >=
(1ull << 31) ||
args.problem_size.filter_size() * sizeof(ElementC) >= (1ull << 31) ||
args.problem_size.output_size() * sizeof(ElementA) >= (1ull << 31)) {
return Status::kErrorInvalidProblem;
}
}
// check group conv constraint
if (args.problem_size.groups != 1) {
if (kGroupMode == conv::GroupMode::kNone) {

View File

@@ -104,7 +104,7 @@ namespace cutlass {
#else // defined(CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL)
#if (__CUDACC_VER_MAJOR__ >= 13)
#if (__CUDACC_VER_MAJOR__ > 12)
#define CUTLASS_CUDA_DRIVER_WRAPPER_DECL(func, ver) \
template <typename... Args> \
@@ -142,7 +142,7 @@ namespace cutlass {
return reinterpret_cast<PFN_##func>(pfn)(args...); \
}
#endif // (__CUDACC_VERSION__ >= 12.5)
#endif // (__CUDACC_VER_MAJOR__ > 12)
#endif // defined(CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL)

View File

@@ -69,7 +69,7 @@ struct LayoutAwareConvertImpl {
auto&& src_vm = cute::recast<SrcArray>(src);
auto&& dst_vm = cute::recast<DstArray>(dst);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i <src.size(); ++i) {
for (int i = 0; i < src_vm.size(); ++i) {
dst_vm(i) = Converter::convert(src_vm(i));
}
}

View File

@@ -60,7 +60,7 @@
#if ! defined(_MSC_VER)
#define CUTLASS_LAMBDA_FUNC_INLINE __attribute__((always_inline))
#else
#define CUTLASS_LAMBDA_FUNC_INLINE
#define CUTLASS_LAMBDA_FUNC_INLINE [[msvc::forceinline]]
#endif
#define CUTLASS_HOST __host__

View File

@@ -98,142 +98,6 @@ sm100_get_epilogue_smem_swizzle_layout_atom() {
}
}
// Attempts to compute a reasonable epilogue tile based on block tile shape or allows the user to provide one.
template <
class OpClass,
class CtaTileShape_MNK,
class EpilogueTileType,
class TmemWarpShape_MN,
class ElementC,
class StrideC,
class ElementD,
class StrideD,
class FusionOp
>
constexpr auto
sm100_compute_tile_shape_or_override() {
using namespace cute;
if constexpr (cute::is_same_v<EpilogueTileType, EpilogueTileAuto> &&
cute::is_same_v<OpClass, arch::OpClassBlockScaledTensorOp> &&
size<1>(CtaTileShape_MNK{}) == 256) {
constexpr int CtaM = size<0>(CtaTileShape_MNK{});
constexpr int WarpM = size<0>(TmemWarpShape_MN{});
constexpr int DpFull = 32;
constexpr int M = cute::min(CtaM, DpFull * WarpM); // target 32dp tmem load
// Note:
// Set Epi_Tile_N to 128 support OverlappingAccum for the largest tile.
// This is a general workable epi_tile_N which does not promise best perf.
return make_tile(Int<M>{}, Int<128>{});
}
else if constexpr (cute::is_same_v<EpilogueTileType, EpilogueTileAuto>) {
constexpr int CtaM = size<0>(CtaTileShape_MNK{});
constexpr int CtaN = size<1>(CtaTileShape_MNK{});
constexpr int WarpM = size<0>(TmemWarpShape_MN{});
constexpr int WarpN = size<1>(TmemWarpShape_MN{});
constexpr bool DisableSource = is_void_v<ElementC>;
constexpr int MaxBits = cute::max(sizeof_bits_v<ElementC>, sizeof_bits_v<ElementD>);
constexpr int DpFull = 32; // tmem datapaths in 1 subpartition
constexpr int M = cute::min(CtaM, DpFull * WarpM); // target 32dp tmem load
constexpr int N_perf = [&]() constexpr { // Known subtile sizes tested for perf
// Epilogues w/o residual load are less sensitive to smem allocation
// Target a fixed amount of compute per epilogue iteration
if (DisableSource) {
if (MaxBits == 4) {
// Make epilogue tile larger to reduce the epilogue iterations.
// 64 is the experimental value. It will minimize epilogue iterations but keep the number of A/B buffers the same.
constexpr int ComputeElts = 8192;
return ComputeElts / M;
}
constexpr int ComputeElts = 4096;
return ComputeElts / M;
}
// Epilogues w/ residual load are more sensitive to smem allocation
// Target optimal smem distribution between epilogue+mainloop based on datatype+tilesize
else {
if (MaxBits == 32) {
return (CtaM > 64 && CtaN <= 128) ? 16 : 32;
}
// Per-column scaling is high register pressure, reduce tile to prevent spills
else if (FusionOp::IsPerColScaleSupported) {
return 32;
}
else if (MaxBits == 16) {
return (CtaN <= 128) ? 32 : 64;
}
else {
return 64;
}
}
}();
constexpr int N_min_C = (DisableSource || detail::is_m_major<StrideC>()) ? 8 * WarpN
: (sizeof_bits_v<ElementC> == 6) ? 128 * WarpN // TMA store only supports SW128B for FP6 data type
: 128 / sizeof_bits_v<ElementC> * WarpN;
constexpr int N_min_D = (detail::is_m_major<StrideD>()) ? 8 * WarpN
: (sizeof_bits_v<ElementD> == 6) ? 128 * WarpN // TMA store only supports SW128B for FP6 data type
: 128 / sizeof_bits_v<ElementD> * WarpN;
constexpr int N = cute::min(CtaN, cute::max(N_perf, N_min_C, N_min_D));
static_assert(CtaN >= N_min_C && CtaN >= N_min_D, "CTA tile too small");
// stride by tmem warp layout and return a by-mode tiler
auto tile_m = Layout<Int<M>>{};
auto tile_n = Layout<Shape <Int<N / WarpN>,Int< WarpN>>,
Stride<Int< 1>,Int<CtaN / WarpN>>>{};
return make_tile(tile_m, coalesce(tile_n));
}
else if constexpr (cute::is_tuple<EpilogueTileType>::value) {
EpilogueTileType epi_tile;
constexpr int M = size<0>(shape(epi_tile));
constexpr int N = size<1>(shape(epi_tile));
static_assert(!is_layout<EpilogueTileType>::value, "EpilogueTile must be a cute::Tile or cute::Shape");
static_assert(TmemWarpShape_MN{} == Shape<_2,_2>{} && (M == 32 || M == 64) ||
TmemWarpShape_MN{} == Shape<_4,_1>{} && (M == 64 || M == 128), "Unsupported tile shape");
static_assert(N % 8 == 0, "Unsupported tile shape");
return epi_tile;
}
else {
static_assert(cutlass::detail::dependent_false<EpilogueTileType>, "Invalid type for EpilogueTileType.");
}
}
template <class EpilogueScheduleType>
static constexpr bool IsPtrArrayDispatchPolicy =
cute::is_same_v<EpilogueScheduleType, PtrArrayTmaWarpSpecialized1Sm> ||
cute::is_same_v<EpilogueScheduleType, PtrArrayTmaWarpSpecialized2Sm>;
template <
class CtaTileShape_MNK,
class EpilogueTile_MN,
class ElementC,
class ElementD,
class Schedule
>
constexpr auto
sm100_get_tma_dispatch_policy() {
using EpilogueTileShape_MN = decltype(product_each(shape(EpilogueTile_MN{})));
constexpr int EpiTiles = size(shape_div(take<0,2>(CtaTileShape_MNK{}), EpilogueTileShape_MN{}));
constexpr int FragmentSize = size(EpilogueTileShape_MN{}) / NumThreadsPerWarpGroup;
// 8b residuals load fast and consume little smem, so the perf cost of waiting on stores to finish outweighs the cost of extra allocation
constexpr bool ReuseSmem = sizeof_bits_v<ElementC> > 8;
constexpr bool DelayTmaStore = false;
constexpr int StagesD = cute::min(EpiTiles, 2);
constexpr int StagesC = ReuseSmem ? cute::max(cute::min(EpiTiles, 4), StagesD+1)
: cute::min(EpiTiles, 4);
if constexpr (detail::IsPtrArrayDispatchPolicy<Schedule>) {
return Sm100PtrArrayTmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmem, DelayTmaStore>{};
}
else
{
return Sm100TmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmem, DelayTmaStore>{};
}
}
/*
* Returns the TMEM_LOAD copy op to be used for the epilogue
* Returned TMEM_LOAD op is such that the thread-value ownership matches the widest available
@@ -344,10 +208,10 @@ sm100_get_tmem_load_op() {
// For complex TF32 kernels
else if constexpr (sizeof_bits_v<ElementAccumulator> == 64 && sizeof_bits_v<ElementD> == 64) {
if constexpr (num_dp == 16) {
return TMEM::op_repeater<SM100_TMEM_LOAD_16dp256b1x, num_col_bits>();
return TMEM::op_repeater<SM100_TMEM_LOAD_16dp256b1x, num_col_bits/2>();
}
else {
return TMEM::op_repeater<SM100_TMEM_LOAD_32dp32b1x, num_col_bits>();
return TMEM::op_repeater<SM100_TMEM_LOAD_32dp32b1x, num_col_bits/2>();
}
}
// For narrow precision output
@@ -376,7 +240,6 @@ sm100_get_smem_store_op() {
static_assert(is_m_major || is_n_major, "Unsupported gmem layout");
// Check for TMEM_LOAD layouts that match the thread-value ownership pattern of stmatrix
// TODO: check copy vectorization instead!
constexpr bool use_stmatrix_m8n8_4x =
(sizeof_bits_v<ElementAccumulator> == 32 && sizeof_bits_v<ElementD> == 32 && is_n_major &&
( cute::is_same_v<AccLoadOp, SM100_TMEM_LOAD_16dp128b2x> ||
@@ -451,22 +314,7 @@ sm100_get_smem_store_op() {
}
}
template <class GmemStrideTypeD, class ElementD>
constexpr auto
sm100_get_register_transform_op() {
using namespace cute;
[[maybe_unused]] constexpr bool is_m_major = cutlass::detail::is_major<0>(GmemStrideTypeD{});
[[maybe_unused]] constexpr bool is_n_major = cutlass::detail::is_major<1>(GmemStrideTypeD{});
static_assert(is_m_major || is_n_major, "Unsupported gmem layout");
if constexpr (sizeof_bits_v<ElementD> == 4 && is_m_major) {
return SM50_Shuffle_U32_2x2Trans_XOR1{};
}
else {
return AutoVectorizingCopyWithAssumedAlignment<128>{};
}
}
// Selects the largest vectorized smem load atom available
// subject to constraint of gmem layout and chosen TMEM_LOAD's thread-value ownership
@@ -503,30 +351,6 @@ sm100_get_smem_load_op() {
}
}
template <class Schedule, class LayoutTag>
constexpr auto
sm100_get_gmem_load_op() {
if constexpr (detail::is_im2col_mode<LayoutTag>) {
return SM90_TMA_LOAD_IM2COL{};
}
else {
return SM90_TMA_LOAD{};
}
}
template <class Schedule, class LayoutTag>
constexpr auto
sm100_get_gmem_store_op() {
if constexpr (detail::is_im2col_mode<LayoutTag>) {
return SM90_TMA_STORE_IM2COL{};
}
else {
return SM90_TMA_STORE{};
}
}
// aux fusion callbacks builder for sm100 tma epilogue
template <
int StagesC,
@@ -622,9 +446,9 @@ struct CallbacksBuilder<
// the fusion operation performed and the dispatch policy to use.
template <
class OpClass,
class CtaTileShape_MNK,
class MmaTileShape_MNK,
class ClusterShape_MNK,
class EpilogueTileType,
class TmemWarpShape_MN,
class ElementAccumulator,
class ElementCompute,
class ElementC_,
@@ -637,62 +461,237 @@ template <
class FusionOpOrCallbacks
>
struct Sm100TmaBuilderImpl {
private:
static constexpr bool Is1SmMma = is_base_of_v<TmaWarpSpecialized1Sm, Schedule>;
static constexpr bool Is2SmMma = is_base_of_v<TmaWarpSpecialized2Sm, Schedule>;
static_assert(Is1SmMma ^ Is2SmMma, "unsupported schedule");
static_assert(not (Is2SmMma && size<0>(ClusterShape_MNK{}) % 2 == 1), "schedule + cluster mismatch");
// Passing void C disables source load + smem allocation
using ElementC = cute::conditional_t<cute::is_void_v<ElementC_>,ElementD,ElementC_>; // prevents void ref breakages
using GmemLayoutTagC = cute::conditional_t<cute::is_void_v<ElementC_>,GmemLayoutTagD,GmemLayoutTagC_>;
using GmemStrideTypeC = cutlass::detail::TagToStrideC_t<GmemLayoutTagC>;
using GmemStrideTypeD = cutlass::detail::TagToStrideC_t<GmemLayoutTagD>;
using CopyOpS2G = decltype(detail::sm100_get_gmem_store_op<Schedule,GmemLayoutTagD>());
using CopyOpG2S = decltype(detail::sm100_get_gmem_load_op<Schedule,GmemLayoutTagC>());
using FusionOp = conditional_t<is_base_of_v<epilogue::fusion::FusionOperation, FusionOpOrCallbacks>,
FusionOpOrCallbacks, epilogue::fusion::FusionOperation>;
using EpilogueTile_MN = decltype(detail::sm100_compute_tile_shape_or_override<
OpClass, CtaTileShape_MNK, EpilogueTileType, TmemWarpShape_MN,
ElementC_, GmemStrideTypeC, ElementD, GmemStrideTypeD, FusionOp>());
using EpilogueTileShape_MN = decltype(product_each(shape(EpilogueTile_MN{})));
using EpilogueWarpTileShape_MN = decltype(shape_div(EpilogueTileShape_MN{}, TmemWarpShape_MN{}));
using AccLoadOp = decltype(detail::sm100_get_tmem_load_op<
GmemStrideTypeD, ElementAccumulator, ElementD, EpilogueWarpTileShape_MN, FusionOp>());
static constexpr bool DisableSource = cute::is_void_v<ElementC_>;
using ElementC = cute::conditional_t<DisableSource,ElementD,ElementC_>; // prevents void ref breakages
using GmemLayoutTagC = cute::conditional_t<DisableSource,GmemLayoutTagD,GmemLayoutTagC_>;
using InternalSmemElementC = typename cutlass::detail::get_unpacked_element_type<ElementC>::type;
using InternalSmemElementD = typename cutlass::detail::get_unpacked_element_type<ElementD>::type;
using DispatchPolicy = decltype(detail::sm100_get_tma_dispatch_policy<
CtaTileShape_MNK, EpilogueTile_MN, ElementC_, ElementD, Schedule>());
using GmemStrideTypeC = cutlass::detail::TagToStrideC_t<GmemLayoutTagC>;
using GmemStrideTypeD = cutlass::detail::TagToStrideC_t<GmemLayoutTagD>;
// TMA builder allows for passing callbacks directly, which is either a fusion::FusionCallbacks
// instance or a direct visitor implementation, e.g. fusion::Sm90LinearCombination
using FusionCallbacks =
typename CallbacksBuilder<
DispatchPolicy,
FusionOpOrCallbacks,
CtaTileShape_MNK,
EpilogueTile_MN,
ElementAccumulator,
AccLoadOp
>::Callbacks;
static constexpr bool IsTaggedFusionOp = is_base_of_v<epilogue::fusion::FusionOperation, FusionOpOrCallbacks>;
using FusionOp = conditional_t<IsTaggedFusionOp, FusionOpOrCallbacks, epilogue::fusion::FusionOperation>;
static constexpr auto
cta_tile_shape() {
if constexpr (Is2SmMma) { // 2x1 threadblock shape
auto [mma_tile_m, mma_tile_n, mma_tile_k] = MmaTileShape_MNK{};
auto cta_tile_m = reverse(shape_div(reverse(mma_tile_m), _2{})); // first MmaTile_M/2 elements, preserve multimode
return make_shape(cta_tile_m, mma_tile_n, mma_tile_k);
}
else { // 1x1 threadblock shape
return MmaTileShape_MNK{};
}
}
using CtaTileShape_MNK = decltype(cta_tile_shape());
static constexpr auto
tmem_warps() {
if constexpr (Is2SmMma && size<0>(MmaTileShape_MNK{}) == 128) {
return Shape<_2,_2>{};
}
else {
return Shape<_4,_1>{};
}
}
using TmemWarpShape_MN = decltype(tmem_warps());
// Attempts to compute a reasonably performant epilogue tile or allows the user to provide one.
static constexpr auto
epilogue_tile() {
using namespace cute;
if constexpr (is_same_v<OpClass, arch::OpClassBlockScaledTensorOp> &&
is_same_v<EpilogueTileType, EpilogueTileAuto> &&
size<1>(CtaTileShape_MNK{}) == 256) {
constexpr int CtaM = size<0>(CtaTileShape_MNK{});
constexpr int WarpM = size<0>(TmemWarpShape_MN{});
constexpr int DpFull = 32;
constexpr int M = cute::min(CtaM, DpFull * WarpM); // target 32dp tmem load
// Note:
// Set Epi_Tile_N to 128 support OverlappingAccum for the largest tile.
// This is a general workable epi_tile_N which does not promise best perf.
return make_tile(Int<M>{}, Int<128>{});
}
else if constexpr (is_same_v<EpilogueTileType, EpilogueTileAuto>) {
constexpr int CtaM = size<0>(CtaTileShape_MNK{});
constexpr int CtaN = size<1>(CtaTileShape_MNK{});
constexpr int WarpM = size<0>(TmemWarpShape_MN{});
constexpr int WarpN = size<1>(TmemWarpShape_MN{});
constexpr int MaxBits = cute::max(sizeof_bits_v<ElementC>, sizeof_bits_v<ElementD>);
constexpr int DpFull = 32; // tmem datapaths in 1 subpartition
constexpr int M = cute::min(CtaM, DpFull * WarpM); // target 32dp tmem load
constexpr int N_perf = [&]() constexpr { // Known subtile sizes tested for perf
// Epilogues w/o residual load are less sensitive to smem allocation
// Target a fixed amount of compute per epilogue iteration
if (DisableSource) {
if (MaxBits == 4) {
// Make epilogue tile larger to reduce the epilogue iterations.
// 64 is the experimental value. It will minimize epilogue iterations but keep the number of A/B buffers the same.
constexpr int ComputeElts = 8192;
return ComputeElts / M;
}
constexpr int ComputeElts = 4096;
return ComputeElts / M;
}
// Epilogues w/ residual load are more sensitive to smem allocation
// Target optimal smem distribution between epilogue+mainloop based on datatype+tilesize
else {
if (MaxBits == 32) {
return (CtaM > 64 && CtaN <= 128) ? 16 : 32;
}
// Per-column scaling is high register pressure, reduce tile to prevent spills
else if (FusionOp::IsPerColScaleSupported) {
return 32;
}
else if (MaxBits == 16) {
return (CtaN <= 128) ? 32 : 64;
}
else {
return 64;
}
}
}();
constexpr int N_min_C = (DisableSource || detail::is_m_major<GmemStrideTypeC>()) ? 8 * WarpN
: (sizeof_bits_v<ElementC> == 6) ? 128 * WarpN // TMA store only supports SW128B for FP6 data type
: 128 / sizeof_bits_v<ElementC> * WarpN;
constexpr int N_min_D = (detail::is_m_major<GmemStrideTypeD>()) ? 8 * WarpN
: (sizeof_bits_v<ElementD> == 6) ? 128 * WarpN // TMA store only supports SW128B for FP6 data type
: 128 / sizeof_bits_v<ElementD> * WarpN;
constexpr int N = cute::min(CtaN, cute::max(N_perf, N_min_C, N_min_D));
static_assert(CtaN >= N_min_C && CtaN >= N_min_D, "CTA tile too small");
// stride by tmem warp layout and return a by-mode tiler
auto tile_m = Layout<Int<M>>{};
auto tile_n = Layout<Shape <Int<N / WarpN>,Int< WarpN>>,
Stride<Int< 1>,Int<CtaN / WarpN>>>{};
return make_tile(tile_m, coalesce(tile_n));
}
else {
static_assert(cute::is_tuple<EpilogueTileType>::value && not is_layout<EpilogueTileType>::value,
"EpilogueTile must be a cute::Tile or cute::Shape");
EpilogueTileType epi_tile;
constexpr int M = size<0>(shape(epi_tile));
constexpr int N = size<1>(shape(epi_tile));
static_assert(N % 8 == 0, "Unsupported tile shape");
return epi_tile;
}
}
using EpilogueTile_MN = decltype(epilogue_tile());
using EpilogueTileShape_MN = decltype(product_each(shape(EpilogueTile_MN{})));
static constexpr int EpiTiles = size(shape_div(take<0,2>(CtaTileShape_MNK{}), EpilogueTileShape_MN{}));
static constexpr int FragmentSize = size(EpilogueTileShape_MN{}) / NumThreadsPerWarpGroup;
using EpilogueWarpTileShape_MN = decltype(shape_div(EpilogueTileShape_MN{}, TmemWarpShape_MN{}));
using AccLoadOp = decltype(detail::sm100_get_tmem_load_op<
GmemStrideTypeD, ElementAccumulator, ElementD, EpilogueWarpTileShape_MN, FusionOp>());
static constexpr auto
dispatch_policy() {
// 8b residuals load fast and consume little smem, so the perf cost of waiting on stores to finish outweighs the cost of extra allocation
constexpr bool ReuseSmem = sizeof_bits_v<ElementC_> > 8;
// TMA store delay performs worse with residual loads
constexpr bool DelayTmaStore = is_void_v<ElementC_>;
constexpr int StagesD = cute::min(EpiTiles, 2);
constexpr int StagesC = ReuseSmem ? cute::max(cute::min(EpiTiles, 4), StagesD+1)
: cute::min(EpiTiles, 4);
if constexpr (is_same_v<Schedule, PtrArrayTmaWarpSpecialized1Sm> ||
is_same_v<Schedule, PtrArrayTmaWarpSpecialized2Sm>) {
constexpr bool DelayTmaStore_ = false; // TMA store delay complicates tensormap updates for Ptr-Array GEMMs
return Sm100PtrArrayTmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmem, DelayTmaStore_>{};
}
else {
return Sm100TmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmem, DelayTmaStore>{};
}
}
static constexpr auto
fusion_callbacks() {
{
return typename CallbacksBuilder<
decltype(dispatch_policy()),
FusionOpOrCallbacks,
CtaTileShape_MNK,
EpilogueTile_MN,
ElementAccumulator,
AccLoadOp
>::Callbacks({},{});
}
}
static constexpr auto
gmem_load_op() {
if constexpr (detail::is_im2col_mode<GmemLayoutTagC>) {
return SM90_TMA_LOAD_IM2COL{};
}
else {
return SM90_TMA_LOAD{};
}
}
static constexpr auto
gmem_store_op() {
if constexpr (detail::is_im2col_mode<GmemLayoutTagD>) {
return SM90_TMA_STORE_IM2COL{};
}
else {
return SM90_TMA_STORE{};
}
}
static constexpr auto
register_shuffle_op() {
using namespace cute;
[[maybe_unused]] constexpr bool is_m_major = cutlass::detail::is_major<0>(GmemStrideTypeD{});
[[maybe_unused]] constexpr bool is_n_major = cutlass::detail::is_major<1>(GmemStrideTypeD{});
static_assert(is_m_major || is_n_major, "Unsupported gmem layout");
if constexpr (sizeof_bits_v<InternalSmemElementD> == 4 && is_m_major) {
return SM50_Shuffle_U32_2x2Trans_XOR1{};
}
else {
return AutoVectorizingCopyWithAssumedAlignment<128>{};
}
}
public:
using CollectiveOp =
cutlass::epilogue::collective::CollectiveEpilogue<
DispatchPolicy,
decltype(dispatch_policy()),
CtaTileShape_MNK,
EpilogueTile_MN,
ElementC_, // Need to pass void through to expose via GemmUniversal
GmemStrideTypeC,
ElementD,
GmemStrideTypeD,
FusionCallbacks,
decltype(fusion_callbacks()),
AccLoadOp,
CopyOpG2S,
decltype(gmem_load_op()),
decltype(detail::sm100_get_epilogue_smem_swizzle_layout_atom<GmemStrideTypeC, InternalSmemElementC, EpilogueTile_MN>()),
decltype(detail::sm100_get_smem_load_op<GmemStrideTypeC, InternalSmemElementC, ElementAccumulator, AccLoadOp>()),
CopyOpS2G,
decltype(gmem_store_op()),
decltype(detail::sm100_get_epilogue_smem_swizzle_layout_atom<GmemStrideTypeD, InternalSmemElementD, EpilogueTile_MN>()),
decltype(detail::sm100_get_smem_store_op<GmemStrideTypeD, InternalSmemElementD, ElementAccumulator, AccLoadOp>()),
decltype(detail::sm100_get_register_transform_op<GmemStrideTypeD, InternalSmemElementD>())
decltype(register_shuffle_op())
>;
};
@@ -702,7 +701,8 @@ struct Sm100TmaBuilderImpl {
// No smem builder
template <
class CtaTileShape_MNK,
class OpClass,
class MmaTileShape_MNK,
class ClusterShape_MNK,
class EpilogueTileType,
class ElementAccumulator,
@@ -718,8 +718,8 @@ template <
>
struct CollectiveBuilder<
arch::Sm100,
arch::OpClassTensorOp,
CtaTileShape_MNK,
OpClass,
MmaTileShape_MNK,
ClusterShape_MNK,
EpilogueTileType,
ElementAccumulator,
@@ -732,11 +732,16 @@ struct CollectiveBuilder<
AlignmentD,
EpilogueScheduleType,
FusionOpOrCallbacks,
cute::enable_if_t<cute::is_same_v<EpilogueScheduleType, NoSmemWarpSpecialized> ||
cute::is_same_v<EpilogueScheduleType, PtrArrayNoSmemWarpSpecialized> >> {
cute::enable_if_t<is_base_of_v<NoSmemWarpSpecialized1Sm, EpilogueScheduleType> ||
is_base_of_v<NoSmemWarpSpecialized2Sm, EpilogueScheduleType> >
> {
private:
static_assert(cute::sizeof_bits_v<ElementD> != 6, "Output element requires TMA");
static_assert(cute::is_same_v<EpilogueTileType, EpilogueTileAuto>, "Epilogue subtiling requires smem");
static_assert(cute::sizeof_bits_v<ElementD> != 4 and cute::sizeof_bits_v<ElementD> != 6, "Output element requires smem");
static constexpr bool Is1SmMma = is_base_of_v<NoSmemWarpSpecialized1Sm, EpilogueScheduleType>;
static constexpr bool Is2SmMma = is_base_of_v<NoSmemWarpSpecialized2Sm, EpilogueScheduleType>;
static_assert(Is1SmMma ^ Is2SmMma, "unsupported schedule");
static_assert(not (Is2SmMma && size<0>(ClusterShape_MNK{}) % 2 == 1), "schedule + cluster mismatch");
static constexpr bool DisableSource = cute::is_void_v<ElementC_>;
using ElementC = cute::conditional_t<DisableSource, ElementD, ElementC_>; // prevents void ref breakages
@@ -744,173 +749,110 @@ struct CollectiveBuilder<
using GmemStrideTypeC = cutlass::detail::TagToStrideC_t<GmemLayoutTagC>;
using GmemStrideTypeD = cutlass::detail::TagToStrideC_t<GmemLayoutTagD>;
using FusionOp = conditional_t<is_base_of_v<epilogue::fusion::FusionOperation, FusionOpOrCallbacks>,
FusionOpOrCallbacks, epilogue::fusion::FusionOperation>;
static constexpr bool IsTaggedFusionOp = is_base_of_v<epilogue::fusion::FusionOperation, FusionOpOrCallbacks>;
using FusionOp = conditional_t<IsTaggedFusionOp, FusionOpOrCallbacks, epilogue::fusion::FusionOperation>;
// use a 4x2 division to select tmem load shape in order to maintain compatability with both (4,1) and (2,2) layouts
using EpilogueTile = decltype(take<0,2>(CtaTileShape_MNK{}));
using EpilogueWarpTileShape_MN = decltype(shape_div(EpilogueTile{}, Shape<_4,_2>{}));
static constexpr auto
cta_tile_shape() {
if constexpr (Is2SmMma) { // 2x1 threadblock shape
auto [mma_tile_m, mma_tile_n, mma_tile_k] = MmaTileShape_MNK{};
auto cta_tile_m = reverse(shape_div(reverse(mma_tile_m), _2{})); // first MmaTile_M/2 elements, preserve multimode
return make_shape(cta_tile_m, mma_tile_n, mma_tile_k);
}
else { // 1x1 threadblock shape
return MmaTileShape_MNK{};
}
}
using CtaTileShape_MNK = decltype(cta_tile_shape());
static constexpr auto
tmem_warps() {
if constexpr (Is2SmMma && size<0>(MmaTileShape_MNK{}) == 128) {
return Shape<_2,_2>{};
}
else {
return Shape<_4,_1>{};
}
}
using TmemWarpShape_MN = decltype(tmem_warps());
static constexpr auto
epilogue_tile() {
using namespace cute;
if constexpr (not is_same_v<EpilogueTileType, EpilogueTileAuto>) {
static_assert(is_tuple_v<EpilogueTileType>, "Shape or Tile");
return EpilogueTileType{};
}
else if constexpr (is_same_v<OpClass,arch::OpClassBlockScaledTensorOp>) { // perf specialized case
constexpr int EpiM = size<0>(CtaTileShape_MNK{});
constexpr int EpiN = cute::min(_64{}, size<1>(CtaTileShape_MNK{}));
return Shape<Int<EpiM>, Int<EpiN>>{};
}
else {
return take<0,2>(CtaTileShape_MNK{});
}
}
using EpilogueTile = decltype(epilogue_tile());
using EpilogueWarpTileShape_MN = decltype(shape_div(EpilogueTile{}, TmemWarpShape_MN{}));
using AccLoadOp = decltype(detail::sm100_get_tmem_load_op<
GmemStrideTypeD, ElementAccumulator, ElementD, EpilogueWarpTileShape_MN, FusionOp>());
static constexpr int FragmentSize = size(EpilogueTile{}) / NumThreadsPerWarpGroup;
using DispatchPolicy = cutlass::epilogue::Sm100NoSmemWarpSpecialized;
static constexpr auto
dispatch_policy() {
if constexpr (is_same_v<EpilogueScheduleType, PtrArrayNoSmemWarpSpecialized1Sm> ||
is_same_v<EpilogueScheduleType, PtrArrayNoSmemWarpSpecialized2Sm>) {
return Sm100PtrArrayNoSmemWarpSpecialized{};
}
else {
return Sm100NoSmemWarpSpecialized{};
}
}
using DispatchPolicy = decltype(dispatch_policy());
using AlignmentCType = Int<AlignmentC>;
using AlignmentDType = Int<AlignmentD>;
static constexpr auto
fusion_callbacks() {
constexpr thread::ScaleType::Kind ScaleType =
DisableSource ? thread::ScaleType::OnlyAlphaScaling : thread::ScaleType::Default;
if constexpr (IsDefaultFusionOp<FusionOp>::value && not is_same_v<OpClass,arch::OpClassBlockScaledTensorOp>) {
// Legacy codepath using thread::LinearCombination, do not expect this to be stable
return thread::LinearCombination<
ElementD, 1, ElementAccumulator, ElementCompute, ScaleType, FusionOp::RoundStyle, ElementC>({});
}
else {
return typename detail::CallbacksBuilder<
DispatchPolicy,
FusionOpOrCallbacks,
CtaTileShape_MNK,
EpilogueTile,
ElementAccumulator,
AccLoadOp
>::Callbacks({},{});
}
}
static constexpr FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest;
static constexpr thread::ScaleType::Kind ScaleType = DisableSource ?
thread::ScaleType::OnlyAlphaScaling : thread::ScaleType::Default;
using FusionCallbacks = cute::conditional_t<
IsDefaultFusionOp<FusionOp>::value,
// Legacy codepath using thread::LinearCombination, do not expect this to be stable
thread::LinearCombination<
ElementD, 1, ElementAccumulator, ElementCompute,
ScaleType, RoundStyle, ElementC>
,
typename detail::CallbacksBuilder<
public:
using CollectiveOp =
cutlass::epilogue::collective::CollectiveEpilogue<
DispatchPolicy,
FusionOpOrCallbacks,
CtaTileShape_MNK,
EpilogueTile,
ElementAccumulator,
AccLoadOp
>::Callbacks
>;
using CollectiveOp = cute::conditional_t<
cute::is_same_v<EpilogueScheduleType, NoSmemWarpSpecialized>,
cutlass::epilogue::collective::CollectiveEpilogue<
cutlass::epilogue::Sm100NoSmemWarpSpecialized,
EpilogueTile,
ElementC_,
GmemStrideTypeC,
ElementD,
GmemStrideTypeD,
FusionCallbacks,
decltype(fusion_callbacks()),
AccLoadOp,
AlignmentCType,
AlignmentDType
>,
cutlass::epilogue::collective::CollectiveEpilogue<
cutlass::epilogue::Sm100PtrArrayNoSmemWarpSpecialized,
EpilogueTile,
ElementC_,
GmemStrideTypeC,
ElementD,
GmemStrideTypeD,
FusionCallbacks,
AccLoadOp
>
>;
};
// No smem builder for OpClassBlockScaledTensorOp
template <
class CtaTileShape_MNK,
class ClusterShape_MNK,
class EpilogueTileType,
class ElementAccumulator,
class ElementCompute,
class ElementC_,
class GmemLayoutTagC_,
int AlignmentC,
class ElementD,
class GmemLayoutTagD,
int AlignmentD,
class EpilogueScheduleType,
class FusionOp
>
struct CollectiveBuilder<
arch::Sm100,
arch::OpClassBlockScaledTensorOp,
CtaTileShape_MNK,
ClusterShape_MNK,
EpilogueTileType,
ElementAccumulator,
ElementCompute,
ElementC_,
GmemLayoutTagC_,
AlignmentC,
ElementD,
GmemLayoutTagD,
AlignmentD,
EpilogueScheduleType,
FusionOp,
cute::enable_if_t<cute::is_same_v<EpilogueScheduleType, NoSmemWarpSpecialized> ||
cute::is_same_v<EpilogueScheduleType, PtrArrayNoSmemWarpSpecialized> >> {
static_assert(cute::sizeof_bits_v<ElementD> != 6, "Output element requires smem");
static constexpr bool DisableSource = cute::is_void_v<ElementC_>;
using ElementC = cute::conditional_t<DisableSource, ElementD, ElementC_>; // prevents void ref breakages
using GmemLayoutTagC = cute::conditional_t<DisableSource, GmemLayoutTagD, GmemLayoutTagC_>;
static constexpr thread::ScaleType::Kind ScaleType = DisableSource ?
thread::ScaleType::OnlyAlphaScaling : thread::ScaleType::Default;
using GmemStrideTypeC = cutlass::detail::TagToStrideC_t<GmemLayoutTagC>;
using GmemStrideTypeD = cutlass::detail::TagToStrideC_t<GmemLayoutTagD>;
static_assert(cute::is_tuple<EpilogueTileType>::value || cute::is_same_v<EpilogueTileType, EpilogueTileAuto>);
using EpilogueTile = cute::conditional_t<cute::is_same_v<EpilogueTileType, EpilogueTileAuto>,
cute::Shape<_128, _64>,
EpilogueTileType
>;
using EpilogueWarpTileShape_MN = decltype(shape_div(EpilogueTile{}, Shape<_4,_1>{}));
using AccLoadOp = decltype(detail::sm100_get_tmem_load_op<
GmemStrideTypeD, ElementAccumulator, ElementD, EpilogueWarpTileShape_MN, FusionOp>());
using DispatchPolicy = cutlass::epilogue::Sm100NoSmemWarpSpecialized;
using AlignmentCType = Int<AlignmentC>;
using AlignmentDType = Int<AlignmentD>;
static constexpr FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest;
static_assert(is_base_of_v<fusion::FusionOperation, FusionOp>, "only support EVT fusions");
using FusionCallbacks =
typename detail::CallbacksBuilder<
DispatchPolicy,
FusionOp,
CtaTileShape_MNK,
EpilogueTile,
ElementAccumulator,
AccLoadOp
>::Callbacks;
using CollectiveOp = cute::conditional_t<
cute::is_same_v<EpilogueScheduleType, NoSmemWarpSpecialized>,
cutlass::epilogue::collective::CollectiveEpilogue<
cutlass::epilogue::Sm100NoSmemWarpSpecialized,
EpilogueTile,
ElementC_,
GmemStrideTypeC,
ElementD,
GmemStrideTypeD,
FusionCallbacks,
AccLoadOp,
AlignmentCType,
AlignmentDType
>,
cutlass::epilogue::collective::CollectiveEpilogue<
cutlass::epilogue::Sm100PtrArrayNoSmemWarpSpecialized,
EpilogueTile,
ElementC_,
GmemStrideTypeC,
ElementD,
GmemStrideTypeD,
FusionCallbacks,
AccLoadOp
>
>;
Int<AlignmentC>,
Int<AlignmentD>
>;
};
// TMA epilogue builder
template <
class OpClass,
class CtaTileShape_MNK, // Static CTA tile shape
class ClusterShape_MNK, // Static cluster shape or dynamic (int, int, _1)
class MmaTileShape_MNK,
class ClusterShape_MNK,
class EpilogueTileType,
class ElementAccumulator,
class ElementCompute,
@@ -926,7 +868,7 @@ template <
struct CollectiveBuilder<
arch::Sm100,
OpClass,
CtaTileShape_MNK,
MmaTileShape_MNK,
ClusterShape_MNK,
EpilogueTileType,
ElementAccumulator,
@@ -940,30 +882,20 @@ struct CollectiveBuilder<
EpilogueScheduleType,
FusionOp,
cute::enable_if_t<
// OpClass
( cute::is_same_v<OpClass, arch::OpClassTensorOp>
|| cute::is_same_v<OpClass, arch::OpClassBlockScaledTensorOp>
) &&
// Epilogue Schedule Type
( cute::is_base_of_v<TmaWarpSpecialized1Sm, EpilogueScheduleType> ||
cute::is_base_of_v<TmaWarpSpecialized2Sm, EpilogueScheduleType>
|| detail::IsPtrArrayDispatchPolicy<EpilogueScheduleType>
)>>
// Only support TensorOp kernels
not cute::is_same_v<OpClass, arch::OpClassSimt> &&
(cute::is_base_of_v<TmaWarpSpecialized1Sm, EpilogueScheduleType> ||
cute::is_base_of_v<TmaWarpSpecialized2Sm, EpilogueScheduleType>)
>
>
{
private:
using TmemWarpShape_MN = cute::conditional_t<size<0>(CtaTileShape_MNK{}) == 64 &&
(cute::is_base_of_v<TmaWarpSpecialized2Sm, EpilogueScheduleType>
|| cute::is_same_v<EpilogueScheduleType, PtrArrayTmaWarpSpecialized2Sm>
),
Shape<_2,_2>, Shape<_4,_1>>;
public:
using CollectiveOp =
typename detail::Sm100TmaBuilderImpl<
OpClass,
CtaTileShape_MNK,
MmaTileShape_MNK,
ClusterShape_MNK,
EpilogueTileType,
TmemWarpShape_MN,
ElementAccumulator,
ElementCompute,
ElementC,
@@ -977,11 +909,11 @@ public:
>::CollectiveOp;
};
// Auto builder
// Auto epilogue builder for TensorOp kernels
template <
class OpClass,
class CtaTileShape_MNK, // Static CTA tile shape
class ClusterShape_MNK, // Static cluster shape or dynamic (int, int, _1)
class MmaTileShape_MNK,
class ClusterShape_MNK,
class EpilogueTileType,
class ElementAccumulator,
class ElementCompute,
@@ -991,13 +923,12 @@ template <
class ElementD,
class GmemLayoutTagD,
int AlignmentD,
class EpilogueScheduleType,
class FusionOp
>
struct CollectiveBuilder<
arch::Sm100,
OpClass,
CtaTileShape_MNK,
MmaTileShape_MNK,
ClusterShape_MNK,
EpilogueTileType,
ElementAccumulator,
@@ -1008,30 +939,41 @@ struct CollectiveBuilder<
ElementD,
GmemLayoutTagD,
AlignmentD,
EpilogueScheduleType,
EpilogueScheduleAuto,
FusionOp,
cute::enable_if_t<
// OpClass
( cute::is_same_v<OpClass, arch::OpClassTensorOp>
|| cute::is_same_v<OpClass, arch::OpClassBlockScaledTensorOp>
)
// Epilogue Schedule Type
&& cute::is_same_v<EpilogueScheduleType, EpilogueScheduleAuto>>
// only for TensorOp kernels
cute::enable_if_t<not cute::is_same_v<OpClass, arch::OpClassSimt>>
>
{
private:
static_assert(cute::is_same_v<EpilogueTileType, EpilogueTileAuto>, "Don't specify epilogue tile with auto schedule");
using TmemWarpShape_MN = cute::conditional_t<size<0>(CtaTileShape_MNK{}) == 64 &&
size<0>(ClusterShape_MNK{}) % 2 == 0
,
Shape<_2,_2>, Shape<_4,_1>>;
static constexpr bool
is_2sm() {
using namespace cute;
constexpr int MmaTileM = size<0>(MmaTileShape_MNK{});
constexpr int ClusterM = size<0>(ClusterShape_MNK{});
constexpr bool StaticClusterM = is_static_v<decltype(get<0>(ClusterShape_MNK{}))>;
constexpr bool EvenClusterM = StaticClusterM && ClusterM % 2 == 0;
if constexpr (not EvenClusterM) {
return false;
}
else if constexpr (is_same_v<OpClass,arch::OpClassBlockScaledTensorOp>) {
return MmaTileM == 256;
}
else {
return MmaTileM == 256 || MmaTileM == 128;
}
}
using EpilogueSchedule = cute::conditional_t<is_2sm(), TmaWarpSpecialized2Sm, TmaWarpSpecialized1Sm>;
public:
static_assert(cute::is_same_v<EpilogueTileType, EpilogueTileAuto>, "Don't specify epilogue tile with auto schedule");
using CollectiveOp =
typename detail::Sm100TmaBuilderImpl<
typename CollectiveBuilder<
arch::Sm100,
OpClass,
CtaTileShape_MNK,
MmaTileShape_MNK,
ClusterShape_MNK,
EpilogueTileType,
TmemWarpShape_MN,
ElementAccumulator,
ElementCompute,
ElementC,
@@ -1040,7 +982,7 @@ public:
ElementD,
GmemLayoutTagD,
AlignmentD,
EpilogueScheduleType,
EpilogueSchedule,
FusionOp
>::CollectiveOp;
};

View File

@@ -356,24 +356,21 @@ public:
}
// Represent the full output tensor, slice to get the tile this CTA is responsible for
Tensor mC = make_tensor(make_gmem_ptr(ptr_C_l), problem_shape_mnl, append<3>(params.dC,_0{})); // (M,N,L)
Tensor mD = make_tensor(make_gmem_ptr(params.ptr_D[l_coord]), problem_shape_mnl, append<3>(params.dD,_0{})); // (M,N,L)
Tensor gC = local_tile(mC, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N)
Tensor gD = local_tile(mD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N)
Tensor gC_epi = flat_divide( gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
Tensor gD_epi = flat_divide( gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
Tensor mC = make_tensor(make_gmem_ptr(ptr_C_l), problem_shape_mnl, append<3>(params.dC,_0{})); // (M,N,L)
Tensor mD = make_tensor(make_gmem_ptr(params.ptr_D[l_coord]), problem_shape_mnl, append<3>(params.dD,_0{})); // (M,N,L)
Tensor gC = local_tile(mC, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N)
Tensor gD = local_tile(mD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N)
// Partition source and destination tiles according to tmem copy T2R partitioning (tTR_)
auto thread_t2r = tiled_t2r.get_slice(threadIdx.x % size(tiled_t2r));
Tensor tTR_gC = thread_t2r.partition_D(gC_epi); // (T2R,T2R_M,T2R_N)
Tensor tTR_gD = thread_t2r.partition_D(gD_epi); // (T2R,T2R_M,T2R_N)
Tensor tTR_gC = thread_t2r.partition_D(gC); // (T2R,T2R_M,T2R_N)
Tensor tTR_gD = thread_t2r.partition_D(gD); // (T2R,T2R_M,T2R_N)
Tensor coordD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l)
Tensor cD = local_tile(coordD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l)
Tensor cD_epi = flat_divide( cD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
Tensor tTR_cD = thread_t2r.partition_D(cD); // (T2R,T2R_M,T2R_N) -> (m,n,l)
Tensor coordD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l)
Tensor cD = local_tile(coordD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l)
Tensor tTR_cD = thread_t2r.partition_D(cD); // (T2R,T2R_M,T2R_N) -> (m,n,l)
// 2. Apply element-wise operation and store to gmem
// source is needed
@@ -410,7 +407,9 @@ template <
class ElementD_,
class StrideD_,
class ThreadEpilogueOp_,
class CopyOpT2R_
class CopyOpT2R_,
class AlignmentC,
class AlignmentD
>
class CollectiveEpilogue<
Sm100PtrArrayNoSmemWarpSpecialized,
@@ -420,7 +419,9 @@ class CollectiveEpilogue<
ElementD_,
StrideD_,
ThreadEpilogueOp_,
CopyOpT2R_
CopyOpT2R_,
AlignmentC,
AlignmentD
> : public detail::Sm100TmaWarpSpecializedAdapter<CollectiveEpilogue<
Sm100PtrArrayNoSmem,
EpilogueTile_,

View File

@@ -372,24 +372,21 @@ public:
auto cta_tiler = take<0,2>(cta_tile_shape_mnk);
// Represent the full output tensor, slice to get the tile this CTA is responsible for
Tensor mC = make_tensor(make_gmem_ptr<GmemElementC>(params.ptr_C), problem_shape_mnl, append<3>(params.dC,_0{})); // (M,N,L)
Tensor mD = make_tensor(make_gmem_ptr(params.ptr_D), problem_shape_mnl, append<3>(params.dD,_0{})); // (M,N,L)
Tensor gC = local_tile(mC, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N)
Tensor gD = local_tile(mD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N)
Tensor gC_epi = flat_divide( gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
Tensor gD_epi = flat_divide( gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
Tensor mC = make_tensor(make_gmem_ptr<GmemElementC>(params.ptr_C), problem_shape_mnl, append<3>(params.dC,_0{})); // (M,N,L)
Tensor mD = make_tensor(make_gmem_ptr(params.ptr_D), problem_shape_mnl, append<3>(params.dD,_0{})); // (M,N,L)
Tensor gC = local_tile(mC, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N)
Tensor gD = local_tile(mD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N)
// Partition source and destination tiles according to tmem copy T2R partitioning (tTR_)
auto thread_t2r = tiled_t2r.get_slice(threadIdx.x % size(tiled_t2r));
Tensor tTR_gC = thread_t2r.partition_D(gC_epi); // (T2R,T2R_M,T2R_N)
Tensor tTR_gD = thread_t2r.partition_D(gD_epi); // (T2R,T2R_M,T2R_N)
Tensor tTR_gC = thread_t2r.partition_D(gC); // (T2R,T2R_M,T2R_N)
Tensor tTR_gD = thread_t2r.partition_D(gD); // (T2R,T2R_M,T2R_N)
Tensor coordCD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l)
Tensor cCD = local_tile(coordCD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l)
Tensor cD_epi = flat_divide( cCD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
Tensor tTR_cCD = thread_t2r.partition_D(cCD); // (T2R,T2R_M,T2R_N) -> (m,n,l)
Tensor coordCD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l)
Tensor cCD = local_tile(coordCD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l)
Tensor tTR_cCD = thread_t2r.partition_D(cCD); // (T2R,T2R_M,T2R_N) -> (m,n,l)
// 2. Apply element-wise operation and store to gmem
ThreadEpilogueOp epilogue_op{params.thread};
@@ -587,18 +584,18 @@ public:
int thread_idx = threadIdx.x % ThreadCount;
Tensor tAcc = accumulators(make_coord(_,_),_0{},_0{}); // (CTA_M,CTA_N)
Tensor tAcc_epi = flat_divide(tAcc, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
Tensor tAcc = accumulators(make_coord(_,_),_0{},_0{}); // (CTA_M,CTA_N)
Tensor tAcc_epi = flat_divide(tAcc, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
TiledCopy tiled_t2r = make_tmem_copy(CopyOpT2R{}, tAcc_epi(_,_,_0{},_0{}));
ThrCopy thread_t2r = tiled_t2r.get_slice(thread_idx);
Tensor tTR_tAcc = thread_t2r.partition_S(tAcc_epi); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N)
Tensor tTR_tAcc = thread_t2r.partition_S(tAcc_epi); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N)
constexpr int FragmentSize = size(EpilogueTile{}) / ThreadCount;
Tensor coordD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l)
Tensor cD = local_tile(coordD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l)
Tensor coordD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l)
Tensor cD = local_tile(coordD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l)
Tensor cD_epi = flat_divide(cD, EpilogueTile{});
Tensor tTR_cD = thread_t2r.partition_D(cD_epi); // (T2R,T2R_M,T2R_N) -> (m,n,l)
Tensor tTR_cD = thread_t2r.partition_D(cD_epi); // (T2R,T2R_M,T2R_N) -> (m,n,l)
Tensor tTR_rAcc = make_tensor<ElementAccumulator>(shape(tTR_cD(_,_,_,_0{},_0{})));
@@ -689,19 +686,22 @@ public:
do_acc_release = iter_m == size<3>(tTR_tAcc)-1 && iter_n == 0;
}
Tensor tTR_cCD_mn = tTR_cCD(_,_,_,epi_m,epi_n);
Tensor tTR_cCD_mn = tTR_cCD(_,_,_,epi_m,epi_n);
cst_callbacks.begin_loop(epi_m, epi_n);
if (is_C_load_needed) {
Tensor tTR_cC_frag = tensor<1>(zipped_divide(coalesce(tTR_cCD_mn), mclC.compose(Int<VC>{})));
Tensor tTR_gC_frg = recast<Array<GmemElementC, VC>>(coalesce(tTR_gC(_,_,_,epi_m,epi_n)));
Tensor tTR_rC_frg = recast<Array<GmemElementC, VC>>(coalesce(tCrC));
if constexpr (not cute::is_void_v<ElementC>) {
if (is_C_load_needed) {
using CVecType = uint_bit_t<VC * sizeof_bits_v<ElementC>>;
Tensor tTR_cC_frag = tensor<1>(zipped_divide(coalesce(tTR_cCD_mn), mclC.compose(Int<VC>{})));
auto pred_fn_C = [&] (auto const&... coords) {
return elem_less(tTR_cC_frag(coords...), problem_shape_mnl);
};
auto pred_fn_C = [&] (auto const&... coords) CUTLASS_LAMBDA_FUNC_INLINE {
return elem_less(tTR_cC_frag(coords...), problem_shape_mnl);
};
copy_if(pred_fn_C, tTR_gC_frg, tTR_rC_frg);
Tensor tTR_gC_frg = recast<CVecType>(coalesce(tTR_gC(_,_,_,epi_m,epi_n)));
Tensor tTR_rC_frg = recast<CVecType>(coalesce(tCrC));
copy_if(pred_fn_C, tTR_gC_frg, tTR_rC_frg);
}
}
// Copy accumulator tile from tmem to register
@@ -733,17 +733,15 @@ public:
Tensor tTR_cD_frag = tensor<1>(zipped_divide(coalesce(tTR_cCD_mn), mclD.compose(Int<VD>{})));
using VecType = uint_bit_t<VD * sizeof_bits_v<ElementD>>;
Tensor tTR_gD_frg = recast<VecType>(coalesce(tTR_gD(_,_,_,epi_m,epi_n)));
Tensor tTR_rD_frg = recast<VecType>(coalesce(tTR_rD));
auto pred_fn_D = [&] (auto const&... coords) CUTLASS_LAMBDA_FUNC_INLINE {
return elem_less(tTR_cD_frag(coords...), problem_shape_mnl);
};
copy_if(pred_fn_D, tTR_rD_frg, tTR_gD_frg);
using VecType = uint_bit_t<VD * sizeof_bits_v<ElementD>>;
Tensor tTR_gD_frg = recast<VecType>(coalesce(tTR_gD(_,_,_,epi_m,epi_n)));
Tensor tTR_rD_frg = recast<VecType>(coalesce(tTR_rD));
copy_if(pred_fn_D, tTR_rD_frg, tTR_gD_frg);
} // for epi_m
} // for epi_n

View File

@@ -340,7 +340,7 @@ public:
_1{});
}
typename Params::TMA_D tma_store_d;
typename Params::TMA_D tma_store_d{};
if constexpr (is_destination_supported) {
ElementD const* ptr_D_first_batch = reinterpret_cast<ElementD const*>(args.ptr_D);
Tensor tensor_d = make_tensor(ptr_D_first_batch, make_layout(make_shape(init_M,init_N,init_L), append<3>(stride_d, _0{})));

View File

@@ -287,7 +287,7 @@ public:
EpilogueTile{});
}
typename Params::TMA_D tma_store_d;
typename Params::TMA_D tma_store_d{};
if constexpr (is_destination_supported) {
Tensor tensor_d = make_tensor(make_gmem_ptr<TmaElementD>(args.ptr_D), make_layout(make_shape(M,N,L), args.dD));
tma_store_d = make_tma_copy_C_sm90(

View File

@@ -44,35 +44,30 @@ namespace cutlass::epilogue {
// Builder Epilogue Schedules
//
//////////////////////////////////////////////////////////////////////////////
// Pre-Hopper schedules
struct PtrArrayDefault {};
struct EpilogueSimtVectorized {};
struct EpiloguePtrArraySimtVectorized {};
// Hopper direct store schedules
struct NoSmemWarpSpecialized {};
struct PtrArrayNoSmemWarpSpecialized {};
struct PtrArrayNoSmemWarpSpecializedTransposed {};
// Hopper TMA schedules
struct TmaWarpSpecialized {};
struct TmaWarpSpecializedCooperative {};
struct PtrArrayTmaWarpSpecialized { static constexpr int NumEpilogueWarpGroups = 1; };
struct PtrArrayTmaWarpSpecializedPingpong { static constexpr int NumEpilogueWarpGroups = 2; };
struct PtrArrayTmaWarpSpecializedCooperative { static constexpr int NumEpilogueWarpGroups = 2; };
// Blackwell direct store schedules
struct NoSmemWarpSpecialized1Sm {};
struct NoSmemWarpSpecialized2Sm {};
struct PtrArrayNoSmemWarpSpecialized1Sm : NoSmemWarpSpecialized1Sm {};
struct PtrArrayNoSmemWarpSpecialized2Sm : NoSmemWarpSpecialized2Sm {};
// Blackwell TMA schedules
struct TmaWarpSpecialized1Sm {};
struct TmaWarpSpecialized2Sm {};
struct PtrArrayTmaWarpSpecialized1Sm {};
struct PtrArrayTmaWarpSpecialized2Sm {};
struct PtrArrayTmaWarpSpecializedCooperative {
static constexpr int NumEpilogueWarpGroups = 2;
};
// Standard warp specialized epilogue
struct PtrArrayTmaWarpSpecialized {
static constexpr int NumEpilogueWarpGroups = 1;
};
// Pingpong kernel epilogue
struct PtrArrayTmaWarpSpecializedPingpong {
static constexpr int NumEpilogueWarpGroups = 2;
};
struct PtrArrayTmaWarpSpecialized1Sm : TmaWarpSpecialized1Sm {};
struct PtrArrayTmaWarpSpecialized2Sm : TmaWarpSpecialized2Sm {};
// DEPRECATED schedules, will be removed in next release
struct TmaWarpSpecializedElementwiseBase : public TmaWarpSpecialized {};
struct TmaWarpSpecializedCooperativeElementwiseBase : public TmaWarpSpecializedCooperative {};

View File

@@ -53,6 +53,7 @@ struct FusionOperation {
// metadata types/queries that can be overrided
using ElementOutput = void;
using ElementCompute = void;
FloatRoundStyle RoundStyle = FloatRoundStyle::round_indeterminate;
using ElementSource = void;
static constexpr bool IsSourceSupported = false;

View File

@@ -482,7 +482,6 @@ public:
/// Note: The below method only when problem_size_K <= 256 for signed int8 gemm
/// or problem_size_K <= 128 for unsigned int8 gemm. The default approach is
/// above.
/// TODO: Add logic to fallback to the default approach
template <
/// Data type used to load and store< tensors
typename ElementOutput_,

View File

@@ -39,13 +39,8 @@
#include <type_traits>
#endif
#if !defined(__QNX__)
#include <cuda/std/version>
#if defined(_MSC_VER) && defined(CCCL_VERSION) && CCCL_VERSION >= 2008000
#include <cuda/std/__utility/swap.h>
#else
#include <cuda/std/utility>
#endif
#endif
#include "cutlass/cutlass.h"
#include "cutlass/array.h"
#include "cutlass/uint128.h"

View File

@@ -51,18 +51,57 @@
#ifdef _MSC_VER
// Provides support for alternate operators such as 'and', 'or', ...
#include <ciso646>
#include <intrin.h>
#endif // _MSC_VER
#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED)
# define CUTLASS_ARCH_CREDUX_ENABLED
#endif
namespace cutlass {
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail {
CUTLASS_HOST_DEVICE int32_t popcount(int32_t x) {
#if defined(__CUDA_ARCH__)
return __popc(x);
#elif defined(__GNUC__) || defined(__clang__)
return __builtin_popcount(x);
#elif defined(_MSC_VER)
return __popcnt(x);
#else
int32_t count = 0;
while (x) {
count += x & 1;
x >>= 1;
}
return count;
#endif
}
CUTLASS_HOST_DEVICE int64_t popcount(int64_t x) {
#if defined(__CUDA_ARCH__)
return __popcll(x);
#elif defined(__GNUC__) || defined(__clang__)
return __builtin_popcountll(x);
#elif defined(_MSC_VER)
return __popcnt64(x);
#else
int64_t count = 0;
while (x) {
count += x & 1;
x >>= 1;
}
return count;
#endif
}
} // namespace detail
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
struct absolute_value_op {
CUTLASS_HOST_DEVICE
@@ -609,22 +648,7 @@ struct and_popc_add {
CUTLASS_HOST_DEVICE
C operator()(A const &a, B const &b, C const &c) const {
A and_result = a & b;
#if defined(__CUDA__ARCH__)
int popc_result = __popc(and_result);
if constexpr (sizeof(A) == sizeof(uint64_t)) {
popc_result += __popc(static_cast<uint32_t>(and_result >> 32));
}
#else
int popc_result = __builtin_popcount(and_result);
if constexpr (sizeof(A) == sizeof(uint64_t)) {
popc_result += __builtin_popcount(static_cast<uint32_t>(and_result >> 32));
}
#endif
int32_t popc_result = detail::popcount(and_result);
return C(popc_result) + c;
}
};
@@ -646,22 +670,7 @@ struct xor_popc_add {
CUTLASS_HOST_DEVICE
C operator()(A const &a, B const &b, C const &c) const {
A xor_result = a ^ b;
#if defined(__CUDA__ARCH__)
int popc_result = __popc(xor_result);
if constexpr (sizeof(A) == sizeof(uint64_t)) {
popc_result += __popc(static_cast<uint32_t>(xor_result >> 32));
}
#else
int popc_result = __builtin_popcount(xor_result);
if constexpr (sizeof(A) == sizeof(uint64_t)) {
popc_result += __builtin_popcount(static_cast<uint32_t>(xor_result >> 32));
}
#endif
int32_t popc_result = detail::popcount(xor_result);
return C(popc_result) + c;
}
};
@@ -682,22 +691,7 @@ struct or_popc_add {
CUTLASS_HOST_DEVICE
C operator()(A const &a, B const &b, C const &c) const {
A or_result = a | b;
#if defined(__CUDA__ARCH__)
int popc_result = __popc(or_result);
if constexpr (sizeof(A) == sizeof(uint64_t)) {
popc_result += __popc(static_cast<uint32_t>(or_result >> 32));
}
#else
int popc_result = __builtin_popcount(or_result);
if constexpr (sizeof(A) == sizeof(uint64_t)) {
popc_result += __builtin_popcount(static_cast<uint32_t>(or_result >> 32));
}
#endif
int32_t popc_result = detail::popcount(or_result);
return C(popc_result) + c;
}
};

View File

@@ -567,7 +567,7 @@ sm100_make_trivial_fastFP32_tiled_mma() {
}
/**
* @brief Check for U4_UNPACK_U8, U6_UNPACK_U8 alignment requirement
* @brief Check for F8F6F4 alignment requirement
*
* @tparam TileShape_MNK (MmaAtomShape_M, MmaAtomShape_N, TileShape_K)
* @tparam ClusterShape_MNK (cluster_M, cluster_N, cluster_K)

View File

@@ -85,7 +85,7 @@ compute_stage_count_or_override(StageCountAutoCarveout<carveout_bytes_> stage_co
}
// Returns the maximum number of smem tiles that can be used with a given smem capacity in gemm of blockwise/groupwise scale.
template<int capacity_bytes_, class ElementA, class ElementB, class ElementBlockScale, class TileShapeMNK, int ScaleMsPerTile, int carveout_bytes_, int alignment = 128>
template<int capacity_bytes_, class ElementA, class ElementB, class ElementBlockScale, class TileShapeMNK, int ScaleMsPerTile, int ScaleNsPerTile, int carveout_bytes_, int alignment = 128>
constexpr int
compute_stage_count_with_blockwise_scale(StageCountAutoCarveout<carveout_bytes_> stage_count) {
constexpr auto mainloop_pipeline_bytes = sizeof(typename cutlass::PipelineTmaAsync<1>::SharedStorage);
@@ -96,7 +96,7 @@ compute_stage_count_with_blockwise_scale(StageCountAutoCarveout<carveout_bytes_>
cutlass::bits_to_bytes(a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) +
cutlass::bits_to_bytes(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) +
cutlass::bits_to_bytes(scale_bits * ScaleMsPerTile) + // scale of tensor A
cutlass::bits_to_bytes(scale_bits * 1); // scale of tensor B
cutlass::bits_to_bytes(scale_bits * ScaleNsPerTile); // scale of tensor B
constexpr int stage_bytes = cutlass::round_up(stage_bytes_, alignment) +
static_cast<int>(mainloop_pipeline_bytes);
@@ -1043,7 +1043,8 @@ template <
class TileShape_MNK,
class ClusterShape_MNK,
class StageCountType,
int ScaleGranularityM_
int ScaleGranularityM_,
int ScaleGranularityN_
>
struct CollectiveBuilder<
arch::Sm90,
@@ -1058,11 +1059,11 @@ struct CollectiveBuilder<
TileShape_MNK,
ClusterShape_MNK,
StageCountType,
KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM_>,
KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM_, ScaleGranularityN_>,
cute::enable_if_t<
not detail::is_use_rmem_A<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag>()>
> {
using KernelScheduleType = KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM_>;
using KernelScheduleType = KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM_, ScaleGranularityN_>;
static_assert(is_static<TileShape_MNK>::value);
static_assert(is_static<ClusterShape_MNK>::value);
@@ -1090,7 +1091,7 @@ struct CollectiveBuilder<
static constexpr bool IsCooperative = cute::is_any_of_v<KernelScheduleType,
KernelTmaWarpSpecializedCooperative,
KernelPtrArrayTmaWarpSpecializedCooperative,
KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM_>>;
KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM_, ScaleGranularityN_>>;
using AtomLayoutMNK = cute::conditional_t<IsCooperative,
Layout<Shape<_2,_1,_1>>, Layout<Shape<_1,_1,_1>>>;
@@ -1109,12 +1110,15 @@ struct CollectiveBuilder<
static constexpr int KernelSmemCarveout = static_cast<int>(TensorMapStorage);
static constexpr int ScaleGranularityM = ScaleGranularityM_ == 0 ? size<0>(TileShape_MNK{}) : ScaleGranularityM_;
static constexpr int ScaleGranularityN = ScaleGranularityN_ == 0 ? size<1>(TileShape_MNK{}) : ScaleGranularityN_;
static constexpr int ScaleMsPerTile = size<0>(TileShape_MNK{}) / ScaleGranularityM;
static constexpr int ScaleNsPerTile = size<1>(TileShape_MNK{}) / ScaleGranularityN;
static_assert((size<0>(TileShape_MNK{}) % ScaleGranularityM) == 0, "FP8 scaling granularity must evenly divide tile shape along M.");
static_assert((size<1>(TileShape_MNK{}) % ScaleGranularityN) == 0, "FP8 scaling granularity must evenly divide tile shape along N.");
static constexpr int PipelineStages = detail::compute_stage_count_with_blockwise_scale<detail::sm90_smem_capacity_bytes - KernelSmemCarveout,
ElementAMma, ElementBMma, ElementBlockScale, TileShape_MNK, ScaleMsPerTile>(StageCountType{});
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType, ScaleGranularityM_>;
ElementAMma, ElementBMma, ElementBlockScale, TileShape_MNK, ScaleMsPerTile, ScaleNsPerTile>(StageCountType{});
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType, ScaleGranularityM_, ScaleGranularityN_>;
using SmemCopyAtomA = void;
using SmemCopyAtomB = void;

View File

@@ -75,6 +75,15 @@ private:
}
// `multiply` scale the partial accumulators and `add` to main accumulator (FFMA).
CUTLASS_DEVICE
void scale_core(ElementAccumulator const &scale) {
warpgroup_wait<0>();
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(accum_); ++i) {
accum_(i) += accum_temp_(i) * scale;
}
}
template <
class EngineScale,
class LayoutScale>
@@ -94,6 +103,31 @@ private:
}
}
template <
class EngineScaleA,
class LayoutScaleA,
class EngineScaleB,
class LayoutScaleB>
CUTLASS_DEVICE
void scale_core(const cute::Tensor<EngineScaleA, LayoutScaleA> &scaleA, const cute::Tensor<EngineScaleB, LayoutScaleB> &scaleB) {
using TensorScaleA = cute::Tensor<EngineScaleA, LayoutScaleA>;
using TensorScaleB = cute::Tensor<EngineScaleB, LayoutScaleB>;
static_assert(is_static<LayoutScaleA>::value, "ScaleA Layout should be static");
static_assert(is_static<LayoutScaleB>::value, "ScaleB Layout should be static");
static_assert(is_rmem<TensorScaleA>::value, "ScaleA tensor must be rmem resident.");
static_assert(is_rmem<TensorScaleB>::value, "ScaleB tensor must be rmem resident.");
static_assert(LayoutAccum{}.shape() == LayoutScaleA{}.shape(), "Accumulator and scaleA must have same shape.");
static_assert(LayoutAccum{}.shape() == LayoutScaleB{}.shape(), "Accumulator and scaleB must have same shape.");
warpgroup_wait<0>();
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(accum_); ++i) {
accum_(i) += accum_temp_(i) * scaleA(i) * scaleB(i);
}
}
public:
CUTLASS_DEVICE
GmmaFP8Accumulation(
@@ -152,6 +186,16 @@ public:
//
/// scale (multiply_add) the results from the MMA accumulators to main accumulator if needed.
CUTLASS_DEVICE
void scale_if_needed(ElementAccumulator const &scale) {
mma_count_ += mma_count_per_mainloop_iteration_;
reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0);
if (reset_accum_flag_) {
scale_core(scale);
mma_count_ = 0;
}
}
template <
class EngineScale,
class LayoutScale>
@@ -165,7 +209,29 @@ public:
}
}
template <
class EngineScaleA,
class LayoutScaleA,
class EngineScaleB,
class LayoutScaleB>
CUTLASS_DEVICE
void scale_if_needed(const cute::Tensor<EngineScaleA, LayoutScaleA> &scaleA, const cute::Tensor<EngineScaleB, LayoutScaleB> &scaleB) {
mma_count_ += mma_count_per_mainloop_iteration_;
reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0);
if (reset_accum_flag_) {
scale_core(scaleA, scaleB);
mma_count_ = 0;
}
}
/// scale (multiply_add) the residue results from the MMA accumulators to main accumulator if needed.
CUTLASS_DEVICE
void scale_residue_if_needed(ElementAccumulator const &scale) {
if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) {
scale_core(scale);
}
}
template <
class EngineScale,
class LayoutScale>
@@ -175,6 +241,18 @@ public:
scale_core(scale);
}
}
template <
class EngineScaleA,
class LayoutScaleA,
class EngineScaleB,
class LayoutScaleB>
CUTLASS_DEVICE
void scale_residue_if_needed(const cute::Tensor<EngineScaleA, LayoutScaleA> &scaleA, const cute::Tensor<EngineScaleB, LayoutScaleB> &scaleB) {
if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) {
scale_core(scaleA, scaleB);
}
}
};
} // namespace cutlass::gemm::collective

View File

@@ -30,8 +30,6 @@
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
@@ -288,23 +286,23 @@ struct CollectiveMma<
using TensorStorage = typename SharedStorage::TensorStorage;
using PipelineStorage = typename SharedStorage::PipelineStorage;
// Only one thread issues the TMA and updates the barriers in a 2SM MMA, adjust bytes accordingly
static constexpr uint32_t SFTransactionBytes =
cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFA{})) * cute::sizeof_bits_v<ElementSF>) +
cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFB{})) * cute::sizeof_bits_v<ElementSF>);
// Only one thread issues the TMA and updates the barriers in a 2SM MMA, adjust bytes accordingly
static constexpr uint32_t ABTmaTransactionBytes =
cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v<ElementA>) +
cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutB{})) * cute::sizeof_bits_v<ElementB>);
static constexpr uint32_t TmaTransactionBytes = ABTmaTransactionBytes + SFTransactionBytes;
template<class AccTensor, class SfaTensor, class SfbTensor>
template <class AccTensor, class SfaTensor, class SfbTensor>
struct TmemStorage {
AccTensor accumulators;
SfaTensor tCtSFA;
SfbTensor tCtSFB;
};
template<
template <
class KTileCount,
class GTensorPartitionedA, class GTensorPartitionedB,
class STensorA, class STensorB,
@@ -348,7 +346,8 @@ struct CollectiveMma<
, mcast_mask_sfa(mcast_mask_sfa_), mcast_mask_sfb(mcast_mask_sfb_) {}
};
template<
template <
class TiledMma,
class FragmentA, class FragmentB,
class FragmentSFA, class FragmentSFB,
class SFATiledCopy, class SmemFrgSFA, class TmemFrgSFA,
@@ -496,6 +495,7 @@ struct CollectiveMma<
Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA));
Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB));
auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape);
// Cluster layout for TMA construction
auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{}));
auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape_fallback);
@@ -505,7 +505,7 @@ struct CollectiveMma<
// Cluster layout for TMA construction of SFB
auto cluster_layout_sfb_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMMA_SF::AtomThrID{}));
auto cluster_layout_sfb_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMMA_SF::AtomThrID{}));
auto cluster_layout_sfb_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMMA_SF::AtomThrID{}));
typename Params::TMA_A tma_load_a = make_tma_atom_A_sm100<TmaInternalElementA>(
GmemTiledCopyA{},
@@ -649,7 +649,7 @@ struct CollectiveMma<
return cute::make_tuple(tmem_storage.accumulators(_,_,_,stage));
}
template<class EpilogueTile, bool IsOverlappingAccum = false>
template <class EpilogueTile, bool IsOverlappingAccum = false>
CUTLASS_DEVICE static
auto
init_tmem_tensors(EpilogueTile epi_tile) {
@@ -660,7 +660,7 @@ struct CollectiveMma<
tiled_mma, acc_shape, EpilogueTile{});
Tensor tCtSFA = make_tensor<typename TiledMma::FrgTypeSFA>(shape(SmemLayoutAtomSFA{}));
Tensor tCtSFB = make_tensor<typename TiledMma::FrgTypeSFB>(shape(SmemLayoutAtomSFB{}));
TmemStorage<decltype(accumulators), decltype(tCtSFA), decltype(tCtSFB)> tmem_storage;
tmem_storage.accumulators = accumulators;
tmem_storage.tCtSFA = tCtSFA;
@@ -669,10 +669,10 @@ struct CollectiveMma<
return tmem_storage;
}
template<class AccTensor, class SfaTensor, class SfbTensor>
template <class TmemStorage>
CUTLASS_DEVICE static
void
set_tmem_offsets(TmemStorage<AccTensor, SfaTensor, SfbTensor>& tmem_storage, uint32_t tmem_base_addr) {
set_tmem_offsets(TmemStorage& tmem_storage, uint32_t tmem_base_addr) {
tmem_storage.accumulators.data() = tmem_base_addr;
tmem_storage.tCtSFA.data() = tmem_storage.accumulators.data().get() + cutlass::detail::find_tmem_tensor_col_offset(tmem_storage.accumulators);
tmem_storage.tCtSFB.data() = tmem_storage.tCtSFA.data().get() + cutlass::detail::find_tmem_tensor_col_offset(tmem_storage.tCtSFA);
@@ -751,7 +751,6 @@ struct CollectiveMma<
Tensor sSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{});
// Define the CTA-in-cluster Layout and Coord
Layout cta_layout_mnk = make_layout(cluster_shape_);
Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{}));
auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster_);
@@ -785,13 +784,11 @@ struct CollectiveMma<
uint16_t mcast_mask_sfa = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk);
uint16_t mcast_mask_sfb = create_tma_multicast_mask<1>(cta_layout_sfb_vmnk, cta_coord_sfb_vmnk);
LoadParams load_params {
return LoadParams{
size<3>(gA_mkl), // for scheduler
tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values
tAgSFA_mkl, tBgSFB_nkl, tAsSFA, tBsSFB, // for input scale factor tensor values
mcast_mask_a, mcast_mask_b, mcast_mask_sfa, mcast_mask_sfb // multicast masks
};
return load_params;
mcast_mask_a, mcast_mask_b, mcast_mask_sfa, mcast_mask_sfb}; // multicast masks
}
/// Set up the data needed by this collective for mma compute.
@@ -802,8 +799,8 @@ struct CollectiveMma<
TensorStorage& shared_tensors) const {
// Allocate "fragments/descriptors" for A and B matrices
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
// Allocate "fragments/descriptors" for A and B matrices
Tensor tCrA = TiledMma::make_fragment_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
@@ -854,17 +851,12 @@ struct CollectiveMma<
tiled_mma.idesc_.a_format_ = uint8_t(runtime_data_type_a_) & 0b111;
tiled_mma.idesc_.b_format_ = uint8_t(runtime_data_type_b_) & 0b111;
}
MmaParams<
decltype(tCrA), decltype(tCrB), decltype(tCtSFA), decltype(tCtSFB),
decltype(tiled_copy_s2t_SFA), decltype(thr_tCsSFA_compact_s2t), decltype(thr_tCtSFA_compact_s2t),
decltype(tiled_copy_s2t_SFB), decltype(thr_tCsSFB_compact_s2t), decltype(thr_tCtSFB_compact_s2t)
> mma_params {
return MmaParams{
tiled_mma,
tCrA, tCrB, tCtSFA, tCtSFB,
tiled_copy_s2t_SFA, thr_tCsSFA_compact_s2t, thr_tCtSFA_compact_s2t,
tiled_copy_s2t_SFB, thr_tCsSFB_compact_s2t, thr_tCtSFB_compact_s2t
};
return mma_params;
tiled_copy_s2t_SFB, thr_tCsSFB_compact_s2t, thr_tCtSFB_compact_s2t};
}
/// Perform a collective-scoped matrix multiply-accumulate
@@ -983,52 +975,12 @@ struct CollectiveMma<
uint32_t skip_wait = k_tile_count <= 0;
auto barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait);
bool is_first_iter = true;
//
// PIPELINED MAIN LOOP
//
tiled_mma.accumulate_ = UMMA::ScaleOut::Zero;
if (k_tile_count > 0) { // first iteraion
// WAIT on mainloop_pipe_consumer_state until its data are available
// (phase bit flips from mainloop_pipe_consumer_state.phase() value)
mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token);
// Compute on k_tile
int read_stage = mainloop_pipe_consumer_state.index();
// Save current mainlop pipeline read state
auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state;
// Advance mainloop_pipe
++mainloop_pipe_consumer_state;
--k_tile_count;
skip_wait = k_tile_count <= 0;
// Peek at next iteration
barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait);
if (cute::elect_one_sync()) {
copy(tiled_copy_s2t_SFA, thr_tCsSFA_s2t(_,_,_,_,read_stage), thr_tCtSFA_s2t);
copy(tiled_copy_s2t_SFB, thr_tCsSFB_s2t(_,_,_,_,read_stage), thr_tCtSFB_s2t);
}
if constexpr (IsOverlappingAccum) {
accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state);
}
// Unroll the K mode manually so we can set scale C to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
// (V,M) x (V,N) => (V,M,N)
cute::gemm(tiled_mma.with(tiled_mma.accumulate_,
tCtSFA(_,_,k_block),
tCtSFB_mma(_,_,k_block)),
tCrA(_,_,k_block,read_stage),
tCrB(_,_,k_block,read_stage),
accumulators);
tiled_mma.accumulate_ = UMMA::ScaleOut::One;
}
mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state);
}
CUTLASS_PRAGMA_NO_UNROLL
while (k_tile_count > 0) {
// WAIT on mainloop_pipe_consumer_state until its data are available
@@ -1052,6 +1004,13 @@ struct CollectiveMma<
copy(tiled_copy_s2t_SFB, thr_tCsSFB_s2t(_,_,_,_,read_stage), thr_tCtSFB_s2t);
}
if constexpr (IsOverlappingAccum) {
if (is_first_iter) {
accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state);
is_first_iter = false;
}
}
// Unroll the K mode manually so we can set scale C to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
@@ -1064,6 +1023,7 @@ struct CollectiveMma<
accumulators);
tiled_mma.accumulate_ = UMMA::ScaleOut::One;
}
mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state);
}

View File

@@ -31,7 +31,6 @@
#pragma once
#include "cutlass/cutlass.h"
@@ -239,12 +238,12 @@ struct CollectiveMma<
cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v<ElementA>) +
cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutB{})) * cute::sizeof_bits_v<ElementB>);
template<class AccTensor>
template <class AccTensor>
struct TmemStorage {
AccTensor accumulators;
};
template<
template <
class KTileCount,
class GTensorPartitionedA, class GTensorPartitionedB,
class STensorA, class STensorB
@@ -273,7 +272,10 @@ struct CollectiveMma<
, mcast_mask_a(mcast_mask_a_), mcast_mask_b(mcast_mask_b_) {}
};
template<class FragmentA, class FragmentB>
template <
class TiledMma,
class FragmentA, class FragmentB
>
struct MmaParams {
TiledMma tiled_mma;
FragmentA tCrA;
@@ -336,7 +338,7 @@ struct CollectiveMma<
, runtime_data_type_a_(params.runtime_data_type_a)
, runtime_data_type_b_(params.runtime_data_type_b) {
if constexpr (IsDynamicCluster) {
const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x &&
const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x &&
cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y);
observed_tma_load_a_ = is_fallback_cluster ? &params.tma_load_a_fallback : &params.tma_load_a;
observed_tma_load_b_ = is_fallback_cluster ? &params.tma_load_b_fallback : &params.tma_load_b;
@@ -461,7 +463,7 @@ struct CollectiveMma<
return cute::make_tuple(tmem_storage.accumulators(_,_,_,stage));
}
template<class EpilogueTile, bool IsOverlappingAccum = false>
template <class EpilogueTile, bool IsOverlappingAccum = false>
CUTLASS_DEVICE static
auto
init_tmem_tensors(EpilogueTile epi_tile) {
@@ -475,10 +477,10 @@ struct CollectiveMma<
return tmem_storage;
}
template<class AccTensor>
template <class TmemStorage>
CUTLASS_DEVICE static
void
set_tmem_offsets(TmemStorage<AccTensor>& tmem_storage, uint32_t tmem_base_addr) {
set_tmem_offsets(TmemStorage& tmem_storage, uint32_t tmem_base_addr) {
tmem_storage.accumulators.data() = tmem_base_addr;
}
@@ -535,21 +537,21 @@ struct CollectiveMma<
// TMA Multicast Masks
uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk);
uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk);
LoadParams load_params {
return LoadParams{
shape<3>(gA_mkl), // for scheduler
tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values
mcast_mask_a, mcast_mask_b // multicast masks
};
return load_params;
mcast_mask_a, mcast_mask_b}; // multicast masks
}
/// Set up the data needed by this collective for mma compute.
template <class AccTensor>
template <class TmemStorage>
CUTLASS_DEVICE auto
mma_init(
[[maybe_unused]] TmemStorage<AccTensor> tmem_tensors,
TensorStorage& shared_tensors) const {
[[maybe_unused]] TmemStorage tmem_storage,
TensorStorage& shared_tensors) const {
// Allocate "fragments/descriptors" for A and B matrices
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
@@ -558,7 +560,7 @@ struct CollectiveMma<
Tensor tCrB = TiledMma::make_fragment_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<3>(sA)); // PIPE
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<3>(sB));
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<3>(sB)); // PIPE
TiledMma tiled_mma;
@@ -568,11 +570,10 @@ struct CollectiveMma<
tiled_mma.idesc_.a_format_ = uint8_t(runtime_data_type_a_) & 0b111;
tiled_mma.idesc_.b_format_ = uint8_t(runtime_data_type_b_) & 0b111;
}
MmaParams<decltype(tCrA), decltype(tCrB)> mma_params {
return MmaParams{
tiled_mma,
tCrA, tCrB
};
return mma_params;
tCrA, tCrB};
}
/// Perform a collective-scoped matrix multiply-accumulate
@@ -657,6 +658,7 @@ struct CollectiveMma<
) {
static_assert(is_tmem<FrgEngine>::value, "Accumulator must be tmem resident.");
static_assert(rank(FrgLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)");
auto accumulators = get<0>(accumulators_pair);
auto [tiled_mma, tCrA, tCrB] = mma_inputs;

View File

@@ -58,6 +58,7 @@ template <
class ClusterShape,
class KernelSchedule,
int ScaleGranularityM_,
int ScaleGranularityN_,
class TileShape_,
class ElementA_,
class StrideA_,
@@ -73,7 +74,7 @@ template <
class SmemCopyAtomB_,
class TransformB_>
struct CollectiveMma<
MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8<Stages, ClusterShape, KernelSchedule, ScaleGranularityM_>,
MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8<Stages, ClusterShape, KernelSchedule, ScaleGranularityM_, ScaleGranularityN_>,
TileShape_,
ElementA_,
StrideA_,
@@ -92,7 +93,7 @@ struct CollectiveMma<
//
// Type Aliases
//
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8<Stages, ClusterShape, KernelSchedule, ScaleGranularityM_>;
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8<Stages, ClusterShape, KernelSchedule, ScaleGranularityM_, ScaleGranularityN_>;
using TileShape = TileShape_;
using ElementA = ElementA_;
using StrideA = StrideA_;
@@ -120,7 +121,9 @@ struct CollectiveMma<
static constexpr int NumProducerThreadEvents = 2;
static constexpr int ScaleGranularityM = ScaleGranularityM_ == 0 ? size<0>(TileShape{}) : ScaleGranularityM_;
static constexpr int ScaleGranularityN = ScaleGranularityN_ == 0 ? size<1>(TileShape{}) : ScaleGranularityN_;
static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM;
static constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN;
static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
@@ -131,6 +134,7 @@ struct CollectiveMma<
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<0>(TileShape{}) % ScaleGranularityM) == 0, "FP8 scaling granularity must evenly divide tile shape along M.");
static_assert((size<1>(TileShape{}) % ScaleGranularityN) == 0, "FP8 scaling granularity must evenly divide tile shape along N.");
// Tile along modes in a way that maximizes the TMA box size.
using SmemLayoutA = decltype(tile_to_shape(
@@ -144,12 +148,13 @@ struct CollectiveMma<
// Block scaling gmem-to-smem copy atom
using BlockScaleCopyTypeA = cute::uint_byte_t<cute::min(static_cast<int>(sizeof(ElementBlockScale)) * ScaleMsPerTile, 16)>;
using BlockScaleCopyTypeB = cute::uint_byte_t<cute::min(static_cast<int>(sizeof(ElementBlockScale)) * ScaleNsPerTile, 16)>;
using SmemBlockScalingCopyAtomA = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<BlockScaleCopyTypeA>, ElementBlockScale>;
using SmemBlockScalingCopyAtomB = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementBlockScale>, ElementBlockScale>;
using SmemBlockScalingCopyAtomB = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<BlockScaleCopyTypeB>, ElementBlockScale>;
// Block scaling smem layout
using SmemLayoutScaleA = Layout<Shape<Int<ScaleMsPerTile>, Int<DispatchPolicy::Stages>>>;
using SmemLayoutScaleB = Layout<Shape<Int<DispatchPolicy::Stages>>, Stride<_1>>; // `ScaleNsPerTile` is always 1.
using SmemLayoutScaleB = Layout<Shape<Int<ScaleNsPerTile>, Int<DispatchPolicy::Stages>>>;
static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more.");
static_assert(cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value &&
@@ -168,7 +173,7 @@ struct CollectiveMma<
cute::array_aligned<typename TiledMma::ValTypeA, cute::cosize_v<SmemLayoutA>> smem_A; // mxk
cute::array_aligned<typename TiledMma::ValTypeB, cute::cosize_v<SmemLayoutB>> smem_B; // nxk
cute::array_aligned<ElementBlockScale, cute::cosize_v<SmemLayoutScaleA>> smem_scale_A; // ScaleMsPerTile x k
cute::array_aligned<ElementBlockScale, cute::cosize_v<SmemLayoutScaleB>> smem_scale_B; // 1xk
cute::array_aligned<ElementBlockScale, cute::cosize_v<SmemLayoutScaleB>> smem_scale_B; // ScaleNsPerTile x k
} tensors;
using PipelineStorage = typename MainloopPipeline::SharedStorage;
@@ -322,17 +327,17 @@ struct CollectiveMma<
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l)
// Make the tiled views of scale tensors
auto scaleA_shape = make_shape(get<2>(gA_mkl.shape()), Int<ScaleMsPerTile>{}, get<3>(gA_mkl.shape()), get<4>(gA_mkl.shape())); // (m,ScaleMsPerTile,k,l)
auto scale_dA = make_stride(get<3>(gA_mkl.shape()) * Int<ScaleMsPerTile>{}, Int<1>{}, Int<ScaleMsPerTile>{}, get<2>(gA_mkl.shape()) * get<3>(gA_mkl.shape()) * Int<ScaleMsPerTile>{});
auto scaleA_shape = make_shape(shape<2>(gA_mkl), Int<ScaleMsPerTile>{}, shape<3>(gA_mkl), shape<4>(gA_mkl)); // (m,ScaleMsPerTile,k,l)
auto scaleB_shape = make_shape(shape<2>(gB_nkl), Int<ScaleNsPerTile>{}, shape<3>(gB_nkl), shape<4>(gB_nkl)); // (n,ScaleNsPerTile,k,l)
auto scale_dA = compact_order(scaleA_shape, Step<_2,_0,_1,_3>{});
auto scale_dB = compact_order(scaleB_shape, Step<_2,_0,_1,_3>{});
auto scaleA_layout = make_layout(scaleA_shape, scale_dA);
auto scaleB_shape = make_shape(get<2>(gB_nkl.shape()), get<3>(gB_nkl.shape()), get<4>(gB_nkl.shape())); // (n,k,l)
auto scale_dB = make_stride(get<3>(gB_nkl.shape()), Int<1>{}, get<2>(gB_nkl.shape()) * get<3>(gB_nkl.shape()));
auto scaleB_layout = make_layout(scaleB_shape, scale_dB);
// Note that mScaleA_mkl and mScaleB_nkl are already blocked tiled in the `m` host and
// Note that mScaleA_mkl and mScaleB_nkl are already blocked tiled in the `m` host and
// gScaleA_mkl and gScaleB_nkl in `g` global memory are same as mScaleA_mkl and mScaleB_nkl.
Tensor mScaleA_mkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_A), scaleA_layout); // (m,ScaleMsPerTile,k,l)
Tensor mScaleB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_B), scaleB_layout); // (n,k,l)
Tensor mScaleB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_B), scaleB_layout); // (n,ScaleNsPerTile,k,l)
return cute::make_tuple(gA_mkl, gB_nkl, mScaleA_mkl, mScaleB_nkl);
}
@@ -356,13 +361,13 @@ struct CollectiveMma<
uint32_t block_rank_in_cluster,
TensorStorage& shared_tensors) {
int lane_predicate = cute::elect_one_sync();
// Blockscaling: Tma loads for load_input and CpAsync for load_scale
if (lane_predicate) {
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
Tensor sScaleA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), SmemLayoutScaleA{}); // (ScaleMsPerTile,k)
Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (k)
Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (ScaleNsPerTile,k)
//
// Prepare the TMA loads for A and B
@@ -388,10 +393,10 @@ struct CollectiveMma<
Tensor mScaleB_nkl = get<3>(load_inputs);
Tensor gScaleA = mScaleA_mkl(m_coord,_,_,l_coord); // (1,ScaleMsPerTile,k,1)
Tensor gScaleB = mScaleB_nkl(n_coord,_,l_coord); // (1,k,1)
Tensor gScaleB = mScaleB_nkl(n_coord,_,_,l_coord); // (1,ScaleNsPerTile,k,1)
TiledCopy scale_copy_a = make_tiled_copy(SmemBlockScalingCopyAtomA{}, Layout<Shape<_1>>{}, Layout<Shape<Int<ScaleMsPerTile>>>{}); // (1,ScaleMsPerTile,1)
TiledCopy scale_copy_b = make_tiled_copy(SmemBlockScalingCopyAtomB{}, Layout<Shape<_1>>{}, Layout<Shape<_1>>{}); // (1,1,1)
TiledCopy scale_copy_b = make_tiled_copy(SmemBlockScalingCopyAtomB{}, Layout<Shape<_1>>{}, Layout<Shape<Int<ScaleNsPerTile>>>{}); // (1,ScaleNsPerTile,1)
ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x);
ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(threadIdx.x);
@@ -446,7 +451,7 @@ struct CollectiveMma<
// Copy scale tensors from global memory to shared memory
copy(scale_copy_a, tAgA_ScaleA(_,_,*k_tile_iter), tAsA_ScaleA(_,_,write_stage));
copy(scale_copy_b, tBgB_ScaleB(_,*k_tile_iter), tBsB_ScaleB(_,write_stage));
copy(scale_copy_b, tBgB_ScaleB(_,_,*k_tile_iter), tBsB_ScaleB(_,_,write_stage));
pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive_noinc);
++k_tile_iter;
@@ -508,7 +513,11 @@ struct CollectiveMma<
Shape<Shape<Int<ScaleGranularityM>, Int<ScaleMsPerTile>>, cute::tuple_element_t<1, TileShape>, Int<DispatchPolicy::Stages>>,
Stride<Stride<_0, _1>, _0, Int<ScaleMsPerTile>>
>{}); // ((ScaleGranularityM,ScaleMsPerTile),n,k)
Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (k)
Tensor sScaleBViewAsC = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()),
Layout<
Shape<cute::tuple_element_t<0, TileShape>, Shape<Int<ScaleGranularityN>, Int<ScaleNsPerTile>>, Int<DispatchPolicy::Stages>>,
Stride<_0, Stride<_0, _1>, Int<ScaleNsPerTile>>
>{}); // (m,(ScaleGranularityN,ScaleNsPerTile),k)
//
// Define C accumulators and A/B partitioning
@@ -531,7 +540,8 @@ struct CollectiveMma<
TiledMma tiled_mma;
auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx));
Tensor tCsScaleAViewAsC = tiled_mma.get_slice(thread_idx).partition_C(sScaleAViewAsC); // (MMA,MMA_M,MMA_N,PIPE), `thread_mma` above is correct when partitioning A and B, but it is not correct when partitioning C.
Tensor tCsScaleAViewAsC = tiled_mma.get_slice(thread_idx).partition_C(sScaleAViewAsC); // (MMA,MMA_M,MMA_N,PIPE), `thread_mma` above is correct when partitioning A and B, but it is not correct when partitioning C.
Tensor tCsScaleBViewAsC = tiled_mma.get_slice(thread_idx).partition_C(sScaleBViewAsC); // (MMA,MMA_M,MMA_N,PIPE), `thread_mma` above is correct when partitioning A and B, but it is not correct when partitioning C.
Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
@@ -557,11 +567,8 @@ struct CollectiveMma<
PipelineState smem_pipe_release = smem_pipe_read;
// Per block scale values for operand A and B
using RegLayoutScaleAViewAsC = decltype(make_layout_like(tCsScaleAViewAsC(_, _, _, 0).layout())); // `make_layout_like` makes a compact layout.
using RegLayoutScaleAEssential = decltype(filter_zeros(RegLayoutScaleAViewAsC{}.stride(), RegLayoutScaleAViewAsC{}.shape())); // an interface to traverse the underlying storage for the compact layout mentioned above
Tensor tCrScaleAViewAsC = make_tensor<ElementBlockScale>(RegLayoutScaleAViewAsC{}); // (MMA,MMA_M,MMA_N)
ElementBlockScale scale_b;
Tensor tCrScaleAViewAsC = make_tensor_like<ElementBlockScale>(tCsScaleAViewAsC(_, _, _, 0)); // (MMA,MMA_M,MMA_N)
Tensor tCrScaleBViewAsC = make_tensor_like<ElementBlockScale>(tCsScaleBViewAsC(_, _, _, 0)); // (MMA,MMA_M,MMA_N)
// Prologue GMMAs
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
@@ -583,21 +590,26 @@ struct CollectiveMma<
int read_stage = smem_pipe_read.index();
// Load per block scale values from shared memory to registers.
scale_b = sScaleB[read_stage];
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) {
tCrScaleAViewAsC.data()[i] = tCsScaleAViewAsC(_, _, _, read_stage)(idx2crd(i, RegLayoutScaleAEssential{}));
// Load per block scale values from shared memory to registers
copy(tCsScaleAViewAsC(_, _, _, read_stage), tCrScaleAViewAsC);
copy(tCsScaleBViewAsC(_, _, _, read_stage), tCrScaleBViewAsC);
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) {
tCrScaleAViewAsC.data()[0] = tCrScaleAViewAsC.data()[0] * tCrScaleBViewAsC.data()[0];
}
if constexpr (ScaleMsPerTile == 1) {
static_assert(size(RegLayoutScaleAEssential{}) == 1);
tCrScaleAViewAsC.data()[0] = __shfl_sync(0xffffffff, tCrScaleAViewAsC.data()[0] * scale_b, 0); // `tCrScaleAViewAsC.data()[0]` are all same in a warp group when `ScaleMsPerTile == 1`.
} else {
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) {
ElementBlockScale scale_b = tCrScaleBViewAsC.data()[0];
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) {
for (int i = 0; i < size(tCrScaleAViewAsC); i++) {
tCrScaleAViewAsC.data()[i] = tCrScaleAViewAsC.data()[i] * scale_b;
}
}
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) {
ElementBlockScale scale_a = tCrScaleAViewAsC.data()[0];
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(tCrScaleBViewAsC); i++) {
tCrScaleBViewAsC.data()[i] = tCrScaleBViewAsC.data()[i] * scale_a;
}
}
warpgroup_arrive();
// Unroll the K mode manually to set scale D to 1
@@ -609,8 +621,20 @@ struct CollectiveMma<
}
warpgroup_commit_batch();
// Block scale the accumulators with reg tensor `tCrScaleAViewAsC`
accumulation.scale_if_needed(tCrScaleAViewAsC);
// Block scale the accumulators with reg tensor `tCrScaleAViewAsC` and `tCrScaleBViewAsC`
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) {
ElementBlockScale scale_ab = tCrScaleAViewAsC.data()[0];
accumulation.scale_if_needed(scale_ab);
}
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) {
accumulation.scale_if_needed(tCrScaleAViewAsC);
}
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) {
accumulation.scale_if_needed(tCrScaleBViewAsC);
}
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) {
accumulation.scale_if_needed(tCrScaleAViewAsC, tCrScaleBViewAsC);
}
++smem_pipe_read;
}
@@ -632,21 +656,26 @@ struct CollectiveMma<
int read_stage = smem_pipe_read.index();
// Load per block scale values from shared memory to registers (at most twice per block along M and exactly once per block along N)
scale_b = sScaleB[read_stage];
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) {
tCrScaleAViewAsC.data()[i] = tCsScaleAViewAsC(_, _, _, read_stage)(idx2crd(i, RegLayoutScaleAEssential{}));
// Load per block scale values from shared memory to registers (at most twice per block along M and/or N)
copy(tCsScaleAViewAsC(_, _, _, read_stage), tCrScaleAViewAsC);
copy(tCsScaleBViewAsC(_, _, _, read_stage), tCrScaleBViewAsC);
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) {
tCrScaleAViewAsC.data()[0] = tCrScaleAViewAsC.data()[0] * tCrScaleBViewAsC.data()[0];
}
if constexpr (ScaleMsPerTile == 1) {
static_assert(size(RegLayoutScaleAEssential{}) == 1);
tCrScaleAViewAsC.data()[0] = __shfl_sync(0xffffffff, tCrScaleAViewAsC.data()[0] * scale_b, 0); // `tCrScaleAViewAsC.data()[0]` are all same in a warp group when `ScaleMsPerTile == 1`.
} else {
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) {
ElementBlockScale scale_b = tCrScaleBViewAsC.data()[0];
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) {
for (int i = 0; i < size(tCrScaleAViewAsC); i++) {
tCrScaleAViewAsC.data()[i] = tCrScaleAViewAsC.data()[i] * scale_b;
}
}
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) {
ElementBlockScale scale_a = tCrScaleAViewAsC.data()[0];
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(tCrScaleBViewAsC); i++) {
tCrScaleBViewAsC.data()[i] = tCrScaleBViewAsC.data()[i] * scale_a;
}
}
if (accumulation.prepare_if_needed()) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
@@ -667,8 +696,20 @@ struct CollectiveMma<
warpgroup_wait<K_PIPE_MMAS>();
warpgroup_fence_operand(accumulation());
// Block scale the accumulators with reg tensor `tCrScaleAViewAsC`
accumulation.scale_if_needed(tCrScaleAViewAsC);
// Block scale the accumulators with reg tensor `tCrScaleAViewAsC` and `tCrScaleBViewAsC`
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) {
ElementBlockScale scale_ab = tCrScaleAViewAsC.data()[0];
accumulation.scale_if_needed(scale_ab);
}
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) {
accumulation.scale_if_needed(tCrScaleAViewAsC);
}
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) {
accumulation.scale_if_needed(tCrScaleBViewAsC);
}
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) {
accumulation.scale_if_needed(tCrScaleAViewAsC, tCrScaleBViewAsC);
}
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
@@ -677,7 +718,19 @@ struct CollectiveMma<
++smem_pipe_release;
}
accumulation.scale_residue_if_needed(tCrScaleAViewAsC);
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) {
ElementBlockScale scale_ab = tCrScaleAViewAsC.data()[0];
accumulation.scale_residue_if_needed(scale_ab);
}
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) {
accumulation.scale_residue_if_needed(tCrScaleAViewAsC);
}
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) {
accumulation.scale_residue_if_needed(tCrScaleBViewAsC);
}
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) {
accumulation.scale_residue_if_needed(tCrScaleAViewAsC, tCrScaleBViewAsC);
}
warpgroup_fence_operand(accumulation());
}

View File

@@ -117,7 +117,11 @@ struct KernelPtrArrayTmaWarpSpecializedPingpong { };
// FP8 related policies (including Blocked Scaled Accumulation)
template<
int ScaleGranularityM = 0 // `ScaleGranularityM` specifies scaling granularity along M, while zero-value `ScaleGranularityM` indicates that scaling granularity is `size<0>(TileShape_MNK{})` along M.
// `ScaleGranularityM`/`ScaleGranularityN` specifies scaling granularity along M/N, while zero-value
// `ScaleGranularityM`/`ScaleGranularityN` indicates that scaling granularity is
// `size<0>(TileShape_MNK{})`/`size<1>(TileShape_MNK{})` along M/N.
int ScaleGranularityM = 0,
int ScaleGranularityN = 0
>
struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum: KernelTmaWarpSpecializedCooperative { };
@@ -302,12 +306,16 @@ template<
int Stages_,
class ClusterShape_ = Shape<_1,_1,_1>,
class KernelSchedule = KernelTmaWarpSpecialized,
int ScaleGranularityM = 0 // `ScaleGranularityM` specifies scaling granularity along M, while zero-value `ScaleGranularityM` indicates that scaling granularity is `size<0>(TileShape_MNK{})` along M.
// `ScaleGranularityM`/`ScaleGranularityN` specifies scaling granularity along M/N, while zero-value
// `ScaleGranularityM`/`ScaleGranularityN` indicates that scaling granularity is
// `size<0>(TileShape_MNK{})`/`size<1>(TileShape_MNK{})` along M/N.
int ScaleGranularityM = 0,
int ScaleGranularityN = 0
>
struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8
: MainloopSm90TmaGmmaWarpSpecialized<Stages_, ClusterShape_, KernelSchedule> {
static_assert(
cute::is_same_v<KernelSchedule, KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM>>,
cute::is_same_v<KernelSchedule, KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM, ScaleGranularityN>>,
"KernelSchedule must be one of the warp specialized policies");
};

View File

@@ -397,8 +397,6 @@ public:
// An example of an unneeded threadblock is one that is assigned to compute in the upper
// portion of a Rank2K kernel filled with mode kLower.
//
// TODO: Consider pushing these checks into ProblemVisitor to avoid spuriously
// returning from `next_tile()`.
//
// Early exit if threadblock is out of range

View File

@@ -1131,6 +1131,10 @@ public:
}
}
else {
// Register reconfiguration
arch::warpgroup_reg_dealloc<GenericRegisterRequirement>();
}
}
};

View File

@@ -29,8 +29,6 @@
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
@@ -564,20 +562,21 @@ public:
// Sync deallocation status between MMA warps of peer CTAs
arch::ClusterBarrier& tmem_deallocation_result_barrier = shared_storage.pipelines.tmem_dealloc;
[[maybe_unused]] uint32_t dealloc_barrier_phase = 0;
if constexpr(!IsOverlappingAccum) {
if (WarpCategory::MMA == warp_category && has_mma_peer_cta && lane_predicate) {
tmem_deallocation_result_barrier.init(NumMMAThreads);
if (WarpCategory::MMA == warp_category) {
if constexpr(!IsOverlappingAccum) {
if (has_mma_peer_cta && lane_predicate) {
tmem_deallocation_result_barrier.init(NumMMAThreads);
}
}
else {
if (has_mma_peer_cta && lane_predicate) {
tmem_deallocation_result_barrier.init(NumEpilogueThreads*2);
}
else if (lane_predicate) {
tmem_deallocation_result_barrier.init(NumEpilogueThreads);
}
}
}
else {
if (WarpCategory::MMA == warp_category && has_mma_peer_cta && lane_predicate) {
tmem_deallocation_result_barrier.init(NumEpilogueThreads*2);
}
else if (WarpCategory::MMA == warp_category && lane_predicate) {
tmem_deallocation_result_barrier.init(NumEpilogueThreads);
}
}
// Initialize smem barrier for prologue throttling. Epilogue warps are stalled until the prologue finishes.
arch::ClusterBarrier& epilogue_throttle_barrier = shared_storage.pipelines.epilogue_throttle;
@@ -699,7 +698,6 @@ public:
epilogue_throttle_barrier.arrive();
if constexpr (IsSchedDynamicPersistent) {
// Whether a new CLC query must be performed.
// See comment below where this variable is updated for a description of
// why this variable is needed.
@@ -738,7 +736,6 @@ public:
work_tile_info = next_work_tile_info;
} while (work_tile_info.is_valid());
clc_pipeline.producer_tail(clc_pipe_producer_state);
}
}
@@ -963,7 +960,6 @@ public:
epi_load_pipe_consumer_state = load_state_next;
epi_store_pipe_producer_state = store_state_next;
accumulator_pipe_consumer_state = acc_state_next;
do_tail_store = true;
}
work_tile_info = next_work_tile_info;

View File

@@ -1057,6 +1057,10 @@ public:
}
}
else {
// Register reconfiguration
arch::warpgroup_reg_dealloc<GenericRegisterRequirement>();
}
}
};

View File

@@ -783,7 +783,6 @@ private:
int L_idx, Split_idx;
params_.sk_params_.divmod_splits_(L_idx, Split_idx, work_tile_info.L_idx);
// TODO: Modularize the SM90 scheduler to pull out and reuse this redundant code
int additional_k_tiles = 0;
int split_start_offset = params_.sk_params_.big_units_;

View File

@@ -455,8 +455,9 @@ public:
auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K)
TileScheduler scheduler{params.scheduler};
auto work_tile_info = scheduler.initial_work_tile_info(ClusterShape{});
// Declare work_tile_info, then define it in each of warps that use it.
typename TileScheduler::WorkTileInfo work_tile_info;
// In a warp specialized kernel, collectives expose data movement and compute operations separately
CollectiveMainloop collective_mainloop;
@@ -474,6 +475,7 @@ public:
cluster_wait_fn();
if (warp_group_role == WarpGroupRole::Producer) {
work_tile_info = scheduler.initial_work_tile_info(ClusterShape{});
cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
// Mainloop Producer Warp
@@ -578,6 +580,7 @@ public:
} // Producer Warp Group End
else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) {
work_tile_info = scheduler.initial_work_tile_info(ClusterShape{});
cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue);

View File

@@ -265,7 +265,7 @@ struct PersistentTileSchedulerSm90Params {
}
// In case the maximum number of clusters that could co-exist on the target device is
// already calculated using cudaOccupancyMaxActiveClusters
else if (max_active_clusters != 0) {
else if (max_active_clusters != 0 && max_active_clusters * cluster_size <= sm_count) {
if (raster_order == RasterOrder::AlongN) {
launch_grid.y = max_active_clusters * cluster_shape.n();
}
@@ -1204,6 +1204,7 @@ struct PersistentTileSchedulerSm90StreamKParams {
KernelHardwareInfo new_hw_info;
new_hw_info.device_id = hw_info.device_id;
new_hw_info.sm_count = hw_info.sm_count;
new_hw_info.max_active_clusters = hw_info.max_active_clusters;
if (new_hw_info.sm_count <= 0) {
CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
@@ -1787,7 +1788,7 @@ struct PersistentTileSchedulerSm90GroupParams {
}
// In case the maximum number of clusters that could co-exist on the target device is
// already calculated using cudaOccupancyMaxActiveClusters
else if (max_active_clusters != 0) {
else if (max_active_clusters != 0 && max_active_clusters * cluster_size <= sm_count) {
if (raster_order == RasterOrder::AlongN) {
launch_grid.y = max_active_clusters * cluster_shape.n();
}
@@ -2499,6 +2500,7 @@ struct PersistentTileSchedulerSm100GroupParams {
bool is_static_cluster_shape = false) {
int const sm_count = hw_info.sm_count;
int const max_active_clusters = hw_info.max_active_clusters;
// Round up to nearest multiple of swizzle_size along each mode
auto log_swizzle_size = get_log_swizzle_size(problem_blocks.x, problem_blocks.y, max_swizzle_size);
@@ -2542,6 +2544,18 @@ struct PersistentTileSchedulerSm100GroupParams {
launch_grid.x = possibly_truncate(sm_count, problem_blocks_total);
}
}
// In case the maximum number of clusters that could co-exist on the target device is
// already calculated using cudaOccupancyMaxActiveClusters
else if (max_active_clusters != 0 && max_active_clusters * cluster_size <= sm_count) {
if (raster_order == RasterOrder::AlongN) {
launch_grid.y = max_active_clusters * cluster_shape.n();
}
else {
launch_grid.x = max_active_clusters * cluster_shape.m();
}
CUTLASS_TRACE_HOST("get_grid_shape(): Proposed GridDims by the scheduler using cudaOccupancyMaxActiveClusters = "
"(" << launch_grid.x << ", " << launch_grid.y << ", " << launch_grid.z << ")\n");
}
else {
constexpr int max_sm_per_gpc = 20;
int cta_per_device = get_max_cta_occupancy(max_sm_per_gpc, cluster_shape, sm_count);

View File

@@ -142,7 +142,6 @@ class MmaVoltaTensorOpMultiplicandTileIterator<
"Shape of warp-level Mma must be divisible by operator shape.");
// Shape of one individual LDS.128
// TODO: 32 and 4 are hardcoded, 32-by-4 is logical shape
using LdsShape = layout::PitchLinearShape<
32,
4
@@ -458,7 +457,6 @@ class MmaVoltaTensorOpMultiplicandTileIterator<
"Shape of warp-level Mma must be divisible by operator shape.");
// Shape of one individual LDS
// TODO: remove hardcoded 32 and 4
using LdsShape = layout::PitchLinearShape<
32,
4

View File

@@ -995,7 +995,6 @@ public:
CUTLASS_DEVICE
MmaTensorOpMultiplicandTileIterator &add_tile_offset_negative(TensorCoord const &tile_offset) {
// TODO: fix this if it becomes an issue during warp it reset
add_tile_offset(tile_offset);
return *this;

View File

@@ -41,7 +41,6 @@
#pragma once
#include <cuda/std/cassert>
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/layout/pitch_linear.h"

View File

@@ -82,7 +82,7 @@ struct get_unpacked_element_type {
#include "cutlass/tfloat32.h"
#include "cutlass/float8.h"
#include "cutlass/uint128.h"
#include "cutlass/exmy_base.h"
#include "cutlass/float_subbyte.h"
#include "cutlass/exmy_base.h"
#include "cutlass/float_subbyte.h"
/////////////////////////////////////////////////////////////////////////////////////////////////