mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-11 17:00:05 +00:00
update 3.8 v2 (#2112)
* update 3.8 v2 * update 3.8 --------- Co-authored-by: yuzhai <yuzhai@nvidia.com>
This commit is contained in:
@@ -86,5 +86,3 @@
|
||||
#define CUTE_ARCH_FLOAT2_MATH_ENABLED
|
||||
#endif
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -208,6 +208,7 @@ to_CUtensorMapDataType() {
|
||||
if constexpr (is_same_v<T, uint8_t>) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else
|
||||
if constexpr (is_same_v<T, float_e4m3_t>) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else
|
||||
if constexpr (is_same_v<T, float_e5m2_t>) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else
|
||||
if constexpr (is_same_v<T, float_ue8m0_t>) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else
|
||||
if constexpr (is_same_v<T, type_erased_dynamic_float8_t>) { return CU_TENSOR_MAP_DATA_TYPE_UINT8;} else
|
||||
if constexpr (is_same_v<T, uint16_t>) { return CU_TENSOR_MAP_DATA_TYPE_UINT16; } else
|
||||
if constexpr (is_same_v<T, uint32_t>) { return CU_TENSOR_MAP_DATA_TYPE_UINT32; } else
|
||||
|
||||
@@ -956,7 +956,7 @@ template <class a_type, class b_type, class c_type, class sf_type,
|
||||
UMMA::ScaleIn a_neg = UMMA::ScaleIn::One, UMMA::ScaleIn b_neg = UMMA::ScaleIn::One>
|
||||
struct SM100_MMA_MXF4_SS
|
||||
{
|
||||
static_assert(M == 128, "SM100_MMA_MXF4_SS M-mode size should be 128 for 1 CTA cluster OMMA.");
|
||||
static_assert(M == 128, "SM100_MMA_MXF4_SS M-mode size should be 128 for 1 CTA cluster MMA.");
|
||||
static_assert((N % 8 == 0) && (8 <= N) && (N <= 256), "SM100_MMA_MXF4_SS N-mode size should be a multiple of 8 between 8 and 256.");
|
||||
static_assert((VS == 16) || (VS == 32), "SM100_MMA_MXF4_SS Vector size can only be 16 or 32.");
|
||||
|
||||
|
||||
@@ -45,7 +45,6 @@
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
template <>
|
||||
struct Copy_Traits<SM100_U8x8_LDSM_T>
|
||||
{
|
||||
|
||||
@@ -372,7 +372,7 @@ void swap(array<T,N>& a, array<T,N>& b)
|
||||
/// @return A cute::array of the elements of @c t in reverse order.
|
||||
template <class T, size_t N>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
cute::array<T,N> reverse(cute::array<T,N> const& t)
|
||||
cute::array<T,N> reverse(cute::array<T,N> const& t)
|
||||
{
|
||||
if constexpr (N == 0u) {
|
||||
return t;
|
||||
@@ -441,17 +441,6 @@ struct tuple_element<I, cute::array<T,N>>
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <class T, size_t N>
|
||||
struct tuple_size<cute::array<T,N> const>
|
||||
: CUTE_STL_NAMESPACE::integral_constant<size_t, N>
|
||||
{};
|
||||
|
||||
template <size_t I, class T, size_t N>
|
||||
struct tuple_element<I, cute::array<T,N> const>
|
||||
{
|
||||
using type = T;
|
||||
};
|
||||
|
||||
} // end namespace CUTE_STL_NAMESPACE
|
||||
|
||||
#ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD
|
||||
@@ -477,16 +466,5 @@ struct tuple_element<I, cute::array<T,N>>
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <class T, size_t N>
|
||||
struct tuple_size<cute::array<T,N> const>
|
||||
: CUTE_STL_NAMESPACE::integral_constant<size_t, N>
|
||||
{};
|
||||
|
||||
template <size_t I, class T, size_t N>
|
||||
struct tuple_element<I, cute::array<T,N> const>
|
||||
{
|
||||
using type = T;
|
||||
};
|
||||
|
||||
} // end namespace std
|
||||
#endif // CUTE_STL_NAMESPACE_IS_CUDA_STD
|
||||
|
||||
@@ -611,17 +611,6 @@ struct tuple_element<I, cute::array_subbyte<T,N>>
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <class T, size_t N>
|
||||
struct tuple_size<const cute::array_subbyte<T,N>>
|
||||
: CUTE_STL_NAMESPACE::integral_constant<size_t, N>
|
||||
{};
|
||||
|
||||
template <size_t I, class T, size_t N>
|
||||
struct tuple_element<I, const cute::array_subbyte<T,N>>
|
||||
{
|
||||
using type = T;
|
||||
};
|
||||
|
||||
} // end namespace CUTE_STL_NAMESPACE
|
||||
|
||||
#ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD
|
||||
@@ -647,16 +636,5 @@ struct tuple_element<I, cute::array_subbyte<T,N>>
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <class T, size_t N>
|
||||
struct tuple_size<const cute::array_subbyte<T,N>>
|
||||
: CUTE_STL_NAMESPACE::integral_constant<size_t, N>
|
||||
{};
|
||||
|
||||
template <size_t I, class T, size_t N>
|
||||
struct tuple_element<I, const cute::array_subbyte<T,N>>
|
||||
{
|
||||
using type = T;
|
||||
};
|
||||
|
||||
} // end namespace std
|
||||
#endif // CUTE_STL_NAMESPACE_IS_CUDA_STD
|
||||
|
||||
@@ -1,254 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
#include <cute/util/type_traits.hpp>
|
||||
#include <cute/numeric/integral_constant.hpp>
|
||||
#include <cute/container/type_list.hpp>
|
||||
|
||||
namespace cute {
|
||||
|
||||
namespace detail {
|
||||
|
||||
// Empty Structure Optimization
|
||||
template <bool IsFirstEmpty, bool IsRestEmpty, class... T>
|
||||
struct ESO;
|
||||
|
||||
template <class First, class... Rest>
|
||||
static constexpr bool is_first_empty_v = cute::is_empty<First>::value;
|
||||
template <class First, class... Rest>
|
||||
static constexpr bool is_rest_empty_v = (cute::is_empty<Rest>::value && ...);
|
||||
|
||||
template <class... T>
|
||||
using ESO_t = ESO<is_first_empty_v<T...>, is_rest_empty_v<T...>, T...>;
|
||||
|
||||
// Empty First and Empty Rest...
|
||||
template <class First, class... Rest>
|
||||
struct ESO<true, true, First, Rest...> {
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
ESO() {}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
ESO(First const&, Rest const&...) {}
|
||||
};
|
||||
|
||||
// NonEmpty First and Empty Rest...
|
||||
template <class First, class... Rest>
|
||||
struct ESO<false, true, First, Rest...> {
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
ESO() : first_{} {}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
ESO(First const& first, Rest const&...) : first_{first} {}
|
||||
|
||||
First first_;
|
||||
};
|
||||
|
||||
// Empty First and NonEmpty Rest...
|
||||
template <class First, class... Rest>
|
||||
struct ESO<true, false, First, Rest...> {
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
ESO() : rest_{} {}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
ESO(First const&, Rest const&... rest) : rest_{rest...} {}
|
||||
|
||||
ESO_t<Rest...> rest_;
|
||||
};
|
||||
|
||||
// NonEmpty T and NonEmpty Rest...
|
||||
template <class First, class... Rest>
|
||||
struct ESO<false, false, First, Rest...> {
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
ESO() : first_{}, rest_{} {}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
ESO(First const& first, Rest const&... rest) : first_{first}, rest_{rest...} {}
|
||||
|
||||
First first_;
|
||||
ESO_t<Rest...> rest_;
|
||||
};
|
||||
|
||||
// Get Nth value from ESO
|
||||
template <size_t N, class T, class... Rest, bool F, bool R>
|
||||
CUTE_HOST_DEVICE constexpr decltype(auto) getv(ESO<F, R, T, Rest...> const& s) {
|
||||
if constexpr (N == 0) {
|
||||
if constexpr (F) { return T{}; }
|
||||
else { return static_cast<T const&>(s.first_); }
|
||||
} else {
|
||||
if constexpr (R) { return cute::tuple_element_t<N-1, cute::type_list<Rest...>>{}; }
|
||||
else { return getv<N-1>(s.rest_); }
|
||||
}
|
||||
}
|
||||
|
||||
template <size_t N, class T, class... Rest, bool F, bool R>
|
||||
CUTE_HOST_DEVICE constexpr decltype(auto) getv(ESO<F, R, T, Rest...>& s) {
|
||||
if constexpr (N == 0) {
|
||||
if constexpr (F) { return T{}; }
|
||||
else { return static_cast<T&>(s.first_); }
|
||||
} else {
|
||||
if constexpr (R) { return cute::tuple_element_t<N-1, cute::type_list<Rest...>>{}; }
|
||||
else { return getv<N-1>(s.rest_); }
|
||||
}
|
||||
}
|
||||
|
||||
template <size_t N, class T, class... Rest, bool F, bool R>
|
||||
CUTE_HOST_DEVICE constexpr decltype(auto) getv(ESO<F, R, T, Rest...>&& s) {
|
||||
if constexpr (N == 0) {
|
||||
if constexpr (F) { return T{}; }
|
||||
else { return static_cast<T&&>(s.first_); }
|
||||
} else {
|
||||
if constexpr (R) { return cute::tuple_element_t<N-1, cute::type_list<Rest...>>{}; }
|
||||
else { return getv<N-1>(static_cast<ESO_t<Rest...>&&>(s.rest_)); }
|
||||
}
|
||||
}
|
||||
|
||||
// findt: Implementation detail of cute::find.
|
||||
// If X is the first template argument of the tuple, findt returns C<N>.
|
||||
|
||||
template <class X, size_t N,
|
||||
bool IsFirstEmpty, bool IsRestEmpty, class First, class... Rest>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
findt(ESO<IsFirstEmpty, IsRestEmpty, First, Rest...> const& t) noexcept
|
||||
{
|
||||
if constexpr (cute::is_same_v<X, First>) {
|
||||
return C<N>{};
|
||||
}
|
||||
else {
|
||||
static_assert(sizeof...(Rest) != 0,
|
||||
"The type does not appear in the argument list of the tuple.");
|
||||
if constexpr (IsRestEmpty) {
|
||||
// The rest is empty, so creating an instance of it is cheap.
|
||||
return cute::detail::findt<X, N+1>(ESO_t<Rest...>{});
|
||||
}
|
||||
else {
|
||||
return cute::detail::findt<X, N+1>(t.rest_);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
// packed_tuple<T...> is a tuple type that is a standard-layout type
|
||||
// whenever all of its template arguments are standard layout types:
|
||||
// (cute::is_standard_layout_v<T> && ...) implies (cute::is_standard_layout_v<packed_tuple<T...>>)
|
||||
|
||||
template <class... T>
|
||||
struct packed_tuple : detail::ESO_t<T...>
|
||||
{
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
packed_tuple() {}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
packed_tuple(T const&... ts)
|
||||
: detail::ESO_t<T...>(ts...)
|
||||
{}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct packed_tuple<> {};
|
||||
|
||||
template <size_t I, class... T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
get(packed_tuple<T...> const& t) {
|
||||
static_assert(I < sizeof...(T), "Index out of range");
|
||||
return detail::getv<I>(t);
|
||||
}
|
||||
|
||||
template <size_t I, class... T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
get(packed_tuple<T...>& t) {
|
||||
static_assert(I < sizeof...(T), "Index out of range");
|
||||
return detail::getv<I>(t);
|
||||
}
|
||||
|
||||
template <size_t I, class... T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
get(packed_tuple<T...>&& t) {
|
||||
static_assert(I < sizeof...(T), "Index out of range");
|
||||
return detail::getv<I>(static_cast<detail::ESO_t<T...>&&>(t));
|
||||
}
|
||||
|
||||
template <class... T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
packed_tuple<T...>
|
||||
make_packed_tuple(T const&... t)
|
||||
{
|
||||
return {t...};
|
||||
}
|
||||
|
||||
// Returns the position of type X (as a static integer) in the tuple
|
||||
// type's argument list. X must be unique in the argument list.
|
||||
template <class X, class... T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
find(packed_tuple<T...> const& t) noexcept
|
||||
{
|
||||
return detail::findt<X, 0>(t);
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
|
||||
namespace CUTE_STL_NAMESPACE
|
||||
{
|
||||
|
||||
template <class... T>
|
||||
struct tuple_size<cute::packed_tuple<T...>>
|
||||
: CUTE_STL_NAMESPACE::integral_constant<size_t, sizeof...(T)>
|
||||
{};
|
||||
|
||||
template <size_t I, class... T>
|
||||
struct tuple_element<I, cute::packed_tuple<T...>>
|
||||
: CUTE_STL_NAMESPACE::tuple_element<I, CUTE_STL_NAMESPACE::tuple<T...>>
|
||||
{};
|
||||
|
||||
} // end namespace CUTE_STL_NAMESPACE
|
||||
|
||||
#ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD
|
||||
namespace std {
|
||||
|
||||
template <class ... T>
|
||||
struct tuple_size<cute::packed_tuple<T...>>
|
||||
: CUTE_STL_NAMESPACE::integral_constant<size_t, sizeof...(T)>
|
||||
{};
|
||||
|
||||
template <size_t I, class ... T>
|
||||
struct tuple_element<I, cute::packed_tuple<T...>>
|
||||
: CUTE_STL_NAMESPACE::tuple_element<I, cute::packed_tuple<T...>>
|
||||
{};
|
||||
|
||||
} // end namespace std
|
||||
#endif // CUTE_STL_NAMESPACE_IS_CUDA_STD
|
||||
@@ -37,169 +37,183 @@
|
||||
|
||||
#include <cute/container/cuda_types.hpp>
|
||||
#include <cute/container/type_list.hpp>
|
||||
#if defined(CUTLASS_USE_PACKED_TUPLE)
|
||||
# include <cute/container/packed_tuple.hpp>
|
||||
#endif
|
||||
|
||||
//#include <cute/container/array.hpp> // Advanced optimizations
|
||||
|
||||
// cute::tuple is like std::tuple, with two differences.
|
||||
// cute::tuple is like std::tuple, with differences:
|
||||
//
|
||||
// 1. It works on both host and device.
|
||||
// 2. Its template arguments must be semiregular types.
|
||||
// 3. It is always a standard-layout type if all of its template arguments are standard-layout types.
|
||||
// 4. It is always an empty type if all of its template arguments are empty types.
|
||||
//
|
||||
// Semiregular types are default constructible and copyable.
|
||||
// They include "value types" like int or float,
|
||||
// but do _not_ include references like int& or float&.
|
||||
// (See std::tie for an example of a tuple of references.)
|
||||
//
|
||||
// If the template arguments of cute::tuple are all empty types (in
|
||||
// the sense of std::is_empty_v), then the cute::tuple is also an
|
||||
// empty type. Furthermore, if CUTLASS_USE_PACKED_TUPLE is defined,
|
||||
// cute::tuple is always a standard-layout type if all of its template
|
||||
// arguments are standard-layout types.
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
#if defined(CUTLASS_USE_PACKED_TUPLE)
|
||||
|
||||
template<class... T>
|
||||
using tuple = packed_tuple<T...>;
|
||||
|
||||
#else
|
||||
|
||||
namespace detail
|
||||
{
|
||||
|
||||
// This is simplified over the implementations in std::, cuda::std::, and thrust:: by ignoring much of
|
||||
// Standard-layout types preserve ABI across host-device boundaries.
|
||||
// They are safe to use as device kernel parameters.
|
||||
//
|
||||
// The cute::tuple is also simplified over the implementations in std::, cuda::std::, and thrust:: by ignoring much of
|
||||
// the conversion SFINAE, special overloading, and avoiding cvref template types.
|
||||
//
|
||||
// Over standard-conforming tuple implementations, this appears to accelerate compilation times by over 3x.
|
||||
|
||||
// EBO stands for "empty base optimization."
|
||||
namespace cute
|
||||
{
|
||||
|
||||
namespace detail
|
||||
{
|
||||
|
||||
// ESO stands for "empty structure optimization."
|
||||
// We use this technique to ensure that cute::tuple
|
||||
// doesn't need to waste space storing any template arguments
|
||||
// of cute::tuple that have no data (like integral_constant).
|
||||
// Otherwise, cute::tuple would need to spend at least 1 byte
|
||||
// for each of its template arguments.
|
||||
//
|
||||
// This is one way in which cute::tuple differs from std::tuple.
|
||||
// doesn't waste space storing template arguments that have no data (like integral_constant).
|
||||
// Empty types in the template argument list are not even constructed,
|
||||
// and do not have unique element addresses. In fact, they are not
|
||||
// even members of the tuple or stored in any way. Calling `get`
|
||||
// and do not have unique element addresses. Calling `get`
|
||||
// constructs and returns an instance of an empty type on demand.
|
||||
//
|
||||
// EBO always "holds" a single value of type T.
|
||||
// N is like an array index that TupleBase uses
|
||||
// to access the desired tuple element.
|
||||
template <size_t N, class T, bool IsEmpty = is_empty<T>::value>
|
||||
struct EBO;
|
||||
|
||||
template <class T, size_t N, bool B>
|
||||
CUTE_HOST_DEVICE constexpr C<N> findt(EBO<N, T, B> const&)
|
||||
{ return {}; }
|
||||
template <bool IsFirstEmpty, bool IsRestEmpty, class... T>
|
||||
struct ESO;
|
||||
|
||||
// Specialization for types T that have no data;
|
||||
// the "static tuple leaf." Valid T here include
|
||||
// integral_constant<U, Value>, Int<Value>,
|
||||
// and any other semiregular type
|
||||
// for which std::is_empty_v<T> is true.
|
||||
template <size_t N, class T>
|
||||
struct EBO<N, T, true>
|
||||
{
|
||||
template <class First, class... Rest>
|
||||
static constexpr bool is_first_empty_v = cute::is_empty<First>::value;
|
||||
template <class First, class... Rest>
|
||||
static constexpr bool is_rest_empty_v = (cute::is_empty<Rest>::value && ...);
|
||||
|
||||
template <class... T>
|
||||
using ESO_t = ESO<is_first_empty_v<T...>, is_rest_empty_v<T...>, T...>;
|
||||
|
||||
// Empty First and Empty Rest...
|
||||
template <class First, class... Rest>
|
||||
struct ESO<true, true, First, Rest...> {
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
EBO() {}
|
||||
ESO() {}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
EBO(T const&) {}
|
||||
ESO(First const&, Rest const&...) {}
|
||||
};
|
||||
|
||||
template <size_t N, class T>
|
||||
CUTE_HOST_DEVICE constexpr T getv(EBO<N, T, true> const&)
|
||||
{ return {}; }
|
||||
|
||||
// This is a work around approach to solve a shared memory misalign issue (https://github.com/NVIDIA/cutlass/issues/1250).
|
||||
// Will remove this work around implementation once the corresponding fix in compiler is released.
|
||||
struct dummy_EBO_base {};
|
||||
|
||||
// Specialization for types T that are not empty;
|
||||
// the "dynamic tuple leaf." Valid T here include int,
|
||||
// any other integral or floating-point type,
|
||||
// or any semiregular type for which std::is_empty_v<T> is false.
|
||||
template <size_t N, class T>
|
||||
struct EBO<N, T, false> : private dummy_EBO_base
|
||||
{
|
||||
// NonEmpty First and Empty Rest...
|
||||
template <class First, class... Rest>
|
||||
struct ESO<false, true, First, Rest...> {
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
EBO() : t_{} {}
|
||||
ESO() : first_{} {}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
EBO(T const& t) : t_{t} {}
|
||||
ESO(First const& first, Rest const&...) : first_{first} {}
|
||||
|
||||
T t_;
|
||||
First first_;
|
||||
};
|
||||
|
||||
template <size_t N, class T>
|
||||
CUTE_HOST_DEVICE constexpr T const& getv(EBO<N, T, false> const& x)
|
||||
{ return x.t_; }
|
||||
|
||||
template <size_t N, class T>
|
||||
CUTE_HOST_DEVICE constexpr T& getv(EBO<N, T, false>& x)
|
||||
{ return x.t_; }
|
||||
|
||||
template <size_t N, class T>
|
||||
CUTE_HOST_DEVICE constexpr T&& getv(EBO<N, T, false>&& x)
|
||||
{ return cute::move(x.t_); }
|
||||
|
||||
template <class IdxSeq, class... T>
|
||||
struct TupleBase;
|
||||
|
||||
// Base class of cute::tuple binds each element to an index
|
||||
// by inheriting from EBO<i, t> for each (i, t) in (I..., T...).
|
||||
// The storage (for nonempty t) lives in the base classes.
|
||||
template <size_t... I, class... T>
|
||||
struct TupleBase<index_sequence<I...>, T...>
|
||||
: EBO<I,T>...
|
||||
{
|
||||
// Empty First and NonEmpty Rest...
|
||||
template <class First, class... Rest>
|
||||
struct ESO<true, false, First, Rest...> {
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
TupleBase() {}
|
||||
ESO() : rest_{} {}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
TupleBase(T const&... t) : EBO<I,T>(t)... {}
|
||||
ESO(First const&, Rest const&... rest) : rest_{rest...} {}
|
||||
|
||||
ESO_t<Rest...> rest_;
|
||||
};
|
||||
|
||||
// NonEmpty T and NonEmpty Rest...
|
||||
template <class First, class... Rest>
|
||||
struct ESO<false, false, First, Rest...> {
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
ESO() : first_{}, rest_{} {}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
ESO(First const& first, Rest const&... rest) : first_{first}, rest_{rest...} {}
|
||||
|
||||
First first_;
|
||||
ESO_t<Rest...> rest_;
|
||||
};
|
||||
|
||||
// Get Nth value from ESO
|
||||
template <size_t N, bool F, bool R, class T, class... Rest>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
cute::enable_if_t<cute::is_empty<cute::tuple_element_t<N, cute::type_list<T, Rest...>>>::value,
|
||||
cute::tuple_element_t<N, cute::type_list<T, Rest...>>>
|
||||
getv(ESO<F, R, T, Rest...> const&)
|
||||
{
|
||||
return {};
|
||||
}
|
||||
|
||||
template <size_t N, bool F, bool R, class T, class... Rest>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
cute::enable_if_t<not cute::is_empty<cute::tuple_element_t<N, cute::type_list<T, Rest...>>>::value,
|
||||
cute::tuple_element_t<N, cute::type_list<T, Rest...>> const&>
|
||||
getv(ESO<F, R, T, Rest...> const& s)
|
||||
{
|
||||
if constexpr (N == 0) {
|
||||
return static_cast<T const&>(s.first_);
|
||||
} else {
|
||||
return getv<N-1>(s.rest_);
|
||||
}
|
||||
}
|
||||
|
||||
template <size_t N, bool F, bool R, class T, class... Rest>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
cute::enable_if_t<not cute::is_empty<cute::tuple_element_t<N, cute::type_list<T, Rest...>>>::value,
|
||||
cute::tuple_element_t<N, cute::type_list<T, Rest...>> &>
|
||||
getv(ESO<F, R, T, Rest...>& s)
|
||||
{
|
||||
if constexpr (N == 0) {
|
||||
return static_cast<T&>(s.first_);
|
||||
} else {
|
||||
return getv<N-1>(s.rest_);
|
||||
}
|
||||
}
|
||||
|
||||
template <size_t N, bool F, bool R, class T, class... Rest>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
cute::enable_if_t<not cute::is_empty<cute::tuple_element_t<N, cute::type_list<T, Rest...>>>::value,
|
||||
cute::tuple_element_t<N, cute::type_list<T, Rest...>> &&>
|
||||
getv(ESO<F, R, T, Rest...>&& s)
|
||||
{
|
||||
if constexpr (N == 0) {
|
||||
return static_cast<T&&>(s.first_);
|
||||
} else {
|
||||
return getv<N-1>(static_cast<ESO_t<Rest...>&&>(s.rest_));
|
||||
}
|
||||
}
|
||||
|
||||
template <class X, size_t N,
|
||||
bool IsFirstEmpty, bool IsRestEmpty, class First, class... Rest>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
findt(ESO<IsFirstEmpty, IsRestEmpty, First, Rest...> const& t) noexcept
|
||||
{
|
||||
if constexpr (cute::is_same_v<X, First>) {
|
||||
return C<N>{};
|
||||
} else
|
||||
if constexpr (sizeof...(Rest) == 0) {
|
||||
return C<N+1>{};
|
||||
} else
|
||||
if constexpr (IsRestEmpty) {
|
||||
return cute::detail::findt<X, N+1>(ESO_t<Rest...>{});
|
||||
} else {
|
||||
return cute::detail::findt<X, N+1>(t.rest_);
|
||||
}
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
// Attempting to use the following commented-out alias
|
||||
// in the declaration of `struct tuple` causes MSVC 2022 build errors.
|
||||
//
|
||||
//template <class... T>
|
||||
//using TupleBase = detail::TupleBase<make_index_sequence<sizeof...(T)>, T...>;
|
||||
|
||||
// This is the actual cute::tuple class.
|
||||
// The storage (if any) lives in TupleBase's EBO base classes.
|
||||
//
|
||||
// Inheriting from the above alias TupleBase
|
||||
// causes MSVC 2022 build errors when assigning one tuple to another:
|
||||
// In summary: this is verbose as a work-around for MSVC build errors.
|
||||
template <class... T>
|
||||
struct tuple : detail::TupleBase<make_index_sequence<sizeof...(T)>, T...>
|
||||
struct tuple : detail::ESO_t<T...>
|
||||
{
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
tuple() {}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
tuple(T const&... t) : detail::TupleBase<make_index_sequence<sizeof...(T)>, T...>(t...) {}
|
||||
tuple(T const&... t) : detail::ESO_t<T...>(t...) {}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct tuple<>
|
||||
{};
|
||||
|
||||
//
|
||||
// get for cute::tuple (just like std::get for std::tuple)
|
||||
//
|
||||
struct tuple<> {};
|
||||
|
||||
// Returns the element in the ith position of the tuple
|
||||
template <size_t I, class... T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
@@ -224,25 +238,19 @@ decltype(auto)
|
||||
get(tuple<T...>&& t) noexcept
|
||||
{
|
||||
static_assert(I < sizeof...(T), "Index out of range");
|
||||
return detail::getv<I>(static_cast<tuple<T...>&&>(t));
|
||||
return detail::getv<I>(static_cast<detail::ESO_t<T...>&&>(t));
|
||||
}
|
||||
|
||||
//
|
||||
// find a type X within a cute::tuple
|
||||
// Requires X to be unique in tuple
|
||||
// Returns a static integer
|
||||
//
|
||||
|
||||
// Returns the position of type X (as a static integer) in the tuple
|
||||
// type's argument list. X must be unique in the argument list.
|
||||
template <class X, class... T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
find(tuple<T...> const& t) noexcept
|
||||
{
|
||||
return detail::findt<X>(t);
|
||||
return detail::findt<X, 0>(t);
|
||||
}
|
||||
|
||||
#endif // CUTLASS_USE_PACKED_TUPLE
|
||||
|
||||
//
|
||||
// Custom is_tuple trait simply checks the existence of tuple_size
|
||||
// and assumes std::get<I>(.), std::tuple_element<I,.>
|
||||
@@ -258,7 +266,7 @@ auto has_tuple_size(...) -> false_type;
|
||||
template <class T>
|
||||
struct is_tuple : decltype(detail::has_tuple_size((T*)0)) {};
|
||||
|
||||
template<typename T>
|
||||
template <class T>
|
||||
constexpr bool is_tuple_v = cute::is_tuple<T>::value;
|
||||
|
||||
//
|
||||
@@ -679,8 +687,6 @@ CUTE_HOST std::ostream& operator<<(std::ostream& os, Tuple const& t)
|
||||
|
||||
} // end namespace cute
|
||||
|
||||
#if ! defined(CUTLASS_USE_PACKED_TUPLE)
|
||||
|
||||
namespace CUTE_STL_NAMESPACE
|
||||
{
|
||||
|
||||
@@ -694,22 +700,8 @@ struct tuple_element<I, cute::tuple<T...>>
|
||||
: CUTE_STL_NAMESPACE::tuple_element<I, CUTE_STL_NAMESPACE::tuple<T...>>
|
||||
{};
|
||||
|
||||
template <class... T>
|
||||
struct tuple_size<const cute::tuple<T...>>
|
||||
: CUTE_STL_NAMESPACE::integral_constant<size_t, sizeof...(T)>
|
||||
{};
|
||||
|
||||
template <size_t I, class... T>
|
||||
struct tuple_element<I, const cute::tuple<T...>>
|
||||
: CUTE_STL_NAMESPACE::tuple_element<I, const CUTE_STL_NAMESPACE::tuple<T...>>
|
||||
{};
|
||||
|
||||
} // end namespace CUTE_STL_NAMESPACE
|
||||
|
||||
//
|
||||
// std compatibility
|
||||
//
|
||||
|
||||
#ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD
|
||||
namespace std
|
||||
{
|
||||
@@ -732,17 +724,5 @@ struct tuple_element<I, cute::tuple<T...>>
|
||||
: CUTE_STL_NAMESPACE::tuple_element<I, CUTE_STL_NAMESPACE::tuple<T...>>
|
||||
{};
|
||||
|
||||
template <class... T>
|
||||
struct tuple_size<const cute::tuple<T...>>
|
||||
: CUTE_STL_NAMESPACE::integral_constant<size_t, sizeof...(T)>
|
||||
{};
|
||||
|
||||
template <size_t I, class... T>
|
||||
struct tuple_element<I, const cute::tuple<T...>>
|
||||
: CUTE_STL_NAMESPACE::tuple_element<I, const CUTE_STL_NAMESPACE::tuple<T...>>
|
||||
{};
|
||||
|
||||
} // end namespace std
|
||||
#endif // CUTE_STL_NAMESPACE_IS_CUDA_STD
|
||||
|
||||
#endif // CUTLASS_USE_PACKED_TUPLE
|
||||
|
||||
@@ -73,17 +73,6 @@ struct tuple_element<I, cute::type_list<T...>>
|
||||
using type = typename CUTE_STL_NAMESPACE::tuple_element<I, CUTE_STL_NAMESPACE::tuple<T...>>::type;
|
||||
};
|
||||
|
||||
template <class... T>
|
||||
struct tuple_size<const cute::type_list<T...>>
|
||||
: CUTE_STL_NAMESPACE::integral_constant<size_t, sizeof...(T)>
|
||||
{};
|
||||
|
||||
template <size_t I, class... T>
|
||||
struct tuple_element<I, const cute::type_list<T...>>
|
||||
{
|
||||
using type = typename CUTE_STL_NAMESPACE::tuple_element<I, CUTE_STL_NAMESPACE::tuple<T...>>::type;
|
||||
};
|
||||
|
||||
} // end namespace std
|
||||
|
||||
#ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD
|
||||
@@ -109,16 +98,5 @@ struct tuple_element<I, cute::type_list<T...>>
|
||||
using type = typename CUTE_STL_NAMESPACE::tuple_element<I, CUTE_STL_NAMESPACE::tuple<T...>>::type;
|
||||
};
|
||||
|
||||
template <class... T>
|
||||
struct tuple_size<const cute::type_list<T...>>
|
||||
: CUTE_STL_NAMESPACE::integral_constant<size_t, sizeof...(T)>
|
||||
{};
|
||||
|
||||
template <size_t I, class... T>
|
||||
struct tuple_element<I, const cute::type_list<T...>>
|
||||
{
|
||||
using type = typename CUTE_STL_NAMESPACE::tuple_element<I, CUTE_STL_NAMESPACE::tuple<T...>>::type;
|
||||
};
|
||||
|
||||
} // end namespace std
|
||||
#endif // CUTE_STL_NAMESPACE_IS_CUDA_STD
|
||||
|
||||
@@ -330,7 +330,7 @@ ceil_div(IntTupleA const& a, IntTupleB const& b)
|
||||
constexpr int R = tuple_size<IntTupleA>::value; // Missing ranks in TupleB are implicitly 1
|
||||
return transform(a, append<R>(b,Int<1>{}), [](auto const& x, auto const& y) { return ceil_div(x,y); });
|
||||
} else { // tuple int
|
||||
auto const [result, rest] = fold(a, cute::make_tuple(cute::make_tuple(), b),
|
||||
auto [result, rest] = fold(a, cute::make_tuple(cute::make_tuple(), b),
|
||||
[] (auto const& init, auto const& ai) {
|
||||
return cute::make_tuple(append(get<0>(init), ceil_div(ai, get<1>(init))), ceil_div(get<1>(init), ai));
|
||||
});
|
||||
@@ -390,7 +390,7 @@ shape_div(IntTupleA const& a, IntTupleB const& b)
|
||||
static_assert(tuple_size<IntTupleA>::value == tuple_size<IntTupleB>::value, "Mismatched ranks");
|
||||
return transform(a, b, [](auto const& x, auto const& y) { return shape_div(x,y); });
|
||||
} else { // tuple int
|
||||
auto const [result, rest] = fold(a, cute::make_tuple(cute::make_tuple(), b),
|
||||
auto [result, rest] = fold(a, cute::make_tuple(cute::make_tuple(), b),
|
||||
[] (auto const& init, auto const& ai) {
|
||||
return cute::make_tuple(append(get<0>(init), shape_div(ai, get<1>(init))), shape_div(get<1>(init), ai));
|
||||
});
|
||||
|
||||
@@ -1044,7 +1044,7 @@ composition_impl(LShape const& lhs_shape, LStride const& lhs_stride,
|
||||
auto result_shape_0 = take<0,R-1>(lhs_shape);
|
||||
|
||||
// Mod out the rhs_shape from the lhs_shape
|
||||
auto const [result_shape_1, rest_shape] = fold(result_shape_0, cute::make_tuple(cute::make_tuple(), rhs_shape),
|
||||
auto [result_shape_1, rest_shape] = fold(result_shape_0, cute::make_tuple(cute::make_tuple(), rhs_shape),
|
||||
[] (auto const& init, auto const& si) {
|
||||
return cute::make_tuple(append(get<0>(init), shape_min(abs(si), get<1>(init))), shape_div(get<1>(init), abs(si)));
|
||||
});
|
||||
@@ -1058,7 +1058,7 @@ composition_impl(LShape const& lhs_shape, LStride const& lhs_stride,
|
||||
auto result_stride_0 = take<0,R-1>(lhs_stride);
|
||||
|
||||
// Divide out the rhs_stride from the lhs_shape
|
||||
auto const [result_shape_1, rest_stride] = fold(result_shape_0, cute::make_tuple(cute::make_tuple(), rhs_stride),
|
||||
auto [result_shape_1, rest_stride] = fold(result_shape_0, cute::make_tuple(cute::make_tuple(), rhs_stride),
|
||||
[] (auto const& init, auto const& di) {
|
||||
return cute::make_tuple(append(get<0>(init), shape_div(di, get<1>(init))), shape_div(get<1>(init), di));
|
||||
});
|
||||
@@ -1067,7 +1067,7 @@ composition_impl(LShape const& lhs_shape, LStride const& lhs_stride,
|
||||
auto result_stride_1 = elem_scale(result_stride_0, shape_div(result_shape_0, result_shape_1));
|
||||
|
||||
// Mod out the rhs_shape from the lhs_shape
|
||||
auto const [result_shape_2, rest_shape] = fold(result_shape_1, cute::make_tuple(cute::make_tuple(), rhs_shape),
|
||||
auto [result_shape_2, rest_shape] = fold(result_shape_1, cute::make_tuple(cute::make_tuple(), rhs_shape),
|
||||
[] (auto const& init, auto const& si) {
|
||||
return cute::make_tuple(append(get<0>(init), shape_min(abs(si), get<1>(init))), shape_div(get<1>(init), abs(si)));
|
||||
});
|
||||
|
||||
@@ -508,16 +508,6 @@ struct tuple_element<I, cute::ArithmeticTuple<T...>>
|
||||
: CUTE_STL_NAMESPACE::tuple_element<I, CUTE_STL_NAMESPACE::tuple<T...>>
|
||||
{};
|
||||
|
||||
template <class... T>
|
||||
struct tuple_size<const cute::ArithmeticTuple<T...>>
|
||||
: CUTE_STL_NAMESPACE::integral_constant<size_t, sizeof...(T)>
|
||||
{};
|
||||
|
||||
template <size_t I, class... T>
|
||||
struct tuple_element<I, const cute::ArithmeticTuple<T...>>
|
||||
: CUTE_STL_NAMESPACE::tuple_element<I, const CUTE_STL_NAMESPACE::tuple<T...>>
|
||||
{};
|
||||
|
||||
} // end namespace CUTE_STL_NAMESPACE
|
||||
|
||||
#ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD
|
||||
@@ -542,15 +532,5 @@ struct tuple_element<I, cute::ArithmeticTuple<T...>>
|
||||
: CUTE_STL_NAMESPACE::tuple_element<I, CUTE_STL_NAMESPACE::tuple<T...>>
|
||||
{};
|
||||
|
||||
template <class... T>
|
||||
struct tuple_size<const cute::ArithmeticTuple<T...>>
|
||||
: CUTE_STL_NAMESPACE::integral_constant<size_t, sizeof...(T)>
|
||||
{};
|
||||
|
||||
template <size_t I, class... T>
|
||||
struct tuple_element<I, const cute::ArithmeticTuple<T...>>
|
||||
: CUTE_STL_NAMESPACE::tuple_element<I, const CUTE_STL_NAMESPACE::tuple<T...>>
|
||||
{};
|
||||
|
||||
} // end namespace std
|
||||
#endif // CUTE_STL_NAMESPACE_IS_CUDA_STD
|
||||
|
||||
@@ -84,7 +84,6 @@ using CUTE_STL_NAMESPACE::uint16_t;
|
||||
using CUTE_STL_NAMESPACE::uint32_t;
|
||||
using CUTE_STL_NAMESPACE::uint64_t;
|
||||
using cutlass::uint128_t;
|
||||
|
||||
template <int N> struct uint_bit;
|
||||
template <> struct uint_bit< 1> { using type = uint1_t; };
|
||||
template <> struct uint_bit< 2> { using type = uint2_t; };
|
||||
@@ -95,7 +94,6 @@ template <> struct uint_bit< 16> { using type = uint16_t; };
|
||||
template <> struct uint_bit< 32> { using type = uint32_t; };
|
||||
template <> struct uint_bit< 64> { using type = uint64_t; };
|
||||
template <> struct uint_bit<128> { using type = cutlass::uint128_t; };
|
||||
|
||||
template <int N>
|
||||
using uint_bit_t = typename uint_bit<N>::type;
|
||||
|
||||
|
||||
@@ -235,7 +235,7 @@ struct Tensor
|
||||
decltype(auto)
|
||||
operator()(Coord const& coord) {
|
||||
if constexpr (has_underscore<Coord>::value) {
|
||||
auto const& [sliced_layout,offset] = slice_and_offset(coord, layout());
|
||||
auto [sliced_layout,offset] = slice_and_offset(coord, layout());
|
||||
return make_tensor(data() + offset, sliced_layout);
|
||||
} else {
|
||||
return data()[layout()(coord)];
|
||||
@@ -249,7 +249,7 @@ struct Tensor
|
||||
decltype(auto)
|
||||
operator()(Coord const& coord) const {
|
||||
if constexpr (has_underscore<Coord>::value) {
|
||||
auto const& [sliced_layout,offset] = slice_and_offset(coord, layout());
|
||||
auto [sliced_layout,offset] = slice_and_offset(coord, layout());
|
||||
return make_tensor(data() + offset, sliced_layout);
|
||||
} else {
|
||||
return data()[layout()(coord)];
|
||||
|
||||
@@ -102,6 +102,7 @@
|
||||
#if (!defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && defined(__CUDA_ARCH_FEAT_SM100_ALL))
|
||||
#define CUTLASS_ARCH_MMA_SM100A_ENABLED 1
|
||||
#endif
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
||||
|
||||
@@ -38,10 +38,14 @@
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#ifndef CUDA_CTA_RECONFIG_ACTIVATED
|
||||
#if (__CUDACC_VER_MAJOR__ >= 12 && \
|
||||
defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL))
|
||||
#if defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ >= 12 && ( \
|
||||
(__CUDA_ARCH__ == 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL)) \
|
||||
|| (__CUDA_ARCH__ == 1000 && defined(__CUDA_ARCH_FEAT_SM100_ALL)) \
|
||||
)
|
||||
#define CUDA_CTA_RECONFIG_ACTIVATED 1
|
||||
#endif
|
||||
|
||||
|
||||
#endif
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
@@ -106,7 +106,6 @@ struct CollectiveConv<
|
||||
|
||||
using ProblemShape = ConvProblemShape<ConvOp, NumSpatialDimensions>;
|
||||
|
||||
// TODO: move pipeline mode tiling into the collective setup phase instead
|
||||
static_assert(rank(SmemLayoutA{}) == 3, "SmemLayout must be rank 3 (M/N, K, PIPE)");
|
||||
static_assert((size<0>(TileShape{}) == size<0>(SmemLayoutA{})), "SmemLayout must be compatible with the tile shape.");
|
||||
static_assert((size<2>(TileShape{}) == size<1>(SmemLayoutA{})), "SmemLayout must be compatible with the tile shape.");
|
||||
|
||||
@@ -255,23 +255,27 @@ public:
|
||||
CUTLASS_HOST_DEVICE
|
||||
int64_t activation_size() const {
|
||||
|
||||
return (N * H * W * C);
|
||||
return static_cast<int64_t>(N) * static_cast<int64_t>(H) *
|
||||
static_cast<int64_t>(W) * static_cast<int64_t>(C);
|
||||
}
|
||||
|
||||
/// Returns filter size in number of elements
|
||||
CUTLASS_HOST_DEVICE
|
||||
int64_t filter_size() const {
|
||||
|
||||
return (K * R * S * C / groups);
|
||||
return static_cast<int64_t>(K) * static_cast<int64_t>(R) *
|
||||
static_cast<int64_t>(S) * static_cast<int64_t>(C) /
|
||||
static_cast<int64_t>(groups);
|
||||
}
|
||||
|
||||
/// Returns output size in number of elements
|
||||
CUTLASS_HOST_DEVICE
|
||||
int64_t output_size() const {
|
||||
|
||||
return (N * P * Q * K);
|
||||
return static_cast<int64_t>(N) * static_cast<int64_t>(P) *
|
||||
static_cast<int64_t>(Q) * static_cast<int64_t>(K);
|
||||
}
|
||||
|
||||
|
||||
/// Returns padding as Tensor4DCoord
|
||||
CUTLASS_HOST_DEVICE
|
||||
cutlass::Tensor4DCoord padding() const {
|
||||
|
||||
@@ -285,21 +285,27 @@ public:
|
||||
CUTLASS_HOST_DEVICE
|
||||
int64_t activation_size() const {
|
||||
|
||||
return (N * D * H * W * C);
|
||||
return static_cast<int64_t>(N) * static_cast<int64_t>(D) *
|
||||
static_cast<int64_t>(H) * static_cast<int64_t>(W) *
|
||||
static_cast<int64_t>(C);
|
||||
}
|
||||
|
||||
/// Returns filter size in number of elements
|
||||
CUTLASS_HOST_DEVICE
|
||||
int64_t filter_size() const {
|
||||
|
||||
return (K * T * R * S * C);
|
||||
return static_cast<int64_t>(K) * static_cast<int64_t>(T) *
|
||||
static_cast<int64_t>(R) * static_cast<int64_t>(S) *
|
||||
static_cast<int64_t>(C);
|
||||
}
|
||||
|
||||
/// Returns output size in number of elements
|
||||
CUTLASS_HOST_DEVICE
|
||||
int64_t output_size() const {
|
||||
|
||||
return (N * Z * P * Q * K);
|
||||
return static_cast<int64_t>(N) * static_cast<int64_t>(Z) *
|
||||
static_cast<int64_t>(P) * static_cast<int64_t>(Q) *
|
||||
static_cast<int64_t>(K);
|
||||
}
|
||||
|
||||
/// Returns padding as Coord3D
|
||||
|
||||
@@ -114,6 +114,33 @@ public:
|
||||
return status;
|
||||
}
|
||||
|
||||
// Check that tensor sizes don't exceed maximum supported size
|
||||
if (kConvolutionalOperator == conv::Operator::kFprop) {
|
||||
if (args.problem_size.activation_size() * sizeof(ElementA) >=
|
||||
(1ull << 31) ||
|
||||
args.problem_size.filter_size() * sizeof(ElementB) >= (1ull << 31) ||
|
||||
args.problem_size.output_size() * sizeof(ElementC) >= (1ull << 31)) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
}
|
||||
else if (kConvolutionalOperator == conv::Operator::kDgrad ||
|
||||
kConvolutionalOperator == conv::Operator::kDeconv) {
|
||||
if (args.problem_size.activation_size() * sizeof(ElementC) >=
|
||||
(1ull << 31) ||
|
||||
args.problem_size.filter_size() * sizeof(ElementB) >= (1ull << 31) ||
|
||||
args.problem_size.output_size() * sizeof(ElementA) >= (1ull << 31)) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
}
|
||||
else if (kConvolutionalOperator == conv::Operator::kWgrad) {
|
||||
if (args.problem_size.activation_size() * sizeof(ElementB) >=
|
||||
(1ull << 31) ||
|
||||
args.problem_size.filter_size() * sizeof(ElementC) >= (1ull << 31) ||
|
||||
args.problem_size.output_size() * sizeof(ElementA) >= (1ull << 31)) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
}
|
||||
|
||||
// check group conv constraint
|
||||
if (args.problem_size.groups != 1) {
|
||||
if (kGroupMode == conv::GroupMode::kNone) {
|
||||
|
||||
@@ -104,7 +104,7 @@ namespace cutlass {
|
||||
|
||||
#else // defined(CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL)
|
||||
|
||||
#if (__CUDACC_VER_MAJOR__ >= 13)
|
||||
#if (__CUDACC_VER_MAJOR__ > 12)
|
||||
|
||||
#define CUTLASS_CUDA_DRIVER_WRAPPER_DECL(func, ver) \
|
||||
template <typename... Args> \
|
||||
@@ -142,7 +142,7 @@ namespace cutlass {
|
||||
return reinterpret_cast<PFN_##func>(pfn)(args...); \
|
||||
}
|
||||
|
||||
#endif // (__CUDACC_VERSION__ >= 12.5)
|
||||
#endif // (__CUDACC_VER_MAJOR__ > 12)
|
||||
|
||||
#endif // defined(CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL)
|
||||
|
||||
|
||||
@@ -69,7 +69,7 @@ struct LayoutAwareConvertImpl {
|
||||
auto&& src_vm = cute::recast<SrcArray>(src);
|
||||
auto&& dst_vm = cute::recast<DstArray>(dst);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i <src.size(); ++i) {
|
||||
for (int i = 0; i < src_vm.size(); ++i) {
|
||||
dst_vm(i) = Converter::convert(src_vm(i));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -60,7 +60,7 @@
|
||||
#if ! defined(_MSC_VER)
|
||||
#define CUTLASS_LAMBDA_FUNC_INLINE __attribute__((always_inline))
|
||||
#else
|
||||
#define CUTLASS_LAMBDA_FUNC_INLINE
|
||||
#define CUTLASS_LAMBDA_FUNC_INLINE [[msvc::forceinline]]
|
||||
#endif
|
||||
|
||||
#define CUTLASS_HOST __host__
|
||||
|
||||
@@ -98,142 +98,6 @@ sm100_get_epilogue_smem_swizzle_layout_atom() {
|
||||
}
|
||||
}
|
||||
|
||||
// Attempts to compute a reasonable epilogue tile based on block tile shape or allows the user to provide one.
|
||||
template <
|
||||
class OpClass,
|
||||
class CtaTileShape_MNK,
|
||||
class EpilogueTileType,
|
||||
class TmemWarpShape_MN,
|
||||
class ElementC,
|
||||
class StrideC,
|
||||
class ElementD,
|
||||
class StrideD,
|
||||
class FusionOp
|
||||
>
|
||||
constexpr auto
|
||||
sm100_compute_tile_shape_or_override() {
|
||||
using namespace cute;
|
||||
|
||||
if constexpr (cute::is_same_v<EpilogueTileType, EpilogueTileAuto> &&
|
||||
cute::is_same_v<OpClass, arch::OpClassBlockScaledTensorOp> &&
|
||||
size<1>(CtaTileShape_MNK{}) == 256) {
|
||||
constexpr int CtaM = size<0>(CtaTileShape_MNK{});
|
||||
constexpr int WarpM = size<0>(TmemWarpShape_MN{});
|
||||
constexpr int DpFull = 32;
|
||||
constexpr int M = cute::min(CtaM, DpFull * WarpM); // target 32dp tmem load
|
||||
// Note:
|
||||
// Set Epi_Tile_N to 128 support OverlappingAccum for the largest tile.
|
||||
// This is a general workable epi_tile_N which does not promise best perf.
|
||||
return make_tile(Int<M>{}, Int<128>{});
|
||||
}
|
||||
else if constexpr (cute::is_same_v<EpilogueTileType, EpilogueTileAuto>) {
|
||||
constexpr int CtaM = size<0>(CtaTileShape_MNK{});
|
||||
constexpr int CtaN = size<1>(CtaTileShape_MNK{});
|
||||
constexpr int WarpM = size<0>(TmemWarpShape_MN{});
|
||||
constexpr int WarpN = size<1>(TmemWarpShape_MN{});
|
||||
constexpr bool DisableSource = is_void_v<ElementC>;
|
||||
constexpr int MaxBits = cute::max(sizeof_bits_v<ElementC>, sizeof_bits_v<ElementD>);
|
||||
|
||||
constexpr int DpFull = 32; // tmem datapaths in 1 subpartition
|
||||
constexpr int M = cute::min(CtaM, DpFull * WarpM); // target 32dp tmem load
|
||||
constexpr int N_perf = [&]() constexpr { // Known subtile sizes tested for perf
|
||||
// Epilogues w/o residual load are less sensitive to smem allocation
|
||||
// Target a fixed amount of compute per epilogue iteration
|
||||
if (DisableSource) {
|
||||
if (MaxBits == 4) {
|
||||
// Make epilogue tile larger to reduce the epilogue iterations.
|
||||
// 64 is the experimental value. It will minimize epilogue iterations but keep the number of A/B buffers the same.
|
||||
constexpr int ComputeElts = 8192;
|
||||
return ComputeElts / M;
|
||||
}
|
||||
constexpr int ComputeElts = 4096;
|
||||
return ComputeElts / M;
|
||||
}
|
||||
// Epilogues w/ residual load are more sensitive to smem allocation
|
||||
// Target optimal smem distribution between epilogue+mainloop based on datatype+tilesize
|
||||
else {
|
||||
if (MaxBits == 32) {
|
||||
return (CtaM > 64 && CtaN <= 128) ? 16 : 32;
|
||||
}
|
||||
// Per-column scaling is high register pressure, reduce tile to prevent spills
|
||||
else if (FusionOp::IsPerColScaleSupported) {
|
||||
return 32;
|
||||
}
|
||||
else if (MaxBits == 16) {
|
||||
return (CtaN <= 128) ? 32 : 64;
|
||||
}
|
||||
else {
|
||||
return 64;
|
||||
}
|
||||
}
|
||||
}();
|
||||
constexpr int N_min_C = (DisableSource || detail::is_m_major<StrideC>()) ? 8 * WarpN
|
||||
: (sizeof_bits_v<ElementC> == 6) ? 128 * WarpN // TMA store only supports SW128B for FP6 data type
|
||||
: 128 / sizeof_bits_v<ElementC> * WarpN;
|
||||
constexpr int N_min_D = (detail::is_m_major<StrideD>()) ? 8 * WarpN
|
||||
: (sizeof_bits_v<ElementD> == 6) ? 128 * WarpN // TMA store only supports SW128B for FP6 data type
|
||||
: 128 / sizeof_bits_v<ElementD> * WarpN;
|
||||
constexpr int N = cute::min(CtaN, cute::max(N_perf, N_min_C, N_min_D));
|
||||
static_assert(CtaN >= N_min_C && CtaN >= N_min_D, "CTA tile too small");
|
||||
|
||||
// stride by tmem warp layout and return a by-mode tiler
|
||||
auto tile_m = Layout<Int<M>>{};
|
||||
auto tile_n = Layout<Shape <Int<N / WarpN>,Int< WarpN>>,
|
||||
Stride<Int< 1>,Int<CtaN / WarpN>>>{};
|
||||
|
||||
return make_tile(tile_m, coalesce(tile_n));
|
||||
}
|
||||
else if constexpr (cute::is_tuple<EpilogueTileType>::value) {
|
||||
EpilogueTileType epi_tile;
|
||||
constexpr int M = size<0>(shape(epi_tile));
|
||||
constexpr int N = size<1>(shape(epi_tile));
|
||||
|
||||
static_assert(!is_layout<EpilogueTileType>::value, "EpilogueTile must be a cute::Tile or cute::Shape");
|
||||
static_assert(TmemWarpShape_MN{} == Shape<_2,_2>{} && (M == 32 || M == 64) ||
|
||||
TmemWarpShape_MN{} == Shape<_4,_1>{} && (M == 64 || M == 128), "Unsupported tile shape");
|
||||
static_assert(N % 8 == 0, "Unsupported tile shape");
|
||||
|
||||
return epi_tile;
|
||||
}
|
||||
else {
|
||||
static_assert(cutlass::detail::dependent_false<EpilogueTileType>, "Invalid type for EpilogueTileType.");
|
||||
}
|
||||
}
|
||||
|
||||
template <class EpilogueScheduleType>
|
||||
static constexpr bool IsPtrArrayDispatchPolicy =
|
||||
cute::is_same_v<EpilogueScheduleType, PtrArrayTmaWarpSpecialized1Sm> ||
|
||||
cute::is_same_v<EpilogueScheduleType, PtrArrayTmaWarpSpecialized2Sm>;
|
||||
|
||||
|
||||
template <
|
||||
class CtaTileShape_MNK,
|
||||
class EpilogueTile_MN,
|
||||
class ElementC,
|
||||
class ElementD,
|
||||
class Schedule
|
||||
>
|
||||
constexpr auto
|
||||
sm100_get_tma_dispatch_policy() {
|
||||
using EpilogueTileShape_MN = decltype(product_each(shape(EpilogueTile_MN{})));
|
||||
constexpr int EpiTiles = size(shape_div(take<0,2>(CtaTileShape_MNK{}), EpilogueTileShape_MN{}));
|
||||
constexpr int FragmentSize = size(EpilogueTileShape_MN{}) / NumThreadsPerWarpGroup;
|
||||
// 8b residuals load fast and consume little smem, so the perf cost of waiting on stores to finish outweighs the cost of extra allocation
|
||||
constexpr bool ReuseSmem = sizeof_bits_v<ElementC> > 8;
|
||||
constexpr bool DelayTmaStore = false;
|
||||
constexpr int StagesD = cute::min(EpiTiles, 2);
|
||||
constexpr int StagesC = ReuseSmem ? cute::max(cute::min(EpiTiles, 4), StagesD+1)
|
||||
: cute::min(EpiTiles, 4);
|
||||
|
||||
if constexpr (detail::IsPtrArrayDispatchPolicy<Schedule>) {
|
||||
return Sm100PtrArrayTmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmem, DelayTmaStore>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return Sm100TmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmem, DelayTmaStore>{};
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Returns the TMEM_LOAD copy op to be used for the epilogue
|
||||
* Returned TMEM_LOAD op is such that the thread-value ownership matches the widest available
|
||||
@@ -344,10 +208,10 @@ sm100_get_tmem_load_op() {
|
||||
// For complex TF32 kernels
|
||||
else if constexpr (sizeof_bits_v<ElementAccumulator> == 64 && sizeof_bits_v<ElementD> == 64) {
|
||||
if constexpr (num_dp == 16) {
|
||||
return TMEM::op_repeater<SM100_TMEM_LOAD_16dp256b1x, num_col_bits>();
|
||||
return TMEM::op_repeater<SM100_TMEM_LOAD_16dp256b1x, num_col_bits/2>();
|
||||
}
|
||||
else {
|
||||
return TMEM::op_repeater<SM100_TMEM_LOAD_32dp32b1x, num_col_bits>();
|
||||
return TMEM::op_repeater<SM100_TMEM_LOAD_32dp32b1x, num_col_bits/2>();
|
||||
}
|
||||
}
|
||||
// For narrow precision output
|
||||
@@ -376,7 +240,6 @@ sm100_get_smem_store_op() {
|
||||
static_assert(is_m_major || is_n_major, "Unsupported gmem layout");
|
||||
|
||||
// Check for TMEM_LOAD layouts that match the thread-value ownership pattern of stmatrix
|
||||
// TODO: check copy vectorization instead!
|
||||
constexpr bool use_stmatrix_m8n8_4x =
|
||||
(sizeof_bits_v<ElementAccumulator> == 32 && sizeof_bits_v<ElementD> == 32 && is_n_major &&
|
||||
( cute::is_same_v<AccLoadOp, SM100_TMEM_LOAD_16dp128b2x> ||
|
||||
@@ -451,22 +314,7 @@ sm100_get_smem_store_op() {
|
||||
}
|
||||
}
|
||||
|
||||
template <class GmemStrideTypeD, class ElementD>
|
||||
constexpr auto
|
||||
sm100_get_register_transform_op() {
|
||||
using namespace cute;
|
||||
|
||||
[[maybe_unused]] constexpr bool is_m_major = cutlass::detail::is_major<0>(GmemStrideTypeD{});
|
||||
[[maybe_unused]] constexpr bool is_n_major = cutlass::detail::is_major<1>(GmemStrideTypeD{});
|
||||
static_assert(is_m_major || is_n_major, "Unsupported gmem layout");
|
||||
|
||||
if constexpr (sizeof_bits_v<ElementD> == 4 && is_m_major) {
|
||||
return SM50_Shuffle_U32_2x2Trans_XOR1{};
|
||||
}
|
||||
else {
|
||||
return AutoVectorizingCopyWithAssumedAlignment<128>{};
|
||||
}
|
||||
}
|
||||
|
||||
// Selects the largest vectorized smem load atom available
|
||||
// subject to constraint of gmem layout and chosen TMEM_LOAD's thread-value ownership
|
||||
@@ -503,30 +351,6 @@ sm100_get_smem_load_op() {
|
||||
}
|
||||
}
|
||||
|
||||
template <class Schedule, class LayoutTag>
|
||||
constexpr auto
|
||||
sm100_get_gmem_load_op() {
|
||||
if constexpr (detail::is_im2col_mode<LayoutTag>) {
|
||||
return SM90_TMA_LOAD_IM2COL{};
|
||||
}
|
||||
else {
|
||||
|
||||
return SM90_TMA_LOAD{};
|
||||
}
|
||||
}
|
||||
|
||||
template <class Schedule, class LayoutTag>
|
||||
constexpr auto
|
||||
sm100_get_gmem_store_op() {
|
||||
if constexpr (detail::is_im2col_mode<LayoutTag>) {
|
||||
return SM90_TMA_STORE_IM2COL{};
|
||||
}
|
||||
else {
|
||||
|
||||
return SM90_TMA_STORE{};
|
||||
}
|
||||
}
|
||||
|
||||
// aux fusion callbacks builder for sm100 tma epilogue
|
||||
template <
|
||||
int StagesC,
|
||||
@@ -622,9 +446,9 @@ struct CallbacksBuilder<
|
||||
// the fusion operation performed and the dispatch policy to use.
|
||||
template <
|
||||
class OpClass,
|
||||
class CtaTileShape_MNK,
|
||||
class MmaTileShape_MNK,
|
||||
class ClusterShape_MNK,
|
||||
class EpilogueTileType,
|
||||
class TmemWarpShape_MN,
|
||||
class ElementAccumulator,
|
||||
class ElementCompute,
|
||||
class ElementC_,
|
||||
@@ -637,62 +461,237 @@ template <
|
||||
class FusionOpOrCallbacks
|
||||
>
|
||||
struct Sm100TmaBuilderImpl {
|
||||
private:
|
||||
static constexpr bool Is1SmMma = is_base_of_v<TmaWarpSpecialized1Sm, Schedule>;
|
||||
static constexpr bool Is2SmMma = is_base_of_v<TmaWarpSpecialized2Sm, Schedule>;
|
||||
static_assert(Is1SmMma ^ Is2SmMma, "unsupported schedule");
|
||||
static_assert(not (Is2SmMma && size<0>(ClusterShape_MNK{}) % 2 == 1), "schedule + cluster mismatch");
|
||||
|
||||
// Passing void C disables source load + smem allocation
|
||||
using ElementC = cute::conditional_t<cute::is_void_v<ElementC_>,ElementD,ElementC_>; // prevents void ref breakages
|
||||
using GmemLayoutTagC = cute::conditional_t<cute::is_void_v<ElementC_>,GmemLayoutTagD,GmemLayoutTagC_>;
|
||||
|
||||
using GmemStrideTypeC = cutlass::detail::TagToStrideC_t<GmemLayoutTagC>;
|
||||
using GmemStrideTypeD = cutlass::detail::TagToStrideC_t<GmemLayoutTagD>;
|
||||
|
||||
using CopyOpS2G = decltype(detail::sm100_get_gmem_store_op<Schedule,GmemLayoutTagD>());
|
||||
using CopyOpG2S = decltype(detail::sm100_get_gmem_load_op<Schedule,GmemLayoutTagC>());
|
||||
|
||||
using FusionOp = conditional_t<is_base_of_v<epilogue::fusion::FusionOperation, FusionOpOrCallbacks>,
|
||||
FusionOpOrCallbacks, epilogue::fusion::FusionOperation>;
|
||||
|
||||
using EpilogueTile_MN = decltype(detail::sm100_compute_tile_shape_or_override<
|
||||
OpClass, CtaTileShape_MNK, EpilogueTileType, TmemWarpShape_MN,
|
||||
ElementC_, GmemStrideTypeC, ElementD, GmemStrideTypeD, FusionOp>());
|
||||
using EpilogueTileShape_MN = decltype(product_each(shape(EpilogueTile_MN{})));
|
||||
using EpilogueWarpTileShape_MN = decltype(shape_div(EpilogueTileShape_MN{}, TmemWarpShape_MN{}));
|
||||
using AccLoadOp = decltype(detail::sm100_get_tmem_load_op<
|
||||
GmemStrideTypeD, ElementAccumulator, ElementD, EpilogueWarpTileShape_MN, FusionOp>());
|
||||
static constexpr bool DisableSource = cute::is_void_v<ElementC_>;
|
||||
using ElementC = cute::conditional_t<DisableSource,ElementD,ElementC_>; // prevents void ref breakages
|
||||
using GmemLayoutTagC = cute::conditional_t<DisableSource,GmemLayoutTagD,GmemLayoutTagC_>;
|
||||
|
||||
using InternalSmemElementC = typename cutlass::detail::get_unpacked_element_type<ElementC>::type;
|
||||
using InternalSmemElementD = typename cutlass::detail::get_unpacked_element_type<ElementD>::type;
|
||||
|
||||
using DispatchPolicy = decltype(detail::sm100_get_tma_dispatch_policy<
|
||||
CtaTileShape_MNK, EpilogueTile_MN, ElementC_, ElementD, Schedule>());
|
||||
using GmemStrideTypeC = cutlass::detail::TagToStrideC_t<GmemLayoutTagC>;
|
||||
using GmemStrideTypeD = cutlass::detail::TagToStrideC_t<GmemLayoutTagD>;
|
||||
|
||||
// TMA builder allows for passing callbacks directly, which is either a fusion::FusionCallbacks
|
||||
// instance or a direct visitor implementation, e.g. fusion::Sm90LinearCombination
|
||||
using FusionCallbacks =
|
||||
typename CallbacksBuilder<
|
||||
DispatchPolicy,
|
||||
FusionOpOrCallbacks,
|
||||
CtaTileShape_MNK,
|
||||
EpilogueTile_MN,
|
||||
ElementAccumulator,
|
||||
AccLoadOp
|
||||
>::Callbacks;
|
||||
static constexpr bool IsTaggedFusionOp = is_base_of_v<epilogue::fusion::FusionOperation, FusionOpOrCallbacks>;
|
||||
using FusionOp = conditional_t<IsTaggedFusionOp, FusionOpOrCallbacks, epilogue::fusion::FusionOperation>;
|
||||
|
||||
static constexpr auto
|
||||
cta_tile_shape() {
|
||||
if constexpr (Is2SmMma) { // 2x1 threadblock shape
|
||||
auto [mma_tile_m, mma_tile_n, mma_tile_k] = MmaTileShape_MNK{};
|
||||
auto cta_tile_m = reverse(shape_div(reverse(mma_tile_m), _2{})); // first MmaTile_M/2 elements, preserve multimode
|
||||
return make_shape(cta_tile_m, mma_tile_n, mma_tile_k);
|
||||
}
|
||||
else { // 1x1 threadblock shape
|
||||
return MmaTileShape_MNK{};
|
||||
}
|
||||
}
|
||||
using CtaTileShape_MNK = decltype(cta_tile_shape());
|
||||
|
||||
static constexpr auto
|
||||
tmem_warps() {
|
||||
if constexpr (Is2SmMma && size<0>(MmaTileShape_MNK{}) == 128) {
|
||||
return Shape<_2,_2>{};
|
||||
}
|
||||
else {
|
||||
return Shape<_4,_1>{};
|
||||
}
|
||||
}
|
||||
using TmemWarpShape_MN = decltype(tmem_warps());
|
||||
|
||||
// Attempts to compute a reasonably performant epilogue tile or allows the user to provide one.
|
||||
static constexpr auto
|
||||
epilogue_tile() {
|
||||
using namespace cute;
|
||||
|
||||
if constexpr (is_same_v<OpClass, arch::OpClassBlockScaledTensorOp> &&
|
||||
is_same_v<EpilogueTileType, EpilogueTileAuto> &&
|
||||
size<1>(CtaTileShape_MNK{}) == 256) {
|
||||
constexpr int CtaM = size<0>(CtaTileShape_MNK{});
|
||||
constexpr int WarpM = size<0>(TmemWarpShape_MN{});
|
||||
constexpr int DpFull = 32;
|
||||
constexpr int M = cute::min(CtaM, DpFull * WarpM); // target 32dp tmem load
|
||||
// Note:
|
||||
// Set Epi_Tile_N to 128 support OverlappingAccum for the largest tile.
|
||||
// This is a general workable epi_tile_N which does not promise best perf.
|
||||
return make_tile(Int<M>{}, Int<128>{});
|
||||
}
|
||||
else if constexpr (is_same_v<EpilogueTileType, EpilogueTileAuto>) {
|
||||
constexpr int CtaM = size<0>(CtaTileShape_MNK{});
|
||||
constexpr int CtaN = size<1>(CtaTileShape_MNK{});
|
||||
constexpr int WarpM = size<0>(TmemWarpShape_MN{});
|
||||
constexpr int WarpN = size<1>(TmemWarpShape_MN{});
|
||||
constexpr int MaxBits = cute::max(sizeof_bits_v<ElementC>, sizeof_bits_v<ElementD>);
|
||||
|
||||
constexpr int DpFull = 32; // tmem datapaths in 1 subpartition
|
||||
constexpr int M = cute::min(CtaM, DpFull * WarpM); // target 32dp tmem load
|
||||
constexpr int N_perf = [&]() constexpr { // Known subtile sizes tested for perf
|
||||
// Epilogues w/o residual load are less sensitive to smem allocation
|
||||
// Target a fixed amount of compute per epilogue iteration
|
||||
if (DisableSource) {
|
||||
if (MaxBits == 4) {
|
||||
// Make epilogue tile larger to reduce the epilogue iterations.
|
||||
// 64 is the experimental value. It will minimize epilogue iterations but keep the number of A/B buffers the same.
|
||||
constexpr int ComputeElts = 8192;
|
||||
return ComputeElts / M;
|
||||
}
|
||||
constexpr int ComputeElts = 4096;
|
||||
return ComputeElts / M;
|
||||
}
|
||||
// Epilogues w/ residual load are more sensitive to smem allocation
|
||||
// Target optimal smem distribution between epilogue+mainloop based on datatype+tilesize
|
||||
else {
|
||||
if (MaxBits == 32) {
|
||||
return (CtaM > 64 && CtaN <= 128) ? 16 : 32;
|
||||
}
|
||||
// Per-column scaling is high register pressure, reduce tile to prevent spills
|
||||
else if (FusionOp::IsPerColScaleSupported) {
|
||||
return 32;
|
||||
}
|
||||
else if (MaxBits == 16) {
|
||||
return (CtaN <= 128) ? 32 : 64;
|
||||
}
|
||||
else {
|
||||
return 64;
|
||||
}
|
||||
}
|
||||
}();
|
||||
constexpr int N_min_C = (DisableSource || detail::is_m_major<GmemStrideTypeC>()) ? 8 * WarpN
|
||||
: (sizeof_bits_v<ElementC> == 6) ? 128 * WarpN // TMA store only supports SW128B for FP6 data type
|
||||
: 128 / sizeof_bits_v<ElementC> * WarpN;
|
||||
constexpr int N_min_D = (detail::is_m_major<GmemStrideTypeD>()) ? 8 * WarpN
|
||||
: (sizeof_bits_v<ElementD> == 6) ? 128 * WarpN // TMA store only supports SW128B for FP6 data type
|
||||
: 128 / sizeof_bits_v<ElementD> * WarpN;
|
||||
constexpr int N = cute::min(CtaN, cute::max(N_perf, N_min_C, N_min_D));
|
||||
static_assert(CtaN >= N_min_C && CtaN >= N_min_D, "CTA tile too small");
|
||||
|
||||
// stride by tmem warp layout and return a by-mode tiler
|
||||
auto tile_m = Layout<Int<M>>{};
|
||||
auto tile_n = Layout<Shape <Int<N / WarpN>,Int< WarpN>>,
|
||||
Stride<Int< 1>,Int<CtaN / WarpN>>>{};
|
||||
|
||||
return make_tile(tile_m, coalesce(tile_n));
|
||||
}
|
||||
else {
|
||||
static_assert(cute::is_tuple<EpilogueTileType>::value && not is_layout<EpilogueTileType>::value,
|
||||
"EpilogueTile must be a cute::Tile or cute::Shape");
|
||||
|
||||
EpilogueTileType epi_tile;
|
||||
constexpr int M = size<0>(shape(epi_tile));
|
||||
constexpr int N = size<1>(shape(epi_tile));
|
||||
static_assert(N % 8 == 0, "Unsupported tile shape");
|
||||
|
||||
return epi_tile;
|
||||
}
|
||||
}
|
||||
using EpilogueTile_MN = decltype(epilogue_tile());
|
||||
|
||||
using EpilogueTileShape_MN = decltype(product_each(shape(EpilogueTile_MN{})));
|
||||
static constexpr int EpiTiles = size(shape_div(take<0,2>(CtaTileShape_MNK{}), EpilogueTileShape_MN{}));
|
||||
static constexpr int FragmentSize = size(EpilogueTileShape_MN{}) / NumThreadsPerWarpGroup;
|
||||
|
||||
using EpilogueWarpTileShape_MN = decltype(shape_div(EpilogueTileShape_MN{}, TmemWarpShape_MN{}));
|
||||
using AccLoadOp = decltype(detail::sm100_get_tmem_load_op<
|
||||
GmemStrideTypeD, ElementAccumulator, ElementD, EpilogueWarpTileShape_MN, FusionOp>());
|
||||
|
||||
static constexpr auto
|
||||
dispatch_policy() {
|
||||
// 8b residuals load fast and consume little smem, so the perf cost of waiting on stores to finish outweighs the cost of extra allocation
|
||||
constexpr bool ReuseSmem = sizeof_bits_v<ElementC_> > 8;
|
||||
// TMA store delay performs worse with residual loads
|
||||
constexpr bool DelayTmaStore = is_void_v<ElementC_>;
|
||||
|
||||
constexpr int StagesD = cute::min(EpiTiles, 2);
|
||||
constexpr int StagesC = ReuseSmem ? cute::max(cute::min(EpiTiles, 4), StagesD+1)
|
||||
: cute::min(EpiTiles, 4);
|
||||
|
||||
if constexpr (is_same_v<Schedule, PtrArrayTmaWarpSpecialized1Sm> ||
|
||||
is_same_v<Schedule, PtrArrayTmaWarpSpecialized2Sm>) {
|
||||
constexpr bool DelayTmaStore_ = false; // TMA store delay complicates tensormap updates for Ptr-Array GEMMs
|
||||
return Sm100PtrArrayTmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmem, DelayTmaStore_>{};
|
||||
}
|
||||
else {
|
||||
return Sm100TmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmem, DelayTmaStore>{};
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr auto
|
||||
fusion_callbacks() {
|
||||
{
|
||||
return typename CallbacksBuilder<
|
||||
decltype(dispatch_policy()),
|
||||
FusionOpOrCallbacks,
|
||||
CtaTileShape_MNK,
|
||||
EpilogueTile_MN,
|
||||
ElementAccumulator,
|
||||
AccLoadOp
|
||||
>::Callbacks({},{});
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr auto
|
||||
gmem_load_op() {
|
||||
if constexpr (detail::is_im2col_mode<GmemLayoutTagC>) {
|
||||
return SM90_TMA_LOAD_IM2COL{};
|
||||
}
|
||||
else {
|
||||
return SM90_TMA_LOAD{};
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr auto
|
||||
gmem_store_op() {
|
||||
if constexpr (detail::is_im2col_mode<GmemLayoutTagD>) {
|
||||
return SM90_TMA_STORE_IM2COL{};
|
||||
}
|
||||
else {
|
||||
return SM90_TMA_STORE{};
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr auto
|
||||
register_shuffle_op() {
|
||||
using namespace cute;
|
||||
|
||||
[[maybe_unused]] constexpr bool is_m_major = cutlass::detail::is_major<0>(GmemStrideTypeD{});
|
||||
[[maybe_unused]] constexpr bool is_n_major = cutlass::detail::is_major<1>(GmemStrideTypeD{});
|
||||
static_assert(is_m_major || is_n_major, "Unsupported gmem layout");
|
||||
|
||||
if constexpr (sizeof_bits_v<InternalSmemElementD> == 4 && is_m_major) {
|
||||
return SM50_Shuffle_U32_2x2Trans_XOR1{};
|
||||
}
|
||||
else {
|
||||
return AutoVectorizingCopyWithAssumedAlignment<128>{};
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
using CollectiveOp =
|
||||
cutlass::epilogue::collective::CollectiveEpilogue<
|
||||
DispatchPolicy,
|
||||
decltype(dispatch_policy()),
|
||||
CtaTileShape_MNK,
|
||||
EpilogueTile_MN,
|
||||
ElementC_, // Need to pass void through to expose via GemmUniversal
|
||||
GmemStrideTypeC,
|
||||
ElementD,
|
||||
GmemStrideTypeD,
|
||||
FusionCallbacks,
|
||||
decltype(fusion_callbacks()),
|
||||
AccLoadOp,
|
||||
CopyOpG2S,
|
||||
decltype(gmem_load_op()),
|
||||
decltype(detail::sm100_get_epilogue_smem_swizzle_layout_atom<GmemStrideTypeC, InternalSmemElementC, EpilogueTile_MN>()),
|
||||
decltype(detail::sm100_get_smem_load_op<GmemStrideTypeC, InternalSmemElementC, ElementAccumulator, AccLoadOp>()),
|
||||
CopyOpS2G,
|
||||
decltype(gmem_store_op()),
|
||||
decltype(detail::sm100_get_epilogue_smem_swizzle_layout_atom<GmemStrideTypeD, InternalSmemElementD, EpilogueTile_MN>()),
|
||||
decltype(detail::sm100_get_smem_store_op<GmemStrideTypeD, InternalSmemElementD, ElementAccumulator, AccLoadOp>()),
|
||||
decltype(detail::sm100_get_register_transform_op<GmemStrideTypeD, InternalSmemElementD>())
|
||||
decltype(register_shuffle_op())
|
||||
>;
|
||||
};
|
||||
|
||||
@@ -702,7 +701,8 @@ struct Sm100TmaBuilderImpl {
|
||||
|
||||
// No smem builder
|
||||
template <
|
||||
class CtaTileShape_MNK,
|
||||
class OpClass,
|
||||
class MmaTileShape_MNK,
|
||||
class ClusterShape_MNK,
|
||||
class EpilogueTileType,
|
||||
class ElementAccumulator,
|
||||
@@ -718,8 +718,8 @@ template <
|
||||
>
|
||||
struct CollectiveBuilder<
|
||||
arch::Sm100,
|
||||
arch::OpClassTensorOp,
|
||||
CtaTileShape_MNK,
|
||||
OpClass,
|
||||
MmaTileShape_MNK,
|
||||
ClusterShape_MNK,
|
||||
EpilogueTileType,
|
||||
ElementAccumulator,
|
||||
@@ -732,11 +732,16 @@ struct CollectiveBuilder<
|
||||
AlignmentD,
|
||||
EpilogueScheduleType,
|
||||
FusionOpOrCallbacks,
|
||||
cute::enable_if_t<cute::is_same_v<EpilogueScheduleType, NoSmemWarpSpecialized> ||
|
||||
cute::is_same_v<EpilogueScheduleType, PtrArrayNoSmemWarpSpecialized> >> {
|
||||
cute::enable_if_t<is_base_of_v<NoSmemWarpSpecialized1Sm, EpilogueScheduleType> ||
|
||||
is_base_of_v<NoSmemWarpSpecialized2Sm, EpilogueScheduleType> >
|
||||
> {
|
||||
private:
|
||||
static_assert(cute::sizeof_bits_v<ElementD> != 6, "Output element requires TMA");
|
||||
|
||||
static_assert(cute::is_same_v<EpilogueTileType, EpilogueTileAuto>, "Epilogue subtiling requires smem");
|
||||
static_assert(cute::sizeof_bits_v<ElementD> != 4 and cute::sizeof_bits_v<ElementD> != 6, "Output element requires smem");
|
||||
static constexpr bool Is1SmMma = is_base_of_v<NoSmemWarpSpecialized1Sm, EpilogueScheduleType>;
|
||||
static constexpr bool Is2SmMma = is_base_of_v<NoSmemWarpSpecialized2Sm, EpilogueScheduleType>;
|
||||
static_assert(Is1SmMma ^ Is2SmMma, "unsupported schedule");
|
||||
static_assert(not (Is2SmMma && size<0>(ClusterShape_MNK{}) % 2 == 1), "schedule + cluster mismatch");
|
||||
|
||||
static constexpr bool DisableSource = cute::is_void_v<ElementC_>;
|
||||
using ElementC = cute::conditional_t<DisableSource, ElementD, ElementC_>; // prevents void ref breakages
|
||||
@@ -744,173 +749,110 @@ struct CollectiveBuilder<
|
||||
using GmemStrideTypeC = cutlass::detail::TagToStrideC_t<GmemLayoutTagC>;
|
||||
using GmemStrideTypeD = cutlass::detail::TagToStrideC_t<GmemLayoutTagD>;
|
||||
|
||||
using FusionOp = conditional_t<is_base_of_v<epilogue::fusion::FusionOperation, FusionOpOrCallbacks>,
|
||||
FusionOpOrCallbacks, epilogue::fusion::FusionOperation>;
|
||||
static constexpr bool IsTaggedFusionOp = is_base_of_v<epilogue::fusion::FusionOperation, FusionOpOrCallbacks>;
|
||||
using FusionOp = conditional_t<IsTaggedFusionOp, FusionOpOrCallbacks, epilogue::fusion::FusionOperation>;
|
||||
|
||||
// use a 4x2 division to select tmem load shape in order to maintain compatability with both (4,1) and (2,2) layouts
|
||||
using EpilogueTile = decltype(take<0,2>(CtaTileShape_MNK{}));
|
||||
using EpilogueWarpTileShape_MN = decltype(shape_div(EpilogueTile{}, Shape<_4,_2>{}));
|
||||
static constexpr auto
|
||||
cta_tile_shape() {
|
||||
if constexpr (Is2SmMma) { // 2x1 threadblock shape
|
||||
auto [mma_tile_m, mma_tile_n, mma_tile_k] = MmaTileShape_MNK{};
|
||||
auto cta_tile_m = reverse(shape_div(reverse(mma_tile_m), _2{})); // first MmaTile_M/2 elements, preserve multimode
|
||||
return make_shape(cta_tile_m, mma_tile_n, mma_tile_k);
|
||||
}
|
||||
else { // 1x1 threadblock shape
|
||||
return MmaTileShape_MNK{};
|
||||
}
|
||||
}
|
||||
using CtaTileShape_MNK = decltype(cta_tile_shape());
|
||||
|
||||
static constexpr auto
|
||||
tmem_warps() {
|
||||
if constexpr (Is2SmMma && size<0>(MmaTileShape_MNK{}) == 128) {
|
||||
return Shape<_2,_2>{};
|
||||
}
|
||||
else {
|
||||
return Shape<_4,_1>{};
|
||||
}
|
||||
}
|
||||
using TmemWarpShape_MN = decltype(tmem_warps());
|
||||
|
||||
static constexpr auto
|
||||
epilogue_tile() {
|
||||
using namespace cute;
|
||||
if constexpr (not is_same_v<EpilogueTileType, EpilogueTileAuto>) {
|
||||
static_assert(is_tuple_v<EpilogueTileType>, "Shape or Tile");
|
||||
return EpilogueTileType{};
|
||||
}
|
||||
else if constexpr (is_same_v<OpClass,arch::OpClassBlockScaledTensorOp>) { // perf specialized case
|
||||
constexpr int EpiM = size<0>(CtaTileShape_MNK{});
|
||||
constexpr int EpiN = cute::min(_64{}, size<1>(CtaTileShape_MNK{}));
|
||||
return Shape<Int<EpiM>, Int<EpiN>>{};
|
||||
}
|
||||
else {
|
||||
return take<0,2>(CtaTileShape_MNK{});
|
||||
}
|
||||
}
|
||||
using EpilogueTile = decltype(epilogue_tile());
|
||||
|
||||
using EpilogueWarpTileShape_MN = decltype(shape_div(EpilogueTile{}, TmemWarpShape_MN{}));
|
||||
using AccLoadOp = decltype(detail::sm100_get_tmem_load_op<
|
||||
GmemStrideTypeD, ElementAccumulator, ElementD, EpilogueWarpTileShape_MN, FusionOp>());
|
||||
static constexpr int FragmentSize = size(EpilogueTile{}) / NumThreadsPerWarpGroup;
|
||||
|
||||
using DispatchPolicy = cutlass::epilogue::Sm100NoSmemWarpSpecialized;
|
||||
static constexpr auto
|
||||
dispatch_policy() {
|
||||
if constexpr (is_same_v<EpilogueScheduleType, PtrArrayNoSmemWarpSpecialized1Sm> ||
|
||||
is_same_v<EpilogueScheduleType, PtrArrayNoSmemWarpSpecialized2Sm>) {
|
||||
return Sm100PtrArrayNoSmemWarpSpecialized{};
|
||||
}
|
||||
else {
|
||||
return Sm100NoSmemWarpSpecialized{};
|
||||
}
|
||||
}
|
||||
using DispatchPolicy = decltype(dispatch_policy());
|
||||
|
||||
using AlignmentCType = Int<AlignmentC>;
|
||||
using AlignmentDType = Int<AlignmentD>;
|
||||
static constexpr auto
|
||||
fusion_callbacks() {
|
||||
constexpr thread::ScaleType::Kind ScaleType =
|
||||
DisableSource ? thread::ScaleType::OnlyAlphaScaling : thread::ScaleType::Default;
|
||||
if constexpr (IsDefaultFusionOp<FusionOp>::value && not is_same_v<OpClass,arch::OpClassBlockScaledTensorOp>) {
|
||||
// Legacy codepath using thread::LinearCombination, do not expect this to be stable
|
||||
return thread::LinearCombination<
|
||||
ElementD, 1, ElementAccumulator, ElementCompute, ScaleType, FusionOp::RoundStyle, ElementC>({});
|
||||
}
|
||||
else {
|
||||
return typename detail::CallbacksBuilder<
|
||||
DispatchPolicy,
|
||||
FusionOpOrCallbacks,
|
||||
CtaTileShape_MNK,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
AccLoadOp
|
||||
>::Callbacks({},{});
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest;
|
||||
static constexpr thread::ScaleType::Kind ScaleType = DisableSource ?
|
||||
thread::ScaleType::OnlyAlphaScaling : thread::ScaleType::Default;
|
||||
|
||||
using FusionCallbacks = cute::conditional_t<
|
||||
IsDefaultFusionOp<FusionOp>::value,
|
||||
// Legacy codepath using thread::LinearCombination, do not expect this to be stable
|
||||
thread::LinearCombination<
|
||||
ElementD, 1, ElementAccumulator, ElementCompute,
|
||||
ScaleType, RoundStyle, ElementC>
|
||||
,
|
||||
typename detail::CallbacksBuilder<
|
||||
public:
|
||||
using CollectiveOp =
|
||||
cutlass::epilogue::collective::CollectiveEpilogue<
|
||||
DispatchPolicy,
|
||||
FusionOpOrCallbacks,
|
||||
CtaTileShape_MNK,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
AccLoadOp
|
||||
>::Callbacks
|
||||
>;
|
||||
|
||||
using CollectiveOp = cute::conditional_t<
|
||||
cute::is_same_v<EpilogueScheduleType, NoSmemWarpSpecialized>,
|
||||
cutlass::epilogue::collective::CollectiveEpilogue<
|
||||
cutlass::epilogue::Sm100NoSmemWarpSpecialized,
|
||||
EpilogueTile,
|
||||
ElementC_,
|
||||
GmemStrideTypeC,
|
||||
ElementD,
|
||||
GmemStrideTypeD,
|
||||
FusionCallbacks,
|
||||
decltype(fusion_callbacks()),
|
||||
AccLoadOp,
|
||||
AlignmentCType,
|
||||
AlignmentDType
|
||||
>,
|
||||
cutlass::epilogue::collective::CollectiveEpilogue<
|
||||
cutlass::epilogue::Sm100PtrArrayNoSmemWarpSpecialized,
|
||||
EpilogueTile,
|
||||
ElementC_,
|
||||
GmemStrideTypeC,
|
||||
ElementD,
|
||||
GmemStrideTypeD,
|
||||
FusionCallbacks,
|
||||
AccLoadOp
|
||||
>
|
||||
>;
|
||||
};
|
||||
|
||||
// No smem builder for OpClassBlockScaledTensorOp
|
||||
template <
|
||||
class CtaTileShape_MNK,
|
||||
class ClusterShape_MNK,
|
||||
class EpilogueTileType,
|
||||
class ElementAccumulator,
|
||||
class ElementCompute,
|
||||
class ElementC_,
|
||||
class GmemLayoutTagC_,
|
||||
int AlignmentC,
|
||||
class ElementD,
|
||||
class GmemLayoutTagD,
|
||||
int AlignmentD,
|
||||
class EpilogueScheduleType,
|
||||
class FusionOp
|
||||
>
|
||||
struct CollectiveBuilder<
|
||||
arch::Sm100,
|
||||
arch::OpClassBlockScaledTensorOp,
|
||||
CtaTileShape_MNK,
|
||||
ClusterShape_MNK,
|
||||
EpilogueTileType,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
ElementC_,
|
||||
GmemLayoutTagC_,
|
||||
AlignmentC,
|
||||
ElementD,
|
||||
GmemLayoutTagD,
|
||||
AlignmentD,
|
||||
EpilogueScheduleType,
|
||||
FusionOp,
|
||||
cute::enable_if_t<cute::is_same_v<EpilogueScheduleType, NoSmemWarpSpecialized> ||
|
||||
cute::is_same_v<EpilogueScheduleType, PtrArrayNoSmemWarpSpecialized> >> {
|
||||
|
||||
static_assert(cute::sizeof_bits_v<ElementD> != 6, "Output element requires smem");
|
||||
|
||||
static constexpr bool DisableSource = cute::is_void_v<ElementC_>;
|
||||
using ElementC = cute::conditional_t<DisableSource, ElementD, ElementC_>; // prevents void ref breakages
|
||||
using GmemLayoutTagC = cute::conditional_t<DisableSource, GmemLayoutTagD, GmemLayoutTagC_>;
|
||||
static constexpr thread::ScaleType::Kind ScaleType = DisableSource ?
|
||||
thread::ScaleType::OnlyAlphaScaling : thread::ScaleType::Default;
|
||||
using GmemStrideTypeC = cutlass::detail::TagToStrideC_t<GmemLayoutTagC>;
|
||||
using GmemStrideTypeD = cutlass::detail::TagToStrideC_t<GmemLayoutTagD>;
|
||||
|
||||
static_assert(cute::is_tuple<EpilogueTileType>::value || cute::is_same_v<EpilogueTileType, EpilogueTileAuto>);
|
||||
using EpilogueTile = cute::conditional_t<cute::is_same_v<EpilogueTileType, EpilogueTileAuto>,
|
||||
cute::Shape<_128, _64>,
|
||||
EpilogueTileType
|
||||
>;
|
||||
|
||||
using EpilogueWarpTileShape_MN = decltype(shape_div(EpilogueTile{}, Shape<_4,_1>{}));
|
||||
using AccLoadOp = decltype(detail::sm100_get_tmem_load_op<
|
||||
GmemStrideTypeD, ElementAccumulator, ElementD, EpilogueWarpTileShape_MN, FusionOp>());
|
||||
|
||||
using DispatchPolicy = cutlass::epilogue::Sm100NoSmemWarpSpecialized;
|
||||
|
||||
using AlignmentCType = Int<AlignmentC>;
|
||||
using AlignmentDType = Int<AlignmentD>;
|
||||
|
||||
static constexpr FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest;
|
||||
|
||||
static_assert(is_base_of_v<fusion::FusionOperation, FusionOp>, "only support EVT fusions");
|
||||
using FusionCallbacks =
|
||||
typename detail::CallbacksBuilder<
|
||||
DispatchPolicy,
|
||||
FusionOp,
|
||||
CtaTileShape_MNK,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
AccLoadOp
|
||||
>::Callbacks;
|
||||
|
||||
using CollectiveOp = cute::conditional_t<
|
||||
cute::is_same_v<EpilogueScheduleType, NoSmemWarpSpecialized>,
|
||||
cutlass::epilogue::collective::CollectiveEpilogue<
|
||||
cutlass::epilogue::Sm100NoSmemWarpSpecialized,
|
||||
EpilogueTile,
|
||||
ElementC_,
|
||||
GmemStrideTypeC,
|
||||
ElementD,
|
||||
GmemStrideTypeD,
|
||||
FusionCallbacks,
|
||||
AccLoadOp,
|
||||
AlignmentCType,
|
||||
AlignmentDType
|
||||
>,
|
||||
cutlass::epilogue::collective::CollectiveEpilogue<
|
||||
cutlass::epilogue::Sm100PtrArrayNoSmemWarpSpecialized,
|
||||
EpilogueTile,
|
||||
ElementC_,
|
||||
GmemStrideTypeC,
|
||||
ElementD,
|
||||
GmemStrideTypeD,
|
||||
FusionCallbacks,
|
||||
AccLoadOp
|
||||
>
|
||||
>;
|
||||
Int<AlignmentC>,
|
||||
Int<AlignmentD>
|
||||
>;
|
||||
};
|
||||
|
||||
// TMA epilogue builder
|
||||
template <
|
||||
class OpClass,
|
||||
class CtaTileShape_MNK, // Static CTA tile shape
|
||||
class ClusterShape_MNK, // Static cluster shape or dynamic (int, int, _1)
|
||||
class MmaTileShape_MNK,
|
||||
class ClusterShape_MNK,
|
||||
class EpilogueTileType,
|
||||
class ElementAccumulator,
|
||||
class ElementCompute,
|
||||
@@ -926,7 +868,7 @@ template <
|
||||
struct CollectiveBuilder<
|
||||
arch::Sm100,
|
||||
OpClass,
|
||||
CtaTileShape_MNK,
|
||||
MmaTileShape_MNK,
|
||||
ClusterShape_MNK,
|
||||
EpilogueTileType,
|
||||
ElementAccumulator,
|
||||
@@ -940,30 +882,20 @@ struct CollectiveBuilder<
|
||||
EpilogueScheduleType,
|
||||
FusionOp,
|
||||
cute::enable_if_t<
|
||||
// OpClass
|
||||
( cute::is_same_v<OpClass, arch::OpClassTensorOp>
|
||||
|| cute::is_same_v<OpClass, arch::OpClassBlockScaledTensorOp>
|
||||
) &&
|
||||
// Epilogue Schedule Type
|
||||
( cute::is_base_of_v<TmaWarpSpecialized1Sm, EpilogueScheduleType> ||
|
||||
cute::is_base_of_v<TmaWarpSpecialized2Sm, EpilogueScheduleType>
|
||||
|| detail::IsPtrArrayDispatchPolicy<EpilogueScheduleType>
|
||||
)>>
|
||||
// Only support TensorOp kernels
|
||||
not cute::is_same_v<OpClass, arch::OpClassSimt> &&
|
||||
(cute::is_base_of_v<TmaWarpSpecialized1Sm, EpilogueScheduleType> ||
|
||||
cute::is_base_of_v<TmaWarpSpecialized2Sm, EpilogueScheduleType>)
|
||||
>
|
||||
>
|
||||
{
|
||||
private:
|
||||
using TmemWarpShape_MN = cute::conditional_t<size<0>(CtaTileShape_MNK{}) == 64 &&
|
||||
(cute::is_base_of_v<TmaWarpSpecialized2Sm, EpilogueScheduleType>
|
||||
|| cute::is_same_v<EpilogueScheduleType, PtrArrayTmaWarpSpecialized2Sm>
|
||||
),
|
||||
Shape<_2,_2>, Shape<_4,_1>>;
|
||||
|
||||
public:
|
||||
using CollectiveOp =
|
||||
typename detail::Sm100TmaBuilderImpl<
|
||||
OpClass,
|
||||
CtaTileShape_MNK,
|
||||
MmaTileShape_MNK,
|
||||
ClusterShape_MNK,
|
||||
EpilogueTileType,
|
||||
TmemWarpShape_MN,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
ElementC,
|
||||
@@ -977,11 +909,11 @@ public:
|
||||
>::CollectiveOp;
|
||||
};
|
||||
|
||||
// Auto builder
|
||||
// Auto epilogue builder for TensorOp kernels
|
||||
template <
|
||||
class OpClass,
|
||||
class CtaTileShape_MNK, // Static CTA tile shape
|
||||
class ClusterShape_MNK, // Static cluster shape or dynamic (int, int, _1)
|
||||
class MmaTileShape_MNK,
|
||||
class ClusterShape_MNK,
|
||||
class EpilogueTileType,
|
||||
class ElementAccumulator,
|
||||
class ElementCompute,
|
||||
@@ -991,13 +923,12 @@ template <
|
||||
class ElementD,
|
||||
class GmemLayoutTagD,
|
||||
int AlignmentD,
|
||||
class EpilogueScheduleType,
|
||||
class FusionOp
|
||||
>
|
||||
struct CollectiveBuilder<
|
||||
arch::Sm100,
|
||||
OpClass,
|
||||
CtaTileShape_MNK,
|
||||
MmaTileShape_MNK,
|
||||
ClusterShape_MNK,
|
||||
EpilogueTileType,
|
||||
ElementAccumulator,
|
||||
@@ -1008,30 +939,41 @@ struct CollectiveBuilder<
|
||||
ElementD,
|
||||
GmemLayoutTagD,
|
||||
AlignmentD,
|
||||
EpilogueScheduleType,
|
||||
EpilogueScheduleAuto,
|
||||
FusionOp,
|
||||
cute::enable_if_t<
|
||||
// OpClass
|
||||
( cute::is_same_v<OpClass, arch::OpClassTensorOp>
|
||||
|| cute::is_same_v<OpClass, arch::OpClassBlockScaledTensorOp>
|
||||
)
|
||||
// Epilogue Schedule Type
|
||||
&& cute::is_same_v<EpilogueScheduleType, EpilogueScheduleAuto>>
|
||||
// only for TensorOp kernels
|
||||
cute::enable_if_t<not cute::is_same_v<OpClass, arch::OpClassSimt>>
|
||||
>
|
||||
{
|
||||
private:
|
||||
static_assert(cute::is_same_v<EpilogueTileType, EpilogueTileAuto>, "Don't specify epilogue tile with auto schedule");
|
||||
using TmemWarpShape_MN = cute::conditional_t<size<0>(CtaTileShape_MNK{}) == 64 &&
|
||||
size<0>(ClusterShape_MNK{}) % 2 == 0
|
||||
,
|
||||
Shape<_2,_2>, Shape<_4,_1>>;
|
||||
static constexpr bool
|
||||
is_2sm() {
|
||||
using namespace cute;
|
||||
constexpr int MmaTileM = size<0>(MmaTileShape_MNK{});
|
||||
constexpr int ClusterM = size<0>(ClusterShape_MNK{});
|
||||
constexpr bool StaticClusterM = is_static_v<decltype(get<0>(ClusterShape_MNK{}))>;
|
||||
constexpr bool EvenClusterM = StaticClusterM && ClusterM % 2 == 0;
|
||||
if constexpr (not EvenClusterM) {
|
||||
return false;
|
||||
}
|
||||
else if constexpr (is_same_v<OpClass,arch::OpClassBlockScaledTensorOp>) {
|
||||
return MmaTileM == 256;
|
||||
}
|
||||
else {
|
||||
return MmaTileM == 256 || MmaTileM == 128;
|
||||
}
|
||||
}
|
||||
using EpilogueSchedule = cute::conditional_t<is_2sm(), TmaWarpSpecialized2Sm, TmaWarpSpecialized1Sm>;
|
||||
|
||||
public:
|
||||
static_assert(cute::is_same_v<EpilogueTileType, EpilogueTileAuto>, "Don't specify epilogue tile with auto schedule");
|
||||
using CollectiveOp =
|
||||
typename detail::Sm100TmaBuilderImpl<
|
||||
typename CollectiveBuilder<
|
||||
arch::Sm100,
|
||||
OpClass,
|
||||
CtaTileShape_MNK,
|
||||
MmaTileShape_MNK,
|
||||
ClusterShape_MNK,
|
||||
EpilogueTileType,
|
||||
TmemWarpShape_MN,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
ElementC,
|
||||
@@ -1040,7 +982,7 @@ public:
|
||||
ElementD,
|
||||
GmemLayoutTagD,
|
||||
AlignmentD,
|
||||
EpilogueScheduleType,
|
||||
EpilogueSchedule,
|
||||
FusionOp
|
||||
>::CollectiveOp;
|
||||
};
|
||||
|
||||
@@ -356,24 +356,21 @@ public:
|
||||
}
|
||||
|
||||
// Represent the full output tensor, slice to get the tile this CTA is responsible for
|
||||
Tensor mC = make_tensor(make_gmem_ptr(ptr_C_l), problem_shape_mnl, append<3>(params.dC,_0{})); // (M,N,L)
|
||||
Tensor mD = make_tensor(make_gmem_ptr(params.ptr_D[l_coord]), problem_shape_mnl, append<3>(params.dD,_0{})); // (M,N,L)
|
||||
Tensor gC = local_tile(mC, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N)
|
||||
Tensor gD = local_tile(mD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N)
|
||||
Tensor gC_epi = flat_divide( gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
||||
Tensor gD_epi = flat_divide( gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
||||
Tensor mC = make_tensor(make_gmem_ptr(ptr_C_l), problem_shape_mnl, append<3>(params.dC,_0{})); // (M,N,L)
|
||||
Tensor mD = make_tensor(make_gmem_ptr(params.ptr_D[l_coord]), problem_shape_mnl, append<3>(params.dD,_0{})); // (M,N,L)
|
||||
Tensor gC = local_tile(mC, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N)
|
||||
Tensor gD = local_tile(mD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N)
|
||||
|
||||
|
||||
// Partition source and destination tiles according to tmem copy T2R partitioning (tTR_)
|
||||
auto thread_t2r = tiled_t2r.get_slice(threadIdx.x % size(tiled_t2r));
|
||||
Tensor tTR_gC = thread_t2r.partition_D(gC_epi); // (T2R,T2R_M,T2R_N)
|
||||
Tensor tTR_gD = thread_t2r.partition_D(gD_epi); // (T2R,T2R_M,T2R_N)
|
||||
Tensor tTR_gC = thread_t2r.partition_D(gC); // (T2R,T2R_M,T2R_N)
|
||||
Tensor tTR_gD = thread_t2r.partition_D(gD); // (T2R,T2R_M,T2R_N)
|
||||
|
||||
|
||||
Tensor coordD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l)
|
||||
Tensor cD = local_tile(coordD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l)
|
||||
Tensor cD_epi = flat_divide( cD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
||||
Tensor tTR_cD = thread_t2r.partition_D(cD); // (T2R,T2R_M,T2R_N) -> (m,n,l)
|
||||
Tensor coordD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l)
|
||||
Tensor cD = local_tile(coordD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l)
|
||||
Tensor tTR_cD = thread_t2r.partition_D(cD); // (T2R,T2R_M,T2R_N) -> (m,n,l)
|
||||
|
||||
// 2. Apply element-wise operation and store to gmem
|
||||
// source is needed
|
||||
@@ -410,7 +407,9 @@ template <
|
||||
class ElementD_,
|
||||
class StrideD_,
|
||||
class ThreadEpilogueOp_,
|
||||
class CopyOpT2R_
|
||||
class CopyOpT2R_,
|
||||
class AlignmentC,
|
||||
class AlignmentD
|
||||
>
|
||||
class CollectiveEpilogue<
|
||||
Sm100PtrArrayNoSmemWarpSpecialized,
|
||||
@@ -420,7 +419,9 @@ class CollectiveEpilogue<
|
||||
ElementD_,
|
||||
StrideD_,
|
||||
ThreadEpilogueOp_,
|
||||
CopyOpT2R_
|
||||
CopyOpT2R_,
|
||||
AlignmentC,
|
||||
AlignmentD
|
||||
> : public detail::Sm100TmaWarpSpecializedAdapter<CollectiveEpilogue<
|
||||
Sm100PtrArrayNoSmem,
|
||||
EpilogueTile_,
|
||||
|
||||
@@ -372,24 +372,21 @@ public:
|
||||
auto cta_tiler = take<0,2>(cta_tile_shape_mnk);
|
||||
|
||||
// Represent the full output tensor, slice to get the tile this CTA is responsible for
|
||||
Tensor mC = make_tensor(make_gmem_ptr<GmemElementC>(params.ptr_C), problem_shape_mnl, append<3>(params.dC,_0{})); // (M,N,L)
|
||||
Tensor mD = make_tensor(make_gmem_ptr(params.ptr_D), problem_shape_mnl, append<3>(params.dD,_0{})); // (M,N,L)
|
||||
Tensor gC = local_tile(mC, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N)
|
||||
Tensor gD = local_tile(mD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N)
|
||||
Tensor gC_epi = flat_divide( gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
||||
Tensor gD_epi = flat_divide( gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
||||
Tensor mC = make_tensor(make_gmem_ptr<GmemElementC>(params.ptr_C), problem_shape_mnl, append<3>(params.dC,_0{})); // (M,N,L)
|
||||
Tensor mD = make_tensor(make_gmem_ptr(params.ptr_D), problem_shape_mnl, append<3>(params.dD,_0{})); // (M,N,L)
|
||||
Tensor gC = local_tile(mC, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N)
|
||||
Tensor gD = local_tile(mD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N)
|
||||
|
||||
|
||||
// Partition source and destination tiles according to tmem copy T2R partitioning (tTR_)
|
||||
auto thread_t2r = tiled_t2r.get_slice(threadIdx.x % size(tiled_t2r));
|
||||
Tensor tTR_gC = thread_t2r.partition_D(gC_epi); // (T2R,T2R_M,T2R_N)
|
||||
Tensor tTR_gD = thread_t2r.partition_D(gD_epi); // (T2R,T2R_M,T2R_N)
|
||||
Tensor tTR_gC = thread_t2r.partition_D(gC); // (T2R,T2R_M,T2R_N)
|
||||
Tensor tTR_gD = thread_t2r.partition_D(gD); // (T2R,T2R_M,T2R_N)
|
||||
|
||||
|
||||
Tensor coordCD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l)
|
||||
Tensor cCD = local_tile(coordCD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l)
|
||||
Tensor cD_epi = flat_divide( cCD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
||||
Tensor tTR_cCD = thread_t2r.partition_D(cCD); // (T2R,T2R_M,T2R_N) -> (m,n,l)
|
||||
Tensor coordCD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l)
|
||||
Tensor cCD = local_tile(coordCD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l)
|
||||
Tensor tTR_cCD = thread_t2r.partition_D(cCD); // (T2R,T2R_M,T2R_N) -> (m,n,l)
|
||||
|
||||
// 2. Apply element-wise operation and store to gmem
|
||||
ThreadEpilogueOp epilogue_op{params.thread};
|
||||
@@ -587,18 +584,18 @@ public:
|
||||
|
||||
int thread_idx = threadIdx.x % ThreadCount;
|
||||
|
||||
Tensor tAcc = accumulators(make_coord(_,_),_0{},_0{}); // (CTA_M,CTA_N)
|
||||
Tensor tAcc_epi = flat_divide(tAcc, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
||||
Tensor tAcc = accumulators(make_coord(_,_),_0{},_0{}); // (CTA_M,CTA_N)
|
||||
Tensor tAcc_epi = flat_divide(tAcc, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
||||
TiledCopy tiled_t2r = make_tmem_copy(CopyOpT2R{}, tAcc_epi(_,_,_0{},_0{}));
|
||||
ThrCopy thread_t2r = tiled_t2r.get_slice(thread_idx);
|
||||
Tensor tTR_tAcc = thread_t2r.partition_S(tAcc_epi); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N)
|
||||
Tensor tTR_tAcc = thread_t2r.partition_S(tAcc_epi); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N)
|
||||
|
||||
constexpr int FragmentSize = size(EpilogueTile{}) / ThreadCount;
|
||||
|
||||
Tensor coordD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l)
|
||||
Tensor cD = local_tile(coordD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l)
|
||||
Tensor coordD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l)
|
||||
Tensor cD = local_tile(coordD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l)
|
||||
Tensor cD_epi = flat_divide(cD, EpilogueTile{});
|
||||
Tensor tTR_cD = thread_t2r.partition_D(cD_epi); // (T2R,T2R_M,T2R_N) -> (m,n,l)
|
||||
Tensor tTR_cD = thread_t2r.partition_D(cD_epi); // (T2R,T2R_M,T2R_N) -> (m,n,l)
|
||||
|
||||
Tensor tTR_rAcc = make_tensor<ElementAccumulator>(shape(tTR_cD(_,_,_,_0{},_0{})));
|
||||
|
||||
@@ -689,19 +686,22 @@ public:
|
||||
do_acc_release = iter_m == size<3>(tTR_tAcc)-1 && iter_n == 0;
|
||||
}
|
||||
|
||||
Tensor tTR_cCD_mn = tTR_cCD(_,_,_,epi_m,epi_n);
|
||||
Tensor tTR_cCD_mn = tTR_cCD(_,_,_,epi_m,epi_n);
|
||||
cst_callbacks.begin_loop(epi_m, epi_n);
|
||||
|
||||
if (is_C_load_needed) {
|
||||
Tensor tTR_cC_frag = tensor<1>(zipped_divide(coalesce(tTR_cCD_mn), mclC.compose(Int<VC>{})));
|
||||
Tensor tTR_gC_frg = recast<Array<GmemElementC, VC>>(coalesce(tTR_gC(_,_,_,epi_m,epi_n)));
|
||||
Tensor tTR_rC_frg = recast<Array<GmemElementC, VC>>(coalesce(tCrC));
|
||||
if constexpr (not cute::is_void_v<ElementC>) {
|
||||
if (is_C_load_needed) {
|
||||
using CVecType = uint_bit_t<VC * sizeof_bits_v<ElementC>>;
|
||||
Tensor tTR_cC_frag = tensor<1>(zipped_divide(coalesce(tTR_cCD_mn), mclC.compose(Int<VC>{})));
|
||||
|
||||
auto pred_fn_C = [&] (auto const&... coords) {
|
||||
return elem_less(tTR_cC_frag(coords...), problem_shape_mnl);
|
||||
};
|
||||
auto pred_fn_C = [&] (auto const&... coords) CUTLASS_LAMBDA_FUNC_INLINE {
|
||||
return elem_less(tTR_cC_frag(coords...), problem_shape_mnl);
|
||||
};
|
||||
|
||||
copy_if(pred_fn_C, tTR_gC_frg, tTR_rC_frg);
|
||||
Tensor tTR_gC_frg = recast<CVecType>(coalesce(tTR_gC(_,_,_,epi_m,epi_n)));
|
||||
Tensor tTR_rC_frg = recast<CVecType>(coalesce(tCrC));
|
||||
copy_if(pred_fn_C, tTR_gC_frg, tTR_rC_frg);
|
||||
}
|
||||
}
|
||||
|
||||
// Copy accumulator tile from tmem to register
|
||||
@@ -733,17 +733,15 @@ public:
|
||||
|
||||
|
||||
Tensor tTR_cD_frag = tensor<1>(zipped_divide(coalesce(tTR_cCD_mn), mclD.compose(Int<VD>{})));
|
||||
|
||||
using VecType = uint_bit_t<VD * sizeof_bits_v<ElementD>>;
|
||||
Tensor tTR_gD_frg = recast<VecType>(coalesce(tTR_gD(_,_,_,epi_m,epi_n)));
|
||||
Tensor tTR_rD_frg = recast<VecType>(coalesce(tTR_rD));
|
||||
|
||||
auto pred_fn_D = [&] (auto const&... coords) CUTLASS_LAMBDA_FUNC_INLINE {
|
||||
return elem_less(tTR_cD_frag(coords...), problem_shape_mnl);
|
||||
};
|
||||
|
||||
copy_if(pred_fn_D, tTR_rD_frg, tTR_gD_frg);
|
||||
using VecType = uint_bit_t<VD * sizeof_bits_v<ElementD>>;
|
||||
Tensor tTR_gD_frg = recast<VecType>(coalesce(tTR_gD(_,_,_,epi_m,epi_n)));
|
||||
Tensor tTR_rD_frg = recast<VecType>(coalesce(tTR_rD));
|
||||
|
||||
copy_if(pred_fn_D, tTR_rD_frg, tTR_gD_frg);
|
||||
} // for epi_m
|
||||
} // for epi_n
|
||||
|
||||
|
||||
@@ -340,7 +340,7 @@ public:
|
||||
_1{});
|
||||
}
|
||||
|
||||
typename Params::TMA_D tma_store_d;
|
||||
typename Params::TMA_D tma_store_d{};
|
||||
if constexpr (is_destination_supported) {
|
||||
ElementD const* ptr_D_first_batch = reinterpret_cast<ElementD const*>(args.ptr_D);
|
||||
Tensor tensor_d = make_tensor(ptr_D_first_batch, make_layout(make_shape(init_M,init_N,init_L), append<3>(stride_d, _0{})));
|
||||
|
||||
@@ -287,7 +287,7 @@ public:
|
||||
EpilogueTile{});
|
||||
}
|
||||
|
||||
typename Params::TMA_D tma_store_d;
|
||||
typename Params::TMA_D tma_store_d{};
|
||||
if constexpr (is_destination_supported) {
|
||||
Tensor tensor_d = make_tensor(make_gmem_ptr<TmaElementD>(args.ptr_D), make_layout(make_shape(M,N,L), args.dD));
|
||||
tma_store_d = make_tma_copy_C_sm90(
|
||||
|
||||
@@ -44,35 +44,30 @@ namespace cutlass::epilogue {
|
||||
// Builder Epilogue Schedules
|
||||
//
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Pre-Hopper schedules
|
||||
struct PtrArrayDefault {};
|
||||
struct EpilogueSimtVectorized {};
|
||||
struct EpiloguePtrArraySimtVectorized {};
|
||||
// Hopper direct store schedules
|
||||
struct NoSmemWarpSpecialized {};
|
||||
struct PtrArrayNoSmemWarpSpecialized {};
|
||||
struct PtrArrayNoSmemWarpSpecializedTransposed {};
|
||||
// Hopper TMA schedules
|
||||
struct TmaWarpSpecialized {};
|
||||
struct TmaWarpSpecializedCooperative {};
|
||||
|
||||
struct PtrArrayTmaWarpSpecialized { static constexpr int NumEpilogueWarpGroups = 1; };
|
||||
struct PtrArrayTmaWarpSpecializedPingpong { static constexpr int NumEpilogueWarpGroups = 2; };
|
||||
struct PtrArrayTmaWarpSpecializedCooperative { static constexpr int NumEpilogueWarpGroups = 2; };
|
||||
// Blackwell direct store schedules
|
||||
struct NoSmemWarpSpecialized1Sm {};
|
||||
struct NoSmemWarpSpecialized2Sm {};
|
||||
struct PtrArrayNoSmemWarpSpecialized1Sm : NoSmemWarpSpecialized1Sm {};
|
||||
struct PtrArrayNoSmemWarpSpecialized2Sm : NoSmemWarpSpecialized2Sm {};
|
||||
// Blackwell TMA schedules
|
||||
struct TmaWarpSpecialized1Sm {};
|
||||
struct TmaWarpSpecialized2Sm {};
|
||||
struct PtrArrayTmaWarpSpecialized1Sm {};
|
||||
struct PtrArrayTmaWarpSpecialized2Sm {};
|
||||
|
||||
struct PtrArrayTmaWarpSpecializedCooperative {
|
||||
static constexpr int NumEpilogueWarpGroups = 2;
|
||||
};
|
||||
|
||||
// Standard warp specialized epilogue
|
||||
struct PtrArrayTmaWarpSpecialized {
|
||||
static constexpr int NumEpilogueWarpGroups = 1;
|
||||
};
|
||||
|
||||
// Pingpong kernel epilogue
|
||||
struct PtrArrayTmaWarpSpecializedPingpong {
|
||||
static constexpr int NumEpilogueWarpGroups = 2;
|
||||
};
|
||||
|
||||
struct PtrArrayTmaWarpSpecialized1Sm : TmaWarpSpecialized1Sm {};
|
||||
struct PtrArrayTmaWarpSpecialized2Sm : TmaWarpSpecialized2Sm {};
|
||||
// DEPRECATED schedules, will be removed in next release
|
||||
struct TmaWarpSpecializedElementwiseBase : public TmaWarpSpecialized {};
|
||||
struct TmaWarpSpecializedCooperativeElementwiseBase : public TmaWarpSpecializedCooperative {};
|
||||
|
||||
@@ -53,6 +53,7 @@ struct FusionOperation {
|
||||
// metadata types/queries that can be overrided
|
||||
using ElementOutput = void;
|
||||
using ElementCompute = void;
|
||||
FloatRoundStyle RoundStyle = FloatRoundStyle::round_indeterminate;
|
||||
|
||||
using ElementSource = void;
|
||||
static constexpr bool IsSourceSupported = false;
|
||||
|
||||
@@ -482,7 +482,6 @@ public:
|
||||
/// Note: The below method only when problem_size_K <= 256 for signed int8 gemm
|
||||
/// or problem_size_K <= 128 for unsigned int8 gemm. The default approach is
|
||||
/// above.
|
||||
/// TODO: Add logic to fallback to the default approach
|
||||
template <
|
||||
/// Data type used to load and store< tensors
|
||||
typename ElementOutput_,
|
||||
|
||||
@@ -39,13 +39,8 @@
|
||||
#include <type_traits>
|
||||
#endif
|
||||
#if !defined(__QNX__)
|
||||
#include <cuda/std/version>
|
||||
#if defined(_MSC_VER) && defined(CCCL_VERSION) && CCCL_VERSION >= 2008000
|
||||
#include <cuda/std/__utility/swap.h>
|
||||
#else
|
||||
#include <cuda/std/utility>
|
||||
#endif
|
||||
#endif
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/uint128.h"
|
||||
|
||||
@@ -51,18 +51,57 @@
|
||||
#ifdef _MSC_VER
|
||||
// Provides support for alternate operators such as 'and', 'or', ...
|
||||
#include <ciso646>
|
||||
#include <intrin.h>
|
||||
#endif // _MSC_VER
|
||||
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED)
|
||||
# define CUTLASS_ARCH_CREDUX_ENABLED
|
||||
#endif
|
||||
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace detail {
|
||||
|
||||
CUTLASS_HOST_DEVICE int32_t popcount(int32_t x) {
|
||||
#if defined(__CUDA_ARCH__)
|
||||
return __popc(x);
|
||||
#elif defined(__GNUC__) || defined(__clang__)
|
||||
return __builtin_popcount(x);
|
||||
#elif defined(_MSC_VER)
|
||||
return __popcnt(x);
|
||||
#else
|
||||
int32_t count = 0;
|
||||
while (x) {
|
||||
count += x & 1;
|
||||
x >>= 1;
|
||||
}
|
||||
return count;
|
||||
#endif
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE int64_t popcount(int64_t x) {
|
||||
#if defined(__CUDA_ARCH__)
|
||||
return __popcll(x);
|
||||
#elif defined(__GNUC__) || defined(__clang__)
|
||||
return __builtin_popcountll(x);
|
||||
#elif defined(_MSC_VER)
|
||||
return __popcnt64(x);
|
||||
#else
|
||||
int64_t count = 0;
|
||||
while (x) {
|
||||
count += x & 1;
|
||||
x >>= 1;
|
||||
}
|
||||
return count;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T>
|
||||
struct absolute_value_op {
|
||||
CUTLASS_HOST_DEVICE
|
||||
@@ -609,22 +648,7 @@ struct and_popc_add {
|
||||
CUTLASS_HOST_DEVICE
|
||||
C operator()(A const &a, B const &b, C const &c) const {
|
||||
A and_result = a & b;
|
||||
|
||||
#if defined(__CUDA__ARCH__)
|
||||
int popc_result = __popc(and_result);
|
||||
|
||||
if constexpr (sizeof(A) == sizeof(uint64_t)) {
|
||||
popc_result += __popc(static_cast<uint32_t>(and_result >> 32));
|
||||
}
|
||||
|
||||
#else
|
||||
int popc_result = __builtin_popcount(and_result);
|
||||
if constexpr (sizeof(A) == sizeof(uint64_t)) {
|
||||
popc_result += __builtin_popcount(static_cast<uint32_t>(and_result >> 32));
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
int32_t popc_result = detail::popcount(and_result);
|
||||
return C(popc_result) + c;
|
||||
}
|
||||
};
|
||||
@@ -646,22 +670,7 @@ struct xor_popc_add {
|
||||
CUTLASS_HOST_DEVICE
|
||||
C operator()(A const &a, B const &b, C const &c) const {
|
||||
A xor_result = a ^ b;
|
||||
|
||||
#if defined(__CUDA__ARCH__)
|
||||
int popc_result = __popc(xor_result);
|
||||
|
||||
if constexpr (sizeof(A) == sizeof(uint64_t)) {
|
||||
popc_result += __popc(static_cast<uint32_t>(xor_result >> 32));
|
||||
}
|
||||
|
||||
#else
|
||||
int popc_result = __builtin_popcount(xor_result);
|
||||
if constexpr (sizeof(A) == sizeof(uint64_t)) {
|
||||
popc_result += __builtin_popcount(static_cast<uint32_t>(xor_result >> 32));
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
int32_t popc_result = detail::popcount(xor_result);
|
||||
return C(popc_result) + c;
|
||||
}
|
||||
};
|
||||
@@ -682,22 +691,7 @@ struct or_popc_add {
|
||||
CUTLASS_HOST_DEVICE
|
||||
C operator()(A const &a, B const &b, C const &c) const {
|
||||
A or_result = a | b;
|
||||
|
||||
#if defined(__CUDA__ARCH__)
|
||||
int popc_result = __popc(or_result);
|
||||
|
||||
if constexpr (sizeof(A) == sizeof(uint64_t)) {
|
||||
popc_result += __popc(static_cast<uint32_t>(or_result >> 32));
|
||||
}
|
||||
|
||||
#else
|
||||
int popc_result = __builtin_popcount(or_result);
|
||||
if constexpr (sizeof(A) == sizeof(uint64_t)) {
|
||||
popc_result += __builtin_popcount(static_cast<uint32_t>(or_result >> 32));
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
int32_t popc_result = detail::popcount(or_result);
|
||||
return C(popc_result) + c;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -567,7 +567,7 @@ sm100_make_trivial_fastFP32_tiled_mma() {
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Check for U4_UNPACK_U8, U6_UNPACK_U8 alignment requirement
|
||||
* @brief Check for F8F6F4 alignment requirement
|
||||
*
|
||||
* @tparam TileShape_MNK (MmaAtomShape_M, MmaAtomShape_N, TileShape_K)
|
||||
* @tparam ClusterShape_MNK (cluster_M, cluster_N, cluster_K)
|
||||
|
||||
@@ -85,7 +85,7 @@ compute_stage_count_or_override(StageCountAutoCarveout<carveout_bytes_> stage_co
|
||||
}
|
||||
|
||||
// Returns the maximum number of smem tiles that can be used with a given smem capacity in gemm of blockwise/groupwise scale.
|
||||
template<int capacity_bytes_, class ElementA, class ElementB, class ElementBlockScale, class TileShapeMNK, int ScaleMsPerTile, int carveout_bytes_, int alignment = 128>
|
||||
template<int capacity_bytes_, class ElementA, class ElementB, class ElementBlockScale, class TileShapeMNK, int ScaleMsPerTile, int ScaleNsPerTile, int carveout_bytes_, int alignment = 128>
|
||||
constexpr int
|
||||
compute_stage_count_with_blockwise_scale(StageCountAutoCarveout<carveout_bytes_> stage_count) {
|
||||
constexpr auto mainloop_pipeline_bytes = sizeof(typename cutlass::PipelineTmaAsync<1>::SharedStorage);
|
||||
@@ -96,7 +96,7 @@ compute_stage_count_with_blockwise_scale(StageCountAutoCarveout<carveout_bytes_>
|
||||
cutlass::bits_to_bytes(a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) +
|
||||
cutlass::bits_to_bytes(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) +
|
||||
cutlass::bits_to_bytes(scale_bits * ScaleMsPerTile) + // scale of tensor A
|
||||
cutlass::bits_to_bytes(scale_bits * 1); // scale of tensor B
|
||||
cutlass::bits_to_bytes(scale_bits * ScaleNsPerTile); // scale of tensor B
|
||||
|
||||
constexpr int stage_bytes = cutlass::round_up(stage_bytes_, alignment) +
|
||||
static_cast<int>(mainloop_pipeline_bytes);
|
||||
@@ -1043,7 +1043,8 @@ template <
|
||||
class TileShape_MNK,
|
||||
class ClusterShape_MNK,
|
||||
class StageCountType,
|
||||
int ScaleGranularityM_
|
||||
int ScaleGranularityM_,
|
||||
int ScaleGranularityN_
|
||||
>
|
||||
struct CollectiveBuilder<
|
||||
arch::Sm90,
|
||||
@@ -1058,11 +1059,11 @@ struct CollectiveBuilder<
|
||||
TileShape_MNK,
|
||||
ClusterShape_MNK,
|
||||
StageCountType,
|
||||
KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM_>,
|
||||
KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM_, ScaleGranularityN_>,
|
||||
cute::enable_if_t<
|
||||
not detail::is_use_rmem_A<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag>()>
|
||||
> {
|
||||
using KernelScheduleType = KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM_>;
|
||||
using KernelScheduleType = KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM_, ScaleGranularityN_>;
|
||||
|
||||
static_assert(is_static<TileShape_MNK>::value);
|
||||
static_assert(is_static<ClusterShape_MNK>::value);
|
||||
@@ -1090,7 +1091,7 @@ struct CollectiveBuilder<
|
||||
static constexpr bool IsCooperative = cute::is_any_of_v<KernelScheduleType,
|
||||
KernelTmaWarpSpecializedCooperative,
|
||||
KernelPtrArrayTmaWarpSpecializedCooperative,
|
||||
KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM_>>;
|
||||
KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM_, ScaleGranularityN_>>;
|
||||
using AtomLayoutMNK = cute::conditional_t<IsCooperative,
|
||||
Layout<Shape<_2,_1,_1>>, Layout<Shape<_1,_1,_1>>>;
|
||||
|
||||
@@ -1109,12 +1110,15 @@ struct CollectiveBuilder<
|
||||
static constexpr int KernelSmemCarveout = static_cast<int>(TensorMapStorage);
|
||||
|
||||
static constexpr int ScaleGranularityM = ScaleGranularityM_ == 0 ? size<0>(TileShape_MNK{}) : ScaleGranularityM_;
|
||||
static constexpr int ScaleGranularityN = ScaleGranularityN_ == 0 ? size<1>(TileShape_MNK{}) : ScaleGranularityN_;
|
||||
static constexpr int ScaleMsPerTile = size<0>(TileShape_MNK{}) / ScaleGranularityM;
|
||||
static constexpr int ScaleNsPerTile = size<1>(TileShape_MNK{}) / ScaleGranularityN;
|
||||
static_assert((size<0>(TileShape_MNK{}) % ScaleGranularityM) == 0, "FP8 scaling granularity must evenly divide tile shape along M.");
|
||||
static_assert((size<1>(TileShape_MNK{}) % ScaleGranularityN) == 0, "FP8 scaling granularity must evenly divide tile shape along N.");
|
||||
|
||||
static constexpr int PipelineStages = detail::compute_stage_count_with_blockwise_scale<detail::sm90_smem_capacity_bytes - KernelSmemCarveout,
|
||||
ElementAMma, ElementBMma, ElementBlockScale, TileShape_MNK, ScaleMsPerTile>(StageCountType{});
|
||||
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType, ScaleGranularityM_>;
|
||||
ElementAMma, ElementBMma, ElementBlockScale, TileShape_MNK, ScaleMsPerTile, ScaleNsPerTile>(StageCountType{});
|
||||
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType, ScaleGranularityM_, ScaleGranularityN_>;
|
||||
|
||||
using SmemCopyAtomA = void;
|
||||
using SmemCopyAtomB = void;
|
||||
|
||||
@@ -75,6 +75,15 @@ private:
|
||||
}
|
||||
|
||||
// `multiply` scale the partial accumulators and `add` to main accumulator (FFMA).
|
||||
CUTLASS_DEVICE
|
||||
void scale_core(ElementAccumulator const &scale) {
|
||||
warpgroup_wait<0>();
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(accum_); ++i) {
|
||||
accum_(i) += accum_temp_(i) * scale;
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
class EngineScale,
|
||||
class LayoutScale>
|
||||
@@ -94,6 +103,31 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
class EngineScaleA,
|
||||
class LayoutScaleA,
|
||||
class EngineScaleB,
|
||||
class LayoutScaleB>
|
||||
CUTLASS_DEVICE
|
||||
void scale_core(const cute::Tensor<EngineScaleA, LayoutScaleA> &scaleA, const cute::Tensor<EngineScaleB, LayoutScaleB> &scaleB) {
|
||||
using TensorScaleA = cute::Tensor<EngineScaleA, LayoutScaleA>;
|
||||
using TensorScaleB = cute::Tensor<EngineScaleB, LayoutScaleB>;
|
||||
|
||||
static_assert(is_static<LayoutScaleA>::value, "ScaleA Layout should be static");
|
||||
static_assert(is_static<LayoutScaleB>::value, "ScaleB Layout should be static");
|
||||
static_assert(is_rmem<TensorScaleA>::value, "ScaleA tensor must be rmem resident.");
|
||||
static_assert(is_rmem<TensorScaleB>::value, "ScaleB tensor must be rmem resident.");
|
||||
|
||||
static_assert(LayoutAccum{}.shape() == LayoutScaleA{}.shape(), "Accumulator and scaleA must have same shape.");
|
||||
static_assert(LayoutAccum{}.shape() == LayoutScaleB{}.shape(), "Accumulator and scaleB must have same shape.");
|
||||
|
||||
warpgroup_wait<0>();
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(accum_); ++i) {
|
||||
accum_(i) += accum_temp_(i) * scaleA(i) * scaleB(i);
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
GmmaFP8Accumulation(
|
||||
@@ -152,6 +186,16 @@ public:
|
||||
//
|
||||
|
||||
/// scale (multiply_add) the results from the MMA accumulators to main accumulator if needed.
|
||||
CUTLASS_DEVICE
|
||||
void scale_if_needed(ElementAccumulator const &scale) {
|
||||
mma_count_ += mma_count_per_mainloop_iteration_;
|
||||
reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0);
|
||||
if (reset_accum_flag_) {
|
||||
scale_core(scale);
|
||||
mma_count_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
class EngineScale,
|
||||
class LayoutScale>
|
||||
@@ -165,7 +209,29 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
class EngineScaleA,
|
||||
class LayoutScaleA,
|
||||
class EngineScaleB,
|
||||
class LayoutScaleB>
|
||||
CUTLASS_DEVICE
|
||||
void scale_if_needed(const cute::Tensor<EngineScaleA, LayoutScaleA> &scaleA, const cute::Tensor<EngineScaleB, LayoutScaleB> &scaleB) {
|
||||
mma_count_ += mma_count_per_mainloop_iteration_;
|
||||
reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0);
|
||||
if (reset_accum_flag_) {
|
||||
scale_core(scaleA, scaleB);
|
||||
mma_count_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/// scale (multiply_add) the residue results from the MMA accumulators to main accumulator if needed.
|
||||
CUTLASS_DEVICE
|
||||
void scale_residue_if_needed(ElementAccumulator const &scale) {
|
||||
if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) {
|
||||
scale_core(scale);
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
class EngineScale,
|
||||
class LayoutScale>
|
||||
@@ -175,6 +241,18 @@ public:
|
||||
scale_core(scale);
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
class EngineScaleA,
|
||||
class LayoutScaleA,
|
||||
class EngineScaleB,
|
||||
class LayoutScaleB>
|
||||
CUTLASS_DEVICE
|
||||
void scale_residue_if_needed(const cute::Tensor<EngineScaleA, LayoutScaleA> &scaleA, const cute::Tensor<EngineScaleB, LayoutScaleB> &scaleB) {
|
||||
if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) {
|
||||
scale_core(scaleA, scaleB);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
@@ -30,8 +30,6 @@
|
||||
**************************************************************************************************/
|
||||
|
||||
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
@@ -288,23 +286,23 @@ struct CollectiveMma<
|
||||
using TensorStorage = typename SharedStorage::TensorStorage;
|
||||
using PipelineStorage = typename SharedStorage::PipelineStorage;
|
||||
|
||||
// Only one thread issues the TMA and updates the barriers in a 2SM MMA, adjust bytes accordingly
|
||||
static constexpr uint32_t SFTransactionBytes =
|
||||
cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFA{})) * cute::sizeof_bits_v<ElementSF>) +
|
||||
cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFB{})) * cute::sizeof_bits_v<ElementSF>);
|
||||
// Only one thread issues the TMA and updates the barriers in a 2SM MMA, adjust bytes accordingly
|
||||
static constexpr uint32_t ABTmaTransactionBytes =
|
||||
cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v<ElementA>) +
|
||||
cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutB{})) * cute::sizeof_bits_v<ElementB>);
|
||||
static constexpr uint32_t TmaTransactionBytes = ABTmaTransactionBytes + SFTransactionBytes;
|
||||
|
||||
template<class AccTensor, class SfaTensor, class SfbTensor>
|
||||
template <class AccTensor, class SfaTensor, class SfbTensor>
|
||||
struct TmemStorage {
|
||||
AccTensor accumulators;
|
||||
SfaTensor tCtSFA;
|
||||
SfbTensor tCtSFB;
|
||||
};
|
||||
|
||||
template<
|
||||
template <
|
||||
class KTileCount,
|
||||
class GTensorPartitionedA, class GTensorPartitionedB,
|
||||
class STensorA, class STensorB,
|
||||
@@ -348,7 +346,8 @@ struct CollectiveMma<
|
||||
, mcast_mask_sfa(mcast_mask_sfa_), mcast_mask_sfb(mcast_mask_sfb_) {}
|
||||
};
|
||||
|
||||
template<
|
||||
template <
|
||||
class TiledMma,
|
||||
class FragmentA, class FragmentB,
|
||||
class FragmentSFA, class FragmentSFB,
|
||||
class SFATiledCopy, class SmemFrgSFA, class TmemFrgSFA,
|
||||
@@ -496,6 +495,7 @@ struct CollectiveMma<
|
||||
Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA));
|
||||
Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB));
|
||||
auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape);
|
||||
|
||||
// Cluster layout for TMA construction
|
||||
auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{}));
|
||||
auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape_fallback);
|
||||
@@ -505,7 +505,7 @@ struct CollectiveMma<
|
||||
|
||||
// Cluster layout for TMA construction of SFB
|
||||
auto cluster_layout_sfb_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMMA_SF::AtomThrID{}));
|
||||
auto cluster_layout_sfb_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMMA_SF::AtomThrID{}));
|
||||
auto cluster_layout_sfb_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMMA_SF::AtomThrID{}));
|
||||
|
||||
typename Params::TMA_A tma_load_a = make_tma_atom_A_sm100<TmaInternalElementA>(
|
||||
GmemTiledCopyA{},
|
||||
@@ -649,7 +649,7 @@ struct CollectiveMma<
|
||||
return cute::make_tuple(tmem_storage.accumulators(_,_,_,stage));
|
||||
}
|
||||
|
||||
template<class EpilogueTile, bool IsOverlappingAccum = false>
|
||||
template <class EpilogueTile, bool IsOverlappingAccum = false>
|
||||
CUTLASS_DEVICE static
|
||||
auto
|
||||
init_tmem_tensors(EpilogueTile epi_tile) {
|
||||
@@ -660,7 +660,7 @@ struct CollectiveMma<
|
||||
tiled_mma, acc_shape, EpilogueTile{});
|
||||
Tensor tCtSFA = make_tensor<typename TiledMma::FrgTypeSFA>(shape(SmemLayoutAtomSFA{}));
|
||||
Tensor tCtSFB = make_tensor<typename TiledMma::FrgTypeSFB>(shape(SmemLayoutAtomSFB{}));
|
||||
|
||||
|
||||
TmemStorage<decltype(accumulators), decltype(tCtSFA), decltype(tCtSFB)> tmem_storage;
|
||||
tmem_storage.accumulators = accumulators;
|
||||
tmem_storage.tCtSFA = tCtSFA;
|
||||
@@ -669,10 +669,10 @@ struct CollectiveMma<
|
||||
return tmem_storage;
|
||||
}
|
||||
|
||||
template<class AccTensor, class SfaTensor, class SfbTensor>
|
||||
template <class TmemStorage>
|
||||
CUTLASS_DEVICE static
|
||||
void
|
||||
set_tmem_offsets(TmemStorage<AccTensor, SfaTensor, SfbTensor>& tmem_storage, uint32_t tmem_base_addr) {
|
||||
set_tmem_offsets(TmemStorage& tmem_storage, uint32_t tmem_base_addr) {
|
||||
tmem_storage.accumulators.data() = tmem_base_addr;
|
||||
tmem_storage.tCtSFA.data() = tmem_storage.accumulators.data().get() + cutlass::detail::find_tmem_tensor_col_offset(tmem_storage.accumulators);
|
||||
tmem_storage.tCtSFB.data() = tmem_storage.tCtSFA.data().get() + cutlass::detail::find_tmem_tensor_col_offset(tmem_storage.tCtSFA);
|
||||
@@ -751,7 +751,6 @@ struct CollectiveMma<
|
||||
Tensor sSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{});
|
||||
|
||||
// Define the CTA-in-cluster Layout and Coord
|
||||
|
||||
Layout cta_layout_mnk = make_layout(cluster_shape_);
|
||||
Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{}));
|
||||
auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster_);
|
||||
@@ -785,13 +784,11 @@ struct CollectiveMma<
|
||||
uint16_t mcast_mask_sfa = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk);
|
||||
uint16_t mcast_mask_sfb = create_tma_multicast_mask<1>(cta_layout_sfb_vmnk, cta_coord_sfb_vmnk);
|
||||
|
||||
LoadParams load_params {
|
||||
return LoadParams{
|
||||
size<3>(gA_mkl), // for scheduler
|
||||
tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values
|
||||
tAgSFA_mkl, tBgSFB_nkl, tAsSFA, tBsSFB, // for input scale factor tensor values
|
||||
mcast_mask_a, mcast_mask_b, mcast_mask_sfa, mcast_mask_sfb // multicast masks
|
||||
};
|
||||
return load_params;
|
||||
mcast_mask_a, mcast_mask_b, mcast_mask_sfa, mcast_mask_sfb}; // multicast masks
|
||||
}
|
||||
|
||||
/// Set up the data needed by this collective for mma compute.
|
||||
@@ -802,8 +799,8 @@ struct CollectiveMma<
|
||||
TensorStorage& shared_tensors) const {
|
||||
|
||||
// Allocate "fragments/descriptors" for A and B matrices
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
|
||||
// Allocate "fragments/descriptors" for A and B matrices
|
||||
Tensor tCrA = TiledMma::make_fragment_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
|
||||
@@ -854,17 +851,12 @@ struct CollectiveMma<
|
||||
tiled_mma.idesc_.a_format_ = uint8_t(runtime_data_type_a_) & 0b111;
|
||||
tiled_mma.idesc_.b_format_ = uint8_t(runtime_data_type_b_) & 0b111;
|
||||
}
|
||||
MmaParams<
|
||||
decltype(tCrA), decltype(tCrB), decltype(tCtSFA), decltype(tCtSFB),
|
||||
decltype(tiled_copy_s2t_SFA), decltype(thr_tCsSFA_compact_s2t), decltype(thr_tCtSFA_compact_s2t),
|
||||
decltype(tiled_copy_s2t_SFB), decltype(thr_tCsSFB_compact_s2t), decltype(thr_tCtSFB_compact_s2t)
|
||||
> mma_params {
|
||||
|
||||
return MmaParams{
|
||||
tiled_mma,
|
||||
tCrA, tCrB, tCtSFA, tCtSFB,
|
||||
tiled_copy_s2t_SFA, thr_tCsSFA_compact_s2t, thr_tCtSFA_compact_s2t,
|
||||
tiled_copy_s2t_SFB, thr_tCsSFB_compact_s2t, thr_tCtSFB_compact_s2t
|
||||
};
|
||||
return mma_params;
|
||||
tiled_copy_s2t_SFB, thr_tCsSFB_compact_s2t, thr_tCtSFB_compact_s2t};
|
||||
}
|
||||
|
||||
/// Perform a collective-scoped matrix multiply-accumulate
|
||||
@@ -983,52 +975,12 @@ struct CollectiveMma<
|
||||
|
||||
uint32_t skip_wait = k_tile_count <= 0;
|
||||
auto barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait);
|
||||
bool is_first_iter = true;
|
||||
|
||||
//
|
||||
// PIPELINED MAIN LOOP
|
||||
//
|
||||
tiled_mma.accumulate_ = UMMA::ScaleOut::Zero;
|
||||
if (k_tile_count > 0) { // first iteraion
|
||||
// WAIT on mainloop_pipe_consumer_state until its data are available
|
||||
// (phase bit flips from mainloop_pipe_consumer_state.phase() value)
|
||||
mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token);
|
||||
|
||||
// Compute on k_tile
|
||||
int read_stage = mainloop_pipe_consumer_state.index();
|
||||
// Save current mainlop pipeline read state
|
||||
auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state;
|
||||
|
||||
// Advance mainloop_pipe
|
||||
++mainloop_pipe_consumer_state;
|
||||
--k_tile_count;
|
||||
skip_wait = k_tile_count <= 0;
|
||||
// Peek at next iteration
|
||||
barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait);
|
||||
|
||||
if (cute::elect_one_sync()) {
|
||||
copy(tiled_copy_s2t_SFA, thr_tCsSFA_s2t(_,_,_,_,read_stage), thr_tCtSFA_s2t);
|
||||
copy(tiled_copy_s2t_SFB, thr_tCsSFB_s2t(_,_,_,_,read_stage), thr_tCtSFB_s2t);
|
||||
}
|
||||
|
||||
if constexpr (IsOverlappingAccum) {
|
||||
accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state);
|
||||
}
|
||||
|
||||
// Unroll the K mode manually so we can set scale C to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
// (V,M) x (V,N) => (V,M,N)
|
||||
cute::gemm(tiled_mma.with(tiled_mma.accumulate_,
|
||||
tCtSFA(_,_,k_block),
|
||||
tCtSFB_mma(_,_,k_block)),
|
||||
tCrA(_,_,k_block,read_stage),
|
||||
tCrB(_,_,k_block,read_stage),
|
||||
accumulators);
|
||||
tiled_mma.accumulate_ = UMMA::ScaleOut::One;
|
||||
}
|
||||
mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state);
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
while (k_tile_count > 0) {
|
||||
// WAIT on mainloop_pipe_consumer_state until its data are available
|
||||
@@ -1052,6 +1004,13 @@ struct CollectiveMma<
|
||||
copy(tiled_copy_s2t_SFB, thr_tCsSFB_s2t(_,_,_,_,read_stage), thr_tCtSFB_s2t);
|
||||
}
|
||||
|
||||
if constexpr (IsOverlappingAccum) {
|
||||
if (is_first_iter) {
|
||||
accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state);
|
||||
is_first_iter = false;
|
||||
}
|
||||
}
|
||||
|
||||
// Unroll the K mode manually so we can set scale C to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
@@ -1064,6 +1023,7 @@ struct CollectiveMma<
|
||||
accumulators);
|
||||
tiled_mma.accumulate_ = UMMA::ScaleOut::One;
|
||||
}
|
||||
|
||||
mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state);
|
||||
}
|
||||
|
||||
|
||||
@@ -31,7 +31,6 @@
|
||||
|
||||
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
@@ -239,12 +238,12 @@ struct CollectiveMma<
|
||||
cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v<ElementA>) +
|
||||
cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutB{})) * cute::sizeof_bits_v<ElementB>);
|
||||
|
||||
template<class AccTensor>
|
||||
template <class AccTensor>
|
||||
struct TmemStorage {
|
||||
AccTensor accumulators;
|
||||
};
|
||||
|
||||
template<
|
||||
template <
|
||||
class KTileCount,
|
||||
class GTensorPartitionedA, class GTensorPartitionedB,
|
||||
class STensorA, class STensorB
|
||||
@@ -273,7 +272,10 @@ struct CollectiveMma<
|
||||
, mcast_mask_a(mcast_mask_a_), mcast_mask_b(mcast_mask_b_) {}
|
||||
};
|
||||
|
||||
template<class FragmentA, class FragmentB>
|
||||
template <
|
||||
class TiledMma,
|
||||
class FragmentA, class FragmentB
|
||||
>
|
||||
struct MmaParams {
|
||||
TiledMma tiled_mma;
|
||||
FragmentA tCrA;
|
||||
@@ -336,7 +338,7 @@ struct CollectiveMma<
|
||||
, runtime_data_type_a_(params.runtime_data_type_a)
|
||||
, runtime_data_type_b_(params.runtime_data_type_b) {
|
||||
if constexpr (IsDynamicCluster) {
|
||||
const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x &&
|
||||
const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x &&
|
||||
cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y);
|
||||
observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a;
|
||||
observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b;
|
||||
@@ -461,7 +463,7 @@ struct CollectiveMma<
|
||||
return cute::make_tuple(tmem_storage.accumulators(_,_,_,stage));
|
||||
}
|
||||
|
||||
template<class EpilogueTile, bool IsOverlappingAccum = false>
|
||||
template <class EpilogueTile, bool IsOverlappingAccum = false>
|
||||
CUTLASS_DEVICE static
|
||||
auto
|
||||
init_tmem_tensors(EpilogueTile epi_tile) {
|
||||
@@ -475,10 +477,10 @@ struct CollectiveMma<
|
||||
return tmem_storage;
|
||||
}
|
||||
|
||||
template<class AccTensor>
|
||||
template <class TmemStorage>
|
||||
CUTLASS_DEVICE static
|
||||
void
|
||||
set_tmem_offsets(TmemStorage<AccTensor>& tmem_storage, uint32_t tmem_base_addr) {
|
||||
set_tmem_offsets(TmemStorage& tmem_storage, uint32_t tmem_base_addr) {
|
||||
tmem_storage.accumulators.data() = tmem_base_addr;
|
||||
}
|
||||
|
||||
@@ -535,21 +537,21 @@ struct CollectiveMma<
|
||||
// TMA Multicast Masks
|
||||
uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk);
|
||||
uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk);
|
||||
|
||||
LoadParams load_params {
|
||||
|
||||
return LoadParams{
|
||||
shape<3>(gA_mkl), // for scheduler
|
||||
tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values
|
||||
mcast_mask_a, mcast_mask_b // multicast masks
|
||||
};
|
||||
return load_params;
|
||||
mcast_mask_a, mcast_mask_b}; // multicast masks
|
||||
}
|
||||
|
||||
/// Set up the data needed by this collective for mma compute.
|
||||
template <class AccTensor>
|
||||
template <class TmemStorage>
|
||||
CUTLASS_DEVICE auto
|
||||
mma_init(
|
||||
[[maybe_unused]] TmemStorage<AccTensor> tmem_tensors,
|
||||
TensorStorage& shared_tensors) const {
|
||||
[[maybe_unused]] TmemStorage tmem_storage,
|
||||
TensorStorage& shared_tensors) const {
|
||||
|
||||
// Allocate "fragments/descriptors" for A and B matrices
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
|
||||
@@ -558,7 +560,7 @@ struct CollectiveMma<
|
||||
Tensor tCrB = TiledMma::make_fragment_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
|
||||
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<3>(sA)); // PIPE
|
||||
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<3>(sB));
|
||||
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<3>(sB)); // PIPE
|
||||
|
||||
TiledMma tiled_mma;
|
||||
|
||||
@@ -568,11 +570,10 @@ struct CollectiveMma<
|
||||
tiled_mma.idesc_.a_format_ = uint8_t(runtime_data_type_a_) & 0b111;
|
||||
tiled_mma.idesc_.b_format_ = uint8_t(runtime_data_type_b_) & 0b111;
|
||||
}
|
||||
MmaParams<decltype(tCrA), decltype(tCrB)> mma_params {
|
||||
|
||||
return MmaParams{
|
||||
tiled_mma,
|
||||
tCrA, tCrB
|
||||
};
|
||||
return mma_params;
|
||||
tCrA, tCrB};
|
||||
}
|
||||
|
||||
/// Perform a collective-scoped matrix multiply-accumulate
|
||||
@@ -657,6 +658,7 @@ struct CollectiveMma<
|
||||
) {
|
||||
static_assert(is_tmem<FrgEngine>::value, "Accumulator must be tmem resident.");
|
||||
static_assert(rank(FrgLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)");
|
||||
|
||||
auto accumulators = get<0>(accumulators_pair);
|
||||
auto [tiled_mma, tCrA, tCrB] = mma_inputs;
|
||||
|
||||
|
||||
@@ -58,6 +58,7 @@ template <
|
||||
class ClusterShape,
|
||||
class KernelSchedule,
|
||||
int ScaleGranularityM_,
|
||||
int ScaleGranularityN_,
|
||||
class TileShape_,
|
||||
class ElementA_,
|
||||
class StrideA_,
|
||||
@@ -73,7 +74,7 @@ template <
|
||||
class SmemCopyAtomB_,
|
||||
class TransformB_>
|
||||
struct CollectiveMma<
|
||||
MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8<Stages, ClusterShape, KernelSchedule, ScaleGranularityM_>,
|
||||
MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8<Stages, ClusterShape, KernelSchedule, ScaleGranularityM_, ScaleGranularityN_>,
|
||||
TileShape_,
|
||||
ElementA_,
|
||||
StrideA_,
|
||||
@@ -92,7 +93,7 @@ struct CollectiveMma<
|
||||
//
|
||||
// Type Aliases
|
||||
//
|
||||
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8<Stages, ClusterShape, KernelSchedule, ScaleGranularityM_>;
|
||||
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8<Stages, ClusterShape, KernelSchedule, ScaleGranularityM_, ScaleGranularityN_>;
|
||||
using TileShape = TileShape_;
|
||||
using ElementA = ElementA_;
|
||||
using StrideA = StrideA_;
|
||||
@@ -120,7 +121,9 @@ struct CollectiveMma<
|
||||
static constexpr int NumProducerThreadEvents = 2;
|
||||
|
||||
static constexpr int ScaleGranularityM = ScaleGranularityM_ == 0 ? size<0>(TileShape{}) : ScaleGranularityM_;
|
||||
static constexpr int ScaleGranularityN = ScaleGranularityN_ == 0 ? size<1>(TileShape{}) : ScaleGranularityN_;
|
||||
static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM;
|
||||
static constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN;
|
||||
|
||||
static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
@@ -131,6 +134,7 @@ struct CollectiveMma<
|
||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
static_assert((size<0>(TileShape{}) % ScaleGranularityM) == 0, "FP8 scaling granularity must evenly divide tile shape along M.");
|
||||
static_assert((size<1>(TileShape{}) % ScaleGranularityN) == 0, "FP8 scaling granularity must evenly divide tile shape along N.");
|
||||
|
||||
// Tile along modes in a way that maximizes the TMA box size.
|
||||
using SmemLayoutA = decltype(tile_to_shape(
|
||||
@@ -144,12 +148,13 @@ struct CollectiveMma<
|
||||
|
||||
// Block scaling gmem-to-smem copy atom
|
||||
using BlockScaleCopyTypeA = cute::uint_byte_t<cute::min(static_cast<int>(sizeof(ElementBlockScale)) * ScaleMsPerTile, 16)>;
|
||||
using BlockScaleCopyTypeB = cute::uint_byte_t<cute::min(static_cast<int>(sizeof(ElementBlockScale)) * ScaleNsPerTile, 16)>;
|
||||
using SmemBlockScalingCopyAtomA = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<BlockScaleCopyTypeA>, ElementBlockScale>;
|
||||
using SmemBlockScalingCopyAtomB = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementBlockScale>, ElementBlockScale>;
|
||||
using SmemBlockScalingCopyAtomB = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<BlockScaleCopyTypeB>, ElementBlockScale>;
|
||||
|
||||
// Block scaling smem layout
|
||||
using SmemLayoutScaleA = Layout<Shape<Int<ScaleMsPerTile>, Int<DispatchPolicy::Stages>>>;
|
||||
using SmemLayoutScaleB = Layout<Shape<Int<DispatchPolicy::Stages>>, Stride<_1>>; // `ScaleNsPerTile` is always 1.
|
||||
using SmemLayoutScaleB = Layout<Shape<Int<ScaleNsPerTile>, Int<DispatchPolicy::Stages>>>;
|
||||
|
||||
static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more.");
|
||||
static_assert(cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value &&
|
||||
@@ -168,7 +173,7 @@ struct CollectiveMma<
|
||||
cute::array_aligned<typename TiledMma::ValTypeA, cute::cosize_v<SmemLayoutA>> smem_A; // mxk
|
||||
cute::array_aligned<typename TiledMma::ValTypeB, cute::cosize_v<SmemLayoutB>> smem_B; // nxk
|
||||
cute::array_aligned<ElementBlockScale, cute::cosize_v<SmemLayoutScaleA>> smem_scale_A; // ScaleMsPerTile x k
|
||||
cute::array_aligned<ElementBlockScale, cute::cosize_v<SmemLayoutScaleB>> smem_scale_B; // 1xk
|
||||
cute::array_aligned<ElementBlockScale, cute::cosize_v<SmemLayoutScaleB>> smem_scale_B; // ScaleNsPerTile x k
|
||||
} tensors;
|
||||
|
||||
using PipelineStorage = typename MainloopPipeline::SharedStorage;
|
||||
@@ -322,17 +327,17 @@ struct CollectiveMma<
|
||||
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l)
|
||||
|
||||
// Make the tiled views of scale tensors
|
||||
auto scaleA_shape = make_shape(get<2>(gA_mkl.shape()), Int<ScaleMsPerTile>{}, get<3>(gA_mkl.shape()), get<4>(gA_mkl.shape())); // (m,ScaleMsPerTile,k,l)
|
||||
auto scale_dA = make_stride(get<3>(gA_mkl.shape()) * Int<ScaleMsPerTile>{}, Int<1>{}, Int<ScaleMsPerTile>{}, get<2>(gA_mkl.shape()) * get<3>(gA_mkl.shape()) * Int<ScaleMsPerTile>{});
|
||||
auto scaleA_shape = make_shape(shape<2>(gA_mkl), Int<ScaleMsPerTile>{}, shape<3>(gA_mkl), shape<4>(gA_mkl)); // (m,ScaleMsPerTile,k,l)
|
||||
auto scaleB_shape = make_shape(shape<2>(gB_nkl), Int<ScaleNsPerTile>{}, shape<3>(gB_nkl), shape<4>(gB_nkl)); // (n,ScaleNsPerTile,k,l)
|
||||
auto scale_dA = compact_order(scaleA_shape, Step<_2,_0,_1,_3>{});
|
||||
auto scale_dB = compact_order(scaleB_shape, Step<_2,_0,_1,_3>{});
|
||||
auto scaleA_layout = make_layout(scaleA_shape, scale_dA);
|
||||
auto scaleB_shape = make_shape(get<2>(gB_nkl.shape()), get<3>(gB_nkl.shape()), get<4>(gB_nkl.shape())); // (n,k,l)
|
||||
auto scale_dB = make_stride(get<3>(gB_nkl.shape()), Int<1>{}, get<2>(gB_nkl.shape()) * get<3>(gB_nkl.shape()));
|
||||
auto scaleB_layout = make_layout(scaleB_shape, scale_dB);
|
||||
|
||||
// Note that mScaleA_mkl and mScaleB_nkl are already blocked tiled in the `m` host and
|
||||
// Note that mScaleA_mkl and mScaleB_nkl are already blocked tiled in the `m` host and
|
||||
// gScaleA_mkl and gScaleB_nkl in `g` global memory are same as mScaleA_mkl and mScaleB_nkl.
|
||||
Tensor mScaleA_mkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_A), scaleA_layout); // (m,ScaleMsPerTile,k,l)
|
||||
Tensor mScaleB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_B), scaleB_layout); // (n,k,l)
|
||||
Tensor mScaleB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_B), scaleB_layout); // (n,ScaleNsPerTile,k,l)
|
||||
|
||||
return cute::make_tuple(gA_mkl, gB_nkl, mScaleA_mkl, mScaleB_nkl);
|
||||
}
|
||||
@@ -356,13 +361,13 @@ struct CollectiveMma<
|
||||
uint32_t block_rank_in_cluster,
|
||||
TensorStorage& shared_tensors) {
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
// Blockscaling: Tma loads for load_input and CpAsync for load_scale
|
||||
if (lane_predicate) {
|
||||
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
Tensor sScaleA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), SmemLayoutScaleA{}); // (ScaleMsPerTile,k)
|
||||
Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (k)
|
||||
Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (ScaleNsPerTile,k)
|
||||
|
||||
//
|
||||
// Prepare the TMA loads for A and B
|
||||
@@ -388,10 +393,10 @@ struct CollectiveMma<
|
||||
Tensor mScaleB_nkl = get<3>(load_inputs);
|
||||
|
||||
Tensor gScaleA = mScaleA_mkl(m_coord,_,_,l_coord); // (1,ScaleMsPerTile,k,1)
|
||||
Tensor gScaleB = mScaleB_nkl(n_coord,_,l_coord); // (1,k,1)
|
||||
Tensor gScaleB = mScaleB_nkl(n_coord,_,_,l_coord); // (1,ScaleNsPerTile,k,1)
|
||||
|
||||
TiledCopy scale_copy_a = make_tiled_copy(SmemBlockScalingCopyAtomA{}, Layout<Shape<_1>>{}, Layout<Shape<Int<ScaleMsPerTile>>>{}); // (1,ScaleMsPerTile,1)
|
||||
TiledCopy scale_copy_b = make_tiled_copy(SmemBlockScalingCopyAtomB{}, Layout<Shape<_1>>{}, Layout<Shape<_1>>{}); // (1,1,1)
|
||||
TiledCopy scale_copy_b = make_tiled_copy(SmemBlockScalingCopyAtomB{}, Layout<Shape<_1>>{}, Layout<Shape<Int<ScaleNsPerTile>>>{}); // (1,ScaleNsPerTile,1)
|
||||
ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x);
|
||||
ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(threadIdx.x);
|
||||
|
||||
@@ -446,7 +451,7 @@ struct CollectiveMma<
|
||||
|
||||
// Copy scale tensors from global memory to shared memory
|
||||
copy(scale_copy_a, tAgA_ScaleA(_,_,*k_tile_iter), tAsA_ScaleA(_,_,write_stage));
|
||||
copy(scale_copy_b, tBgB_ScaleB(_,*k_tile_iter), tBsB_ScaleB(_,write_stage));
|
||||
copy(scale_copy_b, tBgB_ScaleB(_,_,*k_tile_iter), tBsB_ScaleB(_,_,write_stage));
|
||||
pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive_noinc);
|
||||
|
||||
++k_tile_iter;
|
||||
@@ -508,7 +513,11 @@ struct CollectiveMma<
|
||||
Shape<Shape<Int<ScaleGranularityM>, Int<ScaleMsPerTile>>, cute::tuple_element_t<1, TileShape>, Int<DispatchPolicy::Stages>>,
|
||||
Stride<Stride<_0, _1>, _0, Int<ScaleMsPerTile>>
|
||||
>{}); // ((ScaleGranularityM,ScaleMsPerTile),n,k)
|
||||
Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (k)
|
||||
Tensor sScaleBViewAsC = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()),
|
||||
Layout<
|
||||
Shape<cute::tuple_element_t<0, TileShape>, Shape<Int<ScaleGranularityN>, Int<ScaleNsPerTile>>, Int<DispatchPolicy::Stages>>,
|
||||
Stride<_0, Stride<_0, _1>, Int<ScaleNsPerTile>>
|
||||
>{}); // (m,(ScaleGranularityN,ScaleNsPerTile),k)
|
||||
|
||||
//
|
||||
// Define C accumulators and A/B partitioning
|
||||
@@ -531,7 +540,8 @@ struct CollectiveMma<
|
||||
TiledMma tiled_mma;
|
||||
auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx));
|
||||
|
||||
Tensor tCsScaleAViewAsC = tiled_mma.get_slice(thread_idx).partition_C(sScaleAViewAsC); // (MMA,MMA_M,MMA_N,PIPE), `thread_mma` above is correct when partitioning A and B, but it is not correct when partitioning C.
|
||||
Tensor tCsScaleAViewAsC = tiled_mma.get_slice(thread_idx).partition_C(sScaleAViewAsC); // (MMA,MMA_M,MMA_N,PIPE), `thread_mma` above is correct when partitioning A and B, but it is not correct when partitioning C.
|
||||
Tensor tCsScaleBViewAsC = tiled_mma.get_slice(thread_idx).partition_C(sScaleBViewAsC); // (MMA,MMA_M,MMA_N,PIPE), `thread_mma` above is correct when partitioning A and B, but it is not correct when partitioning C.
|
||||
|
||||
Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
|
||||
Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
@@ -557,11 +567,8 @@ struct CollectiveMma<
|
||||
PipelineState smem_pipe_release = smem_pipe_read;
|
||||
|
||||
// Per block scale values for operand A and B
|
||||
using RegLayoutScaleAViewAsC = decltype(make_layout_like(tCsScaleAViewAsC(_, _, _, 0).layout())); // `make_layout_like` makes a compact layout.
|
||||
using RegLayoutScaleAEssential = decltype(filter_zeros(RegLayoutScaleAViewAsC{}.stride(), RegLayoutScaleAViewAsC{}.shape())); // an interface to traverse the underlying storage for the compact layout mentioned above
|
||||
|
||||
Tensor tCrScaleAViewAsC = make_tensor<ElementBlockScale>(RegLayoutScaleAViewAsC{}); // (MMA,MMA_M,MMA_N)
|
||||
ElementBlockScale scale_b;
|
||||
Tensor tCrScaleAViewAsC = make_tensor_like<ElementBlockScale>(tCsScaleAViewAsC(_, _, _, 0)); // (MMA,MMA_M,MMA_N)
|
||||
Tensor tCrScaleBViewAsC = make_tensor_like<ElementBlockScale>(tCsScaleBViewAsC(_, _, _, 0)); // (MMA,MMA_M,MMA_N)
|
||||
|
||||
// Prologue GMMAs
|
||||
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
|
||||
@@ -583,21 +590,26 @@ struct CollectiveMma<
|
||||
|
||||
int read_stage = smem_pipe_read.index();
|
||||
|
||||
// Load per block scale values from shared memory to registers.
|
||||
scale_b = sScaleB[read_stage];
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) {
|
||||
tCrScaleAViewAsC.data()[i] = tCsScaleAViewAsC(_, _, _, read_stage)(idx2crd(i, RegLayoutScaleAEssential{}));
|
||||
// Load per block scale values from shared memory to registers
|
||||
copy(tCsScaleAViewAsC(_, _, _, read_stage), tCrScaleAViewAsC);
|
||||
copy(tCsScaleBViewAsC(_, _, _, read_stage), tCrScaleBViewAsC);
|
||||
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) {
|
||||
tCrScaleAViewAsC.data()[0] = tCrScaleAViewAsC.data()[0] * tCrScaleBViewAsC.data()[0];
|
||||
}
|
||||
if constexpr (ScaleMsPerTile == 1) {
|
||||
static_assert(size(RegLayoutScaleAEssential{}) == 1);
|
||||
tCrScaleAViewAsC.data()[0] = __shfl_sync(0xffffffff, tCrScaleAViewAsC.data()[0] * scale_b, 0); // `tCrScaleAViewAsC.data()[0]` are all same in a warp group when `ScaleMsPerTile == 1`.
|
||||
} else {
|
||||
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) {
|
||||
ElementBlockScale scale_b = tCrScaleBViewAsC.data()[0];
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) {
|
||||
for (int i = 0; i < size(tCrScaleAViewAsC); i++) {
|
||||
tCrScaleAViewAsC.data()[i] = tCrScaleAViewAsC.data()[i] * scale_b;
|
||||
}
|
||||
}
|
||||
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) {
|
||||
ElementBlockScale scale_a = tCrScaleAViewAsC.data()[0];
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(tCrScaleBViewAsC); i++) {
|
||||
tCrScaleBViewAsC.data()[i] = tCrScaleBViewAsC.data()[i] * scale_a;
|
||||
}
|
||||
}
|
||||
|
||||
warpgroup_arrive();
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
@@ -609,8 +621,20 @@ struct CollectiveMma<
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
|
||||
// Block scale the accumulators with reg tensor `tCrScaleAViewAsC`
|
||||
accumulation.scale_if_needed(tCrScaleAViewAsC);
|
||||
// Block scale the accumulators with reg tensor `tCrScaleAViewAsC` and `tCrScaleBViewAsC`
|
||||
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) {
|
||||
ElementBlockScale scale_ab = tCrScaleAViewAsC.data()[0];
|
||||
accumulation.scale_if_needed(scale_ab);
|
||||
}
|
||||
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) {
|
||||
accumulation.scale_if_needed(tCrScaleAViewAsC);
|
||||
}
|
||||
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) {
|
||||
accumulation.scale_if_needed(tCrScaleBViewAsC);
|
||||
}
|
||||
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) {
|
||||
accumulation.scale_if_needed(tCrScaleAViewAsC, tCrScaleBViewAsC);
|
||||
}
|
||||
|
||||
++smem_pipe_read;
|
||||
}
|
||||
@@ -632,21 +656,26 @@ struct CollectiveMma<
|
||||
|
||||
int read_stage = smem_pipe_read.index();
|
||||
|
||||
// Load per block scale values from shared memory to registers (at most twice per block along M and exactly once per block along N)
|
||||
scale_b = sScaleB[read_stage];
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) {
|
||||
tCrScaleAViewAsC.data()[i] = tCsScaleAViewAsC(_, _, _, read_stage)(idx2crd(i, RegLayoutScaleAEssential{}));
|
||||
// Load per block scale values from shared memory to registers (at most twice per block along M and/or N)
|
||||
copy(tCsScaleAViewAsC(_, _, _, read_stage), tCrScaleAViewAsC);
|
||||
copy(tCsScaleBViewAsC(_, _, _, read_stage), tCrScaleBViewAsC);
|
||||
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) {
|
||||
tCrScaleAViewAsC.data()[0] = tCrScaleAViewAsC.data()[0] * tCrScaleBViewAsC.data()[0];
|
||||
}
|
||||
if constexpr (ScaleMsPerTile == 1) {
|
||||
static_assert(size(RegLayoutScaleAEssential{}) == 1);
|
||||
tCrScaleAViewAsC.data()[0] = __shfl_sync(0xffffffff, tCrScaleAViewAsC.data()[0] * scale_b, 0); // `tCrScaleAViewAsC.data()[0]` are all same in a warp group when `ScaleMsPerTile == 1`.
|
||||
} else {
|
||||
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) {
|
||||
ElementBlockScale scale_b = tCrScaleBViewAsC.data()[0];
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) {
|
||||
for (int i = 0; i < size(tCrScaleAViewAsC); i++) {
|
||||
tCrScaleAViewAsC.data()[i] = tCrScaleAViewAsC.data()[i] * scale_b;
|
||||
}
|
||||
}
|
||||
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) {
|
||||
ElementBlockScale scale_a = tCrScaleAViewAsC.data()[0];
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(tCrScaleBViewAsC); i++) {
|
||||
tCrScaleBViewAsC.data()[i] = tCrScaleBViewAsC.data()[i] * scale_a;
|
||||
}
|
||||
}
|
||||
|
||||
if (accumulation.prepare_if_needed()) {
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
@@ -667,8 +696,20 @@ struct CollectiveMma<
|
||||
warpgroup_wait<K_PIPE_MMAS>();
|
||||
warpgroup_fence_operand(accumulation());
|
||||
|
||||
// Block scale the accumulators with reg tensor `tCrScaleAViewAsC`
|
||||
accumulation.scale_if_needed(tCrScaleAViewAsC);
|
||||
// Block scale the accumulators with reg tensor `tCrScaleAViewAsC` and `tCrScaleBViewAsC`
|
||||
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) {
|
||||
ElementBlockScale scale_ab = tCrScaleAViewAsC.data()[0];
|
||||
accumulation.scale_if_needed(scale_ab);
|
||||
}
|
||||
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) {
|
||||
accumulation.scale_if_needed(tCrScaleAViewAsC);
|
||||
}
|
||||
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) {
|
||||
accumulation.scale_if_needed(tCrScaleBViewAsC);
|
||||
}
|
||||
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) {
|
||||
accumulation.scale_if_needed(tCrScaleAViewAsC, tCrScaleBViewAsC);
|
||||
}
|
||||
|
||||
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
|
||||
|
||||
@@ -677,7 +718,19 @@ struct CollectiveMma<
|
||||
++smem_pipe_release;
|
||||
}
|
||||
|
||||
accumulation.scale_residue_if_needed(tCrScaleAViewAsC);
|
||||
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) {
|
||||
ElementBlockScale scale_ab = tCrScaleAViewAsC.data()[0];
|
||||
accumulation.scale_residue_if_needed(scale_ab);
|
||||
}
|
||||
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) {
|
||||
accumulation.scale_residue_if_needed(tCrScaleAViewAsC);
|
||||
}
|
||||
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) {
|
||||
accumulation.scale_residue_if_needed(tCrScaleBViewAsC);
|
||||
}
|
||||
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) {
|
||||
accumulation.scale_residue_if_needed(tCrScaleAViewAsC, tCrScaleBViewAsC);
|
||||
}
|
||||
|
||||
warpgroup_fence_operand(accumulation());
|
||||
}
|
||||
|
||||
@@ -117,7 +117,11 @@ struct KernelPtrArrayTmaWarpSpecializedPingpong { };
|
||||
|
||||
// FP8 related policies (including Blocked Scaled Accumulation)
|
||||
template<
|
||||
int ScaleGranularityM = 0 // `ScaleGranularityM` specifies scaling granularity along M, while zero-value `ScaleGranularityM` indicates that scaling granularity is `size<0>(TileShape_MNK{})` along M.
|
||||
// `ScaleGranularityM`/`ScaleGranularityN` specifies scaling granularity along M/N, while zero-value
|
||||
// `ScaleGranularityM`/`ScaleGranularityN` indicates that scaling granularity is
|
||||
// `size<0>(TileShape_MNK{})`/`size<1>(TileShape_MNK{})` along M/N.
|
||||
int ScaleGranularityM = 0,
|
||||
int ScaleGranularityN = 0
|
||||
>
|
||||
struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum: KernelTmaWarpSpecializedCooperative { };
|
||||
|
||||
@@ -302,12 +306,16 @@ template<
|
||||
int Stages_,
|
||||
class ClusterShape_ = Shape<_1,_1,_1>,
|
||||
class KernelSchedule = KernelTmaWarpSpecialized,
|
||||
int ScaleGranularityM = 0 // `ScaleGranularityM` specifies scaling granularity along M, while zero-value `ScaleGranularityM` indicates that scaling granularity is `size<0>(TileShape_MNK{})` along M.
|
||||
// `ScaleGranularityM`/`ScaleGranularityN` specifies scaling granularity along M/N, while zero-value
|
||||
// `ScaleGranularityM`/`ScaleGranularityN` indicates that scaling granularity is
|
||||
// `size<0>(TileShape_MNK{})`/`size<1>(TileShape_MNK{})` along M/N.
|
||||
int ScaleGranularityM = 0,
|
||||
int ScaleGranularityN = 0
|
||||
>
|
||||
struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8
|
||||
: MainloopSm90TmaGmmaWarpSpecialized<Stages_, ClusterShape_, KernelSchedule> {
|
||||
static_assert(
|
||||
cute::is_same_v<KernelSchedule, KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM>>,
|
||||
cute::is_same_v<KernelSchedule, KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM, ScaleGranularityN>>,
|
||||
"KernelSchedule must be one of the warp specialized policies");
|
||||
};
|
||||
|
||||
|
||||
@@ -397,8 +397,6 @@ public:
|
||||
// An example of an unneeded threadblock is one that is assigned to compute in the upper
|
||||
// portion of a Rank2K kernel filled with mode kLower.
|
||||
//
|
||||
// TODO: Consider pushing these checks into ProblemVisitor to avoid spuriously
|
||||
// returning from `next_tile()`.
|
||||
//
|
||||
|
||||
// Early exit if threadblock is out of range
|
||||
|
||||
@@ -1131,6 +1131,10 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
else {
|
||||
// Register reconfiguration
|
||||
arch::warpgroup_reg_dealloc<GenericRegisterRequirement>();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -29,8 +29,6 @@
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
@@ -564,20 +562,21 @@ public:
|
||||
// Sync deallocation status between MMA warps of peer CTAs
|
||||
arch::ClusterBarrier& tmem_deallocation_result_barrier = shared_storage.pipelines.tmem_dealloc;
|
||||
[[maybe_unused]] uint32_t dealloc_barrier_phase = 0;
|
||||
if constexpr(!IsOverlappingAccum) {
|
||||
if (WarpCategory::MMA == warp_category && has_mma_peer_cta && lane_predicate) {
|
||||
tmem_deallocation_result_barrier.init(NumMMAThreads);
|
||||
if (WarpCategory::MMA == warp_category) {
|
||||
if constexpr(!IsOverlappingAccum) {
|
||||
if (has_mma_peer_cta && lane_predicate) {
|
||||
tmem_deallocation_result_barrier.init(NumMMAThreads);
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (has_mma_peer_cta && lane_predicate) {
|
||||
tmem_deallocation_result_barrier.init(NumEpilogueThreads*2);
|
||||
}
|
||||
else if (lane_predicate) {
|
||||
tmem_deallocation_result_barrier.init(NumEpilogueThreads);
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (WarpCategory::MMA == warp_category && has_mma_peer_cta && lane_predicate) {
|
||||
tmem_deallocation_result_barrier.init(NumEpilogueThreads*2);
|
||||
}
|
||||
else if (WarpCategory::MMA == warp_category && lane_predicate) {
|
||||
tmem_deallocation_result_barrier.init(NumEpilogueThreads);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Initialize smem barrier for prologue throttling. Epilogue warps are stalled until the prologue finishes.
|
||||
arch::ClusterBarrier& epilogue_throttle_barrier = shared_storage.pipelines.epilogue_throttle;
|
||||
@@ -699,7 +698,6 @@ public:
|
||||
epilogue_throttle_barrier.arrive();
|
||||
|
||||
if constexpr (IsSchedDynamicPersistent) {
|
||||
|
||||
// Whether a new CLC query must be performed.
|
||||
// See comment below where this variable is updated for a description of
|
||||
// why this variable is needed.
|
||||
@@ -738,7 +736,6 @@ public:
|
||||
work_tile_info = next_work_tile_info;
|
||||
} while (work_tile_info.is_valid());
|
||||
clc_pipeline.producer_tail(clc_pipe_producer_state);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -963,7 +960,6 @@ public:
|
||||
epi_load_pipe_consumer_state = load_state_next;
|
||||
epi_store_pipe_producer_state = store_state_next;
|
||||
accumulator_pipe_consumer_state = acc_state_next;
|
||||
|
||||
do_tail_store = true;
|
||||
}
|
||||
work_tile_info = next_work_tile_info;
|
||||
|
||||
@@ -1057,6 +1057,10 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
else {
|
||||
// Register reconfiguration
|
||||
arch::warpgroup_reg_dealloc<GenericRegisterRequirement>();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -783,7 +783,6 @@ private:
|
||||
int L_idx, Split_idx;
|
||||
params_.sk_params_.divmod_splits_(L_idx, Split_idx, work_tile_info.L_idx);
|
||||
|
||||
// TODO: Modularize the SM90 scheduler to pull out and reuse this redundant code
|
||||
int additional_k_tiles = 0;
|
||||
int split_start_offset = params_.sk_params_.big_units_;
|
||||
|
||||
|
||||
@@ -455,8 +455,9 @@ public:
|
||||
auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K)
|
||||
|
||||
TileScheduler scheduler{params.scheduler};
|
||||
auto work_tile_info = scheduler.initial_work_tile_info(ClusterShape{});
|
||||
|
||||
// Declare work_tile_info, then define it in each of warps that use it.
|
||||
typename TileScheduler::WorkTileInfo work_tile_info;
|
||||
|
||||
// In a warp specialized kernel, collectives expose data movement and compute operations separately
|
||||
CollectiveMainloop collective_mainloop;
|
||||
|
||||
@@ -474,6 +475,7 @@ public:
|
||||
cluster_wait_fn();
|
||||
|
||||
if (warp_group_role == WarpGroupRole::Producer) {
|
||||
work_tile_info = scheduler.initial_work_tile_info(ClusterShape{});
|
||||
cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
|
||||
|
||||
// Mainloop Producer Warp
|
||||
@@ -578,6 +580,7 @@ public:
|
||||
} // Producer Warp Group End
|
||||
|
||||
else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) {
|
||||
work_tile_info = scheduler.initial_work_tile_info(ClusterShape{});
|
||||
cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
|
||||
|
||||
CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue);
|
||||
|
||||
@@ -265,7 +265,7 @@ struct PersistentTileSchedulerSm90Params {
|
||||
}
|
||||
// In case the maximum number of clusters that could co-exist on the target device is
|
||||
// already calculated using cudaOccupancyMaxActiveClusters
|
||||
else if (max_active_clusters != 0) {
|
||||
else if (max_active_clusters != 0 && max_active_clusters * cluster_size <= sm_count) {
|
||||
if (raster_order == RasterOrder::AlongN) {
|
||||
launch_grid.y = max_active_clusters * cluster_shape.n();
|
||||
}
|
||||
@@ -1204,6 +1204,7 @@ struct PersistentTileSchedulerSm90StreamKParams {
|
||||
KernelHardwareInfo new_hw_info;
|
||||
new_hw_info.device_id = hw_info.device_id;
|
||||
new_hw_info.sm_count = hw_info.sm_count;
|
||||
new_hw_info.max_active_clusters = hw_info.max_active_clusters;
|
||||
if (new_hw_info.sm_count <= 0) {
|
||||
CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
|
||||
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
|
||||
@@ -1787,7 +1788,7 @@ struct PersistentTileSchedulerSm90GroupParams {
|
||||
}
|
||||
// In case the maximum number of clusters that could co-exist on the target device is
|
||||
// already calculated using cudaOccupancyMaxActiveClusters
|
||||
else if (max_active_clusters != 0) {
|
||||
else if (max_active_clusters != 0 && max_active_clusters * cluster_size <= sm_count) {
|
||||
if (raster_order == RasterOrder::AlongN) {
|
||||
launch_grid.y = max_active_clusters * cluster_shape.n();
|
||||
}
|
||||
@@ -2499,6 +2500,7 @@ struct PersistentTileSchedulerSm100GroupParams {
|
||||
bool is_static_cluster_shape = false) {
|
||||
|
||||
int const sm_count = hw_info.sm_count;
|
||||
int const max_active_clusters = hw_info.max_active_clusters;
|
||||
|
||||
// Round up to nearest multiple of swizzle_size along each mode
|
||||
auto log_swizzle_size = get_log_swizzle_size(problem_blocks.x, problem_blocks.y, max_swizzle_size);
|
||||
@@ -2542,6 +2544,18 @@ struct PersistentTileSchedulerSm100GroupParams {
|
||||
launch_grid.x = possibly_truncate(sm_count, problem_blocks_total);
|
||||
}
|
||||
}
|
||||
// In case the maximum number of clusters that could co-exist on the target device is
|
||||
// already calculated using cudaOccupancyMaxActiveClusters
|
||||
else if (max_active_clusters != 0 && max_active_clusters * cluster_size <= sm_count) {
|
||||
if (raster_order == RasterOrder::AlongN) {
|
||||
launch_grid.y = max_active_clusters * cluster_shape.n();
|
||||
}
|
||||
else {
|
||||
launch_grid.x = max_active_clusters * cluster_shape.m();
|
||||
}
|
||||
CUTLASS_TRACE_HOST("get_grid_shape(): Proposed GridDims by the scheduler using cudaOccupancyMaxActiveClusters = "
|
||||
"(" << launch_grid.x << ", " << launch_grid.y << ", " << launch_grid.z << ")\n");
|
||||
}
|
||||
else {
|
||||
constexpr int max_sm_per_gpc = 20;
|
||||
int cta_per_device = get_max_cta_occupancy(max_sm_per_gpc, cluster_shape, sm_count);
|
||||
|
||||
@@ -142,7 +142,6 @@ class MmaVoltaTensorOpMultiplicandTileIterator<
|
||||
"Shape of warp-level Mma must be divisible by operator shape.");
|
||||
|
||||
// Shape of one individual LDS.128
|
||||
// TODO: 32 and 4 are hardcoded, 32-by-4 is logical shape
|
||||
using LdsShape = layout::PitchLinearShape<
|
||||
32,
|
||||
4
|
||||
@@ -458,7 +457,6 @@ class MmaVoltaTensorOpMultiplicandTileIterator<
|
||||
"Shape of warp-level Mma must be divisible by operator shape.");
|
||||
|
||||
// Shape of one individual LDS
|
||||
// TODO: remove hardcoded 32 and 4
|
||||
using LdsShape = layout::PitchLinearShape<
|
||||
32,
|
||||
4
|
||||
|
||||
@@ -995,7 +995,6 @@ public:
|
||||
CUTLASS_DEVICE
|
||||
MmaTensorOpMultiplicandTileIterator &add_tile_offset_negative(TensorCoord const &tile_offset) {
|
||||
|
||||
// TODO: fix this if it becomes an issue during warp it reset
|
||||
add_tile_offset(tile_offset);
|
||||
|
||||
return *this;
|
||||
|
||||
@@ -41,7 +41,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <cuda/std/cassert>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
|
||||
@@ -82,7 +82,7 @@ struct get_unpacked_element_type {
|
||||
#include "cutlass/tfloat32.h"
|
||||
#include "cutlass/float8.h"
|
||||
#include "cutlass/uint128.h"
|
||||
#include "cutlass/exmy_base.h"
|
||||
#include "cutlass/float_subbyte.h"
|
||||
#include "cutlass/exmy_base.h"
|
||||
#include "cutlass/float_subbyte.h"
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
Reference in New Issue
Block a user