More updates for 3.1 (#958)

* Updates for 3.1

* Minor change

* doc link fix

* Minor updates
This commit is contained in:
ANIKET SHIVAM
2023-05-24 07:17:16 -07:00
committed by GitHub
parent 13f413493a
commit f079619f5e
48 changed files with 1611 additions and 1858 deletions

View File

@@ -3,7 +3,7 @@
## [3.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.1.0) (2023-04-14)
* New CUTLASS Python interface that aims to provide an ease-of-use interface for instantiating, emitting, compiling, and running CUTLASS kernels via Python. More details [here](/python/README.md) and new [examples](/examples/python).
* New [efficient epilogues](test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative.cu#L783) for FP16 datatype using TMA for Hopper.
* New [efficient epilogues](test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative.cu#L783) using TMA for Hopper.
* Support for [fused epilogues](test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu), such Bias, ReLU and GELU, using the new efficient epilogues.
* New [warp-specialized TensorFloat-32 (TF32) GEMM kernels](test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu) targeting Hopper TMA.
* New [*warp-specialized persistent cooperative*](include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp) kernel design that allows for larger tile sizes and improves performance on Hopper.
@@ -12,6 +12,11 @@
* Profiler support for overriding kernel and epilogue builder auto schedules for 3.x API kernels, allowing specific policies to be run in the CUTLASS profiler.
* Performance optimizations for the [*warp-specialized persistent ping-pong*](include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp) kernel.
* Changes to the [GEMM API 3.x](media/docs/gemm_api_3x.md), involving the host-facing arguments and the underlying `Params` structs.
* [FMHA Backward Pass](examples/41_fused_multi_head_attention/fused_multi_head_attention_backward.cu) from Meta xFormers.
* [Streamk GEMM with Broadcast](examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu) enables epilogue broadcast with StreamK GEMM.
* [Batched B2B GEMM](examples/13_two_tensor_op_fusion) now can run multiple Back-to-Back GEMM with the same problem size in parallel.
* [Batched Strided GEMV](test/unit/gemm/device/gemv.cu) support both row major and column major input matrix.
* [Permute + GEMM fusion](examples/39_gemm_permute) can fuse Permute with following GEMM now. Before, we only support fusing GEMM with Permute in the epilogue.
* The GitHub branch is renamed from `master` to `main` in this release.
* Optimal performance using [**CUDA 12.1**](https://developer.nvidia.com/cuda-downloads)
* Updates and bugfixes from the community (thanks!)

View File

@@ -46,7 +46,7 @@ In addition to GEMMs, CUTLASS implements high-performance convolution via the im
CUTLASS 3.1 is an update to CUTLASS adding:
- New CUTLASS Python interface that aims to provide an ease-of-use interface for instantiating, emitting, compiling, and running CUTLASS kernels via Python. More details [here](/python/README.md) and new [examples](/examples/python).
- New [efficient epilogues](test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative.cu#L783) for FP16 datatype using TMA for Hopper.
- New [efficient epilogues](test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative.cu#L783) using TMA for Hopper.
- Support for [fused epilogues](test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu), such Bias, ReLU and GELU, using the new efficient epilogues.
- New [warp-specialized TensorFloat-32 (TF32) GEMM kernels](test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu) targeting Hopper TMA.
- New [*warp-specialized persistent cooperative*](include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp) kernel design that improves performance on Hopper.
@@ -54,6 +54,12 @@ CUTLASS 3.1 is an update to CUTLASS adding:
- New Epilogue builders. Similar to mainloop builders (see [example 49](/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu)), epilogue builders aim to generate the best-possible epilogue while exposing incremental opt-ins for greater customization.
- Profiler support for overriding kernel and epilogue builder auto schedules for 3.x API kernels, allowing specific policies to be run in the CUTLASS profiler.
- Changes to the [GEMM API 3.x](media/docs/gemm_api_3x.md), involving the host-facing arguments and the underlying `Params` structs.
- [FMHA Backward Pass](examples/41_fused_multi_head_attention/fused_multi_head_attention_backward.cu) from Meta xFormers.
- [Streamk GEMM with Broadcast](examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu) enables epilogue broadcast with StreamK GEMM.
- [Batched B2B GEMM](examples/13_two_tensor_op_fusion) now can run multiple Back-to-Back GEMM with the same problem size in parallel.
- [Batched Strided GEMV](test/unit/gemm/device/gemv.cu) support both row major and column major input matrix.
- [Permute + GEMM fusion](examples/39_gemm_permute) can fuse Permute with following GEMM now. Before, we only support fusing GEMM with Permute in the epilogue.
- *Announcement*:
- The GitHub branch is renamed from `master` to `main` in this release.
- A slight modification has been made to the ordering of arguments passed in to epilogues in 3.x kernels.

View File

@@ -641,6 +641,11 @@ public:
}
// Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();
// 2nd Gemm
/// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile
@@ -871,7 +876,10 @@ public:
}
// Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();
}
};

View File

@@ -664,6 +664,11 @@ public:
}
// Insert fence and wait for all outstanding cp.async operations to commit.
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();
/// Epilogue for the first Implicit Gemm
Epilogue0 epilogue0;
@@ -855,7 +860,10 @@ public:
}
// Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();
}
};

View File

@@ -759,13 +759,10 @@ public:
accum1 = plus_accum(accum1, tmp_accum1);
}
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
// commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();
}
// commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();
}
};

View File

@@ -461,11 +461,6 @@ Result run(std::string description, Options &options)
std::cout << " GFLOPs: " << result.gflops << std::endl;
}
// TODO: uncomment when results match
//if (!result.passed) {
// exit(-1);
//}
return result;
}

View File

@@ -499,7 +499,7 @@ flatten(T const& t)
namespace detail {
// Shortcut around tuple_cat for common insert/remove/repeat cases
// Shortcut around cute::tuple_cat for common insert/remove/repeat cases
template <class T, class X, int... I, int... J, int... K>
CUTE_HOST_DEVICE constexpr
auto

View File

@@ -623,7 +623,7 @@ partition_shape_C(TiledMMA<Args...> const& mma, Shape_MN const& shape_MN)
auto V = shape<1>(typename TiledMMA<Args...>::AtomLayoutC_TV{});
auto M = shape_div(size<0>(shape_MN), size<0>(atomMNK) * size<1>(thrVMNK));
auto N = shape_div(size<1>(shape_MN), size<1>(atomMNK) * size<2>(thrVMNK));
return tuple_cat(make_shape(V,M,N), take<2,R>(shape_MN));
return cute::tuple_cat(make_shape(V,M,N), take<2,R>(shape_MN));
}
template <class... Args, class Shape_MN>
@@ -651,7 +651,7 @@ partition_shape_A(TiledMMA<Args...> const& mma, Shape_MK const& shape_MK)
auto V = shape<1>(typename TiledMMA<Args...>::AtomLayoutA_TV{});
auto M = shape_div(size<0>(shape_MK), size<0>(atomMNK) * size<1>(thrVMNK));
auto K = shape_div(size<1>(shape_MK), size<2>(atomMNK) * size<3>(thrVMNK));
return tuple_cat(make_shape(V,M,K), take<2,R>(shape_MK));
return cute::tuple_cat(make_shape(V,M,K), take<2,R>(shape_MK));
}
template <class... Args, class Shape_NK>
@@ -666,7 +666,7 @@ partition_shape_B(TiledMMA<Args...> const& mma, Shape_NK const& shape_NK)
auto V = shape<1>(typename TiledMMA<Args...>::AtomLayoutB_TV{});
auto N = shape_div(size<0>(shape_NK), size<1>(atomMNK) * size<2>(thrVMNK));
auto K = shape_div(size<1>(shape_NK), size<2>(atomMNK) * size<3>(thrVMNK));
return tuple_cat(make_shape(V,N,K), take<2,R>(shape_NK));
return cute::tuple_cat(make_shape(V,N,K), take<2,R>(shape_NK));
}
//

View File

@@ -46,8 +46,14 @@ namespace cute
using dim3 = ::dim3;
// MSVC doesn't define its C++ version macro to match
// its C++ language version. This means that when
// building with MSVC, dim3 isn't constexpr-friendly.
template <size_t I>
CUTE_HOST_DEVICE constexpr
CUTE_HOST_DEVICE
#if ! defined(_MSC_VER)
constexpr
#endif
uint32_t& get(dim3& a)
{
static_assert(I < 3, "Index out of range");
@@ -63,7 +69,10 @@ uint32_t& get(dim3& a)
}
template <size_t I>
CUTE_HOST_DEVICE constexpr
CUTE_HOST_DEVICE
#if ! defined(_MSC_VER)
constexpr
#endif
uint32_t const& get(dim3 const& a)
{
static_assert(I < 3, "Index out of range");
@@ -79,7 +88,10 @@ uint32_t const& get(dim3 const& a)
}
template <size_t I>
CUTE_HOST_DEVICE constexpr
CUTE_HOST_DEVICE
#if ! defined(_MSC_VER)
constexpr
#endif
uint32_t&& get(dim3&& a)
{
static_assert(I < 3, "Index out of range");

View File

@@ -86,18 +86,11 @@ constexpr auto
sm90_compute_tile_shape_or_override() {
if constexpr (cute::is_same_v<EpilogueTileType, EpilogueTileAuto>) {
constexpr int SmemAlloc = 4096;
if constexpr (detail::sm90_is_cooperative_v<Schedule>) {
constexpr int M = 128;
constexpr int N = SmemAlloc / (M * sizeof(Element));
return make_shape(Int<M>{}, Int<N>{});
return Shape<_128,_16>{};
}
else if constexpr (detail::sm90_is_warp_specialized_v<Schedule>) {
constexpr int M = 64;
constexpr int N = SmemAlloc / (M * sizeof(Element));
return make_shape(Int<M>{}, Int<N>{});
return Shape<_64,_32>{};
}
else {
static_assert(cutlass::detail::dependent_false<Schedule>, "Unsupported schedule.");
@@ -167,8 +160,8 @@ template <
class EpilogueTileType,
class ElementAccumulator,
class ElementCompute,
class ElementC,
class GmemLayoutTagC,
class ElementC_,
class GmemLayoutTagC_,
int AlignmentC,
class ElementD,
class GmemLayoutTagD,
@@ -178,6 +171,11 @@ template <
class DispatchPolicy
>
struct TmaBuilderImpl {
// Passing void C disables source load
using ElementC = cute::conditional_t<cute::is_void_v<ElementC_>,ElementD,ElementC_>; // prevents void ref breakages
using GmemLayoutTagC = cute::conditional_t<cute::is_void_v<ElementC_>,GmemLayoutTagD,GmemLayoutTagC_>;
using GmemStrideTypeC = gemm::TagToStrideC_t<GmemLayoutTagC>;
using GmemStrideTypeD = gemm::TagToStrideC_t<GmemLayoutTagD>;
@@ -188,7 +186,7 @@ struct TmaBuilderImpl {
DispatchPolicy,
TileShape_MNK,
EpilogueTile_MN,
ElementC,
ElementC_, // Need to pass void through to expose via GemmUniversal
GmemStrideTypeC,
ElementD,
GmemStrideTypeD,
@@ -246,8 +244,9 @@ struct CollectiveBuilder<
static constexpr thread::ScaleType::Kind ScaleType = cute::is_void_v<ElementC_> ?
thread::ScaleType::OnlyAlphaScaling : thread::ScaleType::Default;
static constexpr int FragmentSize = 1;
using ThreadOp = thread::LinearCombination<
ElementD, 1, ElementAccumulator, ElementCompute,
ElementD, FragmentSize, ElementAccumulator, ElementCompute,
ScaleType, FloatRoundStyle::round_to_nearest, ElementC>;
using CollectiveOp = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter<
@@ -267,7 +266,7 @@ template <
class ElementAccumulator,
class ElementCompute,
class ElementC_,
class GmemLayoutTagC_,
class GmemLayoutTagC,
int AlignmentC,
class ElementD,
class GmemLayoutTagD,
@@ -283,7 +282,7 @@ struct CollectiveBuilder<
ElementAccumulator,
ElementCompute,
ElementC_,
GmemLayoutTagC_,
GmemLayoutTagC,
AlignmentC,
ElementD,
GmemLayoutTagD,
@@ -292,43 +291,26 @@ struct CollectiveBuilder<
cute::enable_if_t<cute::is_same_v<Schedule, TmaWarpSpecialized> ||
cute::is_same_v<Schedule, TmaWarpSpecializedCooperative> >> {
public:
// Passing void C disables source load
using ElementC = cute::conditional_t<cute::is_void_v<ElementC_>,
ElementD, ElementC_>; // prevents cute breakages
using GmemLayoutTagC = cute::conditional_t<cute::is_void_v<ElementC_>,
GmemLayoutTagD, GmemLayoutTagC_>;
using ElementC = cute::conditional_t<cute::is_void_v<ElementC_>,ElementD,ElementC_>; // prevents void ref breakages
static constexpr thread::ScaleType::Kind ScaleType = cute::is_void_v<ElementC_> ?
thread::ScaleType::OnlyAlphaScaling : thread::ScaleType::Default;
static constexpr int FragmentSize = 4;
using ThreadOp = thread::LinearCombination<
ElementD, AlignmentD, ElementAccumulator, ElementCompute,
ElementD, FragmentSize, ElementAccumulator, ElementCompute,
ScaleType, FloatRoundStyle::round_to_nearest, ElementC>;
using GmemStrideTypeC = gemm::TagToStrideC_t<GmemLayoutTagC>;
using GmemStrideTypeD = gemm::TagToStrideC_t<GmemLayoutTagD>;
using EpilogueTile_MN = decltype(detail::sm90_compute_tile_shape_or_override<
ElementD, EpilogueTileType, Schedule>());
private:
static constexpr int StagesC = 1;
static constexpr int StagesD = 2;
static constexpr bool DisableReuseSmemC = true;
using CollectiveOp = cutlass::epilogue::collective::CollectiveEpilogue<
cutlass::epilogue::Sm90TmaWarpSpecialized<StagesC,StagesD,DisableReuseSmemC>,
TileShape_MNK,
EpilogueTile_MN,
ElementC_, // need to pass void to expose via GemmUniversal
GmemStrideTypeC,
ElementD,
GmemStrideTypeD,
ThreadOp,
SM90_TMA_LOAD,
decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom<GmemStrideTypeC, ElementC, TileShape_MNK>()),
decltype(detail::sm90_get_smem_load_op_for_source<GmemStrideTypeC, ElementC>()),
SM90_TMA_STORE,
decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom<GmemStrideTypeD, ElementD, EpilogueTile_MN>()),
decltype(detail::sm90_get_smem_store_op_for_accumulator<GmemStrideTypeD, ElementD>())
>;
using Impl = detail::TmaBuilderImpl<
TileShape_MNK, ClusterShape_MNK, EpilogueTileType, ElementAccumulator, ElementCompute,
ElementC_, GmemLayoutTagC, AlignmentC, ElementD, GmemLayoutTagD, AlignmentD,
Schedule, ThreadOp, cutlass::epilogue::Sm90TmaWarpSpecialized<StagesC,StagesD, DisableReuseSmemC>>;
public:
using CollectiveOp = typename Impl::CollectiveOp;
};
// Auto builder
@@ -427,11 +409,11 @@ struct CollectiveBuilder<
Schedule,
cute::enable_if_t<cute::is_base_of_v<TmaWarpSpecializedElementwiseBase, Schedule> ||
cute::is_base_of_v<TmaWarpSpecializedCooperativeElementwiseBase, Schedule> >> {
public:
static constexpr int FragmentSize = 4;
using ThreadOp = thread::LinearCombinationGeneric<
Schedule::ActivationFunctor,
ElementD, AlignmentD,
ElementD, FragmentSize,
ElementAccumulator, ElementCompute, Schedule::Scale,
Schedule::Round>;
@@ -455,7 +437,7 @@ template <
class EpilogueTileType,
class ElementAccumulator,
class ElementCompute,
class ElementC,
class ElementC_,
class GmemLayoutTagC,
int AlignmentC,
class ElementD,
@@ -471,7 +453,7 @@ struct CollectiveBuilder<
EpilogueTileType,
ElementAccumulator,
ElementCompute,
ElementC,
ElementC_,
GmemLayoutTagC,
AlignmentC,
ElementD,
@@ -480,10 +462,14 @@ struct CollectiveBuilder<
Schedule,
cute::enable_if_t<cute::is_base_of_v<TmaWarpSpecializedBiasElementwiseBase, Schedule> ||
cute::is_base_of_v<TmaWarpSpecializedCooperativeBiasElementwiseBase, Schedule> >> {
private:
// Passing void C disables source load
using ElementC = cute::conditional_t<cute::is_void_v<ElementC_>, ElementD, ElementC_>; // prevents void ref breakages
public:
static constexpr int FragmentSize = 4;
using ThreadOp = thread::LinearCombinationBiasElementwise<
ElementC, ElementAccumulator, ElementCompute, ElementD, typename Schedule::ElementT, AlignmentD,
ElementC, ElementAccumulator, ElementCompute, ElementD, typename Schedule::ElementT, FragmentSize,
typename Schedule::ActivationFunctor<ElementCompute>, typename Schedule::BiasOp<ElementCompute>,
Schedule::StoreT, typename Schedule::ElementBias>;
@@ -492,7 +478,7 @@ private:
static constexpr int StagesD = 2;
using Impl = detail::TmaBuilderImpl<
TileShape_MNK, ClusterShape_MNK, EpilogueTileType, ElementAccumulator, ElementCompute,
ElementC, GmemLayoutTagC, AlignmentC, ElementD, GmemLayoutTagD, AlignmentD,
ElementC_, GmemLayoutTagC, AlignmentC, ElementD, GmemLayoutTagD, AlignmentD,
Schedule, ThreadOp, cutlass::epilogue::Sm90TmaWarpSpecializedBiasElementwise<StagesC,StagesD>>;
public:
@@ -540,8 +526,9 @@ struct CollectiveBuilder<
static constexpr thread::ScaleType::Kind ScaleType = cute::is_void_v<ElementC_> ?
thread::ScaleType::OnlyAlphaScaling : thread::ScaleType::Default;
static constexpr int FragmentSize = 1;
using ThreadOp = thread::LinearCombination<
ElementD, 1, ElementAccumulator, ElementCompute,
ElementD, FragmentSize, ElementAccumulator, ElementCompute,
ScaleType, FloatRoundStyle::round_to_nearest, ElementC>;
using CollectiveOp = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter<

View File

@@ -75,6 +75,9 @@ public:
using ElementD = typename ThreadEpilogueOp::ElementD;
using StrideD = StrideD_;
using GmemTiledCopyC = void;
using GmemTiledCopyD = void;
static const int kOutputAlignment = ThreadEpilogueOp::kCount;
using AlignmentType = typename cute::uint_bit<sizeof_bits<ElementOutput>::value * kOutputAlignment>::type;

View File

@@ -48,8 +48,8 @@ template <
int StagesC_,
int StagesD_,
bool DisableSmemReuseC_,
class BlockTileShape_, // (BLK_M,BLK_N,BLK_K)
class EpilogueTile_, // (EPI_TILE_M,EPI_TILE_N) per-collective
class BlockTileShape_, // (BLK_M,BLK_N,BLK_K)
class EpilogueTileShape_, // (EPI_TILE_M,EPI_TILE_N)
class ElementC_,
class StrideC_,
class ElementD_,
@@ -65,7 +65,7 @@ template <
class CollectiveEpilogue<
Sm90TmaWarpSpecialized<StagesC_,StagesD_,DisableSmemReuseC_>,
BlockTileShape_,
EpilogueTile_,
EpilogueTileShape_,
ElementC_,
StrideC_,
ElementD_,
@@ -84,7 +84,7 @@ public:
//
using DispatchPolicy = Sm90TmaWarpSpecialized<StagesC_,StagesD_,DisableSmemReuseC_>;
using BlockTileShape = BlockTileShape_;
using EpilogueTile = EpilogueTile_;
using EpilogueTileShape = EpilogueTileShape_;
using ThreadEpilogueOp = ThreadEpilogueOp_;
using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator;
using ElementCompute = typename ThreadEpilogueOp::ElementCompute;
@@ -103,24 +103,27 @@ public:
using SmemLayoutAtomD = SmemLayoutAtomD_;
using CopyOpR2S = CopyOpR2S_;
using GmemTiledCopyC = SM90_TMA_LOAD;
using GmemTiledCopyD = SM90_TMA_STORE;
constexpr static int kOutputAlignment = ThreadEpilogueOp::kCount;
constexpr static bool iskThreadEpilogueOpWithBias = detail::IsThreadEpilogueOpWithBias<ThreadEpilogueOp>::value;
using AlignmentType = typename uint_bit<sizeof_bits<ElementOutput>::value * kOutputAlignment>::type;
static_assert(sizeof(ElementD) == 2, "Only 16b output supported for now");
static_assert(!is_layout<EpilogueTile>::value && is_tuple<EpilogueTile>::value, "EpilogueTile must be a cute::Tile or cute::Shape");
static_assert(!is_layout<EpilogueTileShape>::value && is_tuple<EpilogueTileShape>::value, "EpilogueTileShape must be a cute::Shape");
static_assert(rank(BlockTileShape{}) == 3, "BlockTileShape must be rank-3: [BLK_M,BLK_N,BLK_K]");
static_assert(rank(EpilogueTile{}) == 2, "EpilogueTile must be rank-2: [EPI_TILE_M,EPI_TILE_N]");
static_assert(rank(EpilogueTileShape{}) == 2, "EpilogueTileShape must be rank-2: [EPI_TILE_M,EPI_TILE_N]");
static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
private:
using InternalElementC = std::conditional_t<std::is_void_v<ElementC>,ElementD,ElementC>; // prevents void ref breakages
using InternalElementC = cute::conditional_t<cute::is_void_v<ElementC>,ElementD,ElementC>; // prevents void ref breakages
constexpr static int StagesC = StagesC_;
constexpr static int StagesD = StagesD_;
constexpr static bool is_source_supported = ThreadEpilogueOp::kScale == cutlass::epilogue::thread::ScaleType::Default ||
ThreadEpilogueOp::kScale == cutlass::epilogue::thread::ScaleType::NoBetaScaling;
static_assert((std::is_void_v<ElementC> && not is_source_supported) || (not std::is_void_v<ElementC> && is_source_supported));
static_assert((cute::is_void_v<ElementC> && not is_source_supported) || (not cute::is_void_v<ElementC> && is_source_supported),
"Inconsistent C type and Scale kind");
// internal optimization to reuse C shared memory for storing D
using SmemLayoutAtomBitsC = decltype(downcast<sizeof_bits<InternalElementC>::value>(SmemLayoutAtomC{}));
@@ -131,21 +134,14 @@ private:
StrideC{} == StrideD{} &&
cute::is_same_v<SmemLayoutAtomBitsC,SmemLayoutAtomBitsD>;
// Find the max contiguous layout usable by TMA (if EpilogueTile is a by-mode tiler)
using SmemLayoutTmaD = decltype(tile_to_shape(
SmemLayoutAtomD{},
make_shape(max_common_vector(make_layout(get<0>(EpilogueTile{})),make_layout(get<0>(EpilogueTile{}))),
max_common_vector(make_layout(get<1>(EpilogueTile{})),make_layout(get<1>(EpilogueTile{})))),
cute::conditional_t<get<0>(StrideD{}) == 1, Step<_2,_1>, Step<_1,_2>>{} ));
public:
using SmemLayoutC = decltype(tile_to_shape(
SmemLayoutAtomC{},
make_shape(size<0>(BlockTileShape{}), size<1>(BlockTileShape{}), Int<StagesC>{}),
cute::conditional_t<get<0>(StrideC{}) == 1, Step<_2,_1,_3>, Step<_1,_2,_3>>{} ));
using SmemLayoutD = decltype(tile_to_shape(
SmemLayoutTmaD{},
make_shape(size<0>(shape(EpilogueTile{})), size<1>(shape(EpilogueTile{})), Int<StagesD>{}),
SmemLayoutAtomD{},
make_shape(size<0>(EpilogueTileShape{}), size<1>(EpilogueTileShape{}), Int<StagesD>{}),
cute::conditional_t<get<0>(StrideD{}) == 1, Step<_2,_1,_3>, Step<_1,_2,_3>>{} ));
// TMA pipeline for loading C
@@ -194,7 +190,7 @@ public:
CopyOpS2G{},
make_tensor(static_cast<ElementD const*>(nullptr),
repeat_like(StrideD{}, int32_t(0)), StrideD{}),
SmemLayoutTmaD{}));
SmemLayoutD{}(_,_,0)));
typename ThreadEpilogueOp::Params thread{};
TMA_C tma_load_c;
@@ -210,23 +206,32 @@ public:
to_underlying_arguments(
ProblemShape const& problem_shape,
Arguments const& args,
[[maybe_unused]] void* workspace)
{
[[maybe_unused]] void* workspace) {
// Optionally append _1s until problem shape is rank-4 in case its is only rank-3 (MNK)
auto problem_shape_MNKL = append<4>(problem_shape, Int<1>{});
auto M = get<0>(problem_shape_MNKL);
auto N = get<1>(problem_shape_MNKL);
auto L = get<3>(problem_shape_MNKL);
Tensor tensor_c = make_tensor(static_cast<InternalElementC const*>(args.ptr_C), make_layout(make_shape(M,N,L), args.dC));
typename Params::TMA_C tma_load_c = [&]() {
if constexpr (not cute::is_void_v<ElementC>) {
Tensor tensor_c = make_tensor(static_cast<InternalElementC const*>(args.ptr_C), make_layout(make_shape(M,N,L), args.dC));
return make_tma_copy(
CopyOpG2S{},
tensor_c,
SmemLayoutC{}(_,_,0));
}
else {
return typename Params::TMA_C{};
}
}();
Tensor tensor_d = make_tensor(args.ptr_D, make_layout(make_shape(M,N,L), args.dD));
typename Params::TMA_C tma_load_c = make_tma_copy(
CopyOpG2S{},
tensor_c,
SmemLayoutC{}(_,_,0));
typename Params::TMA_D tma_store_d = make_tma_copy(
CopyOpS2G{},
tensor_d,
SmemLayoutTmaD{});
SmemLayoutD{}(_,_,0));
return {
args.thread,
tma_load_c,
@@ -378,8 +383,8 @@ public:
auto L = get<3>(problem_shape_mnkl);
auto mma_tile_m = size<0>(typename TiledMma::TiledShape_MNK{});
auto mma_tile_n = size<1>(typename TiledMma::TiledShape_MNK{});
auto epi_tile_m = size<0>(shape(EpilogueTile{}));
auto epi_tile_n = size<1>(shape(EpilogueTile{}));
auto epi_tile_m = size<0>(EpilogueTileShape{});
auto epi_tile_n = size<1>(EpilogueTileShape{});
// Represent the full output tensor
Tensor mD_mnl = params.tma_store_d.get_tma_tensor(make_shape(M,N,L)); // (m,n,l)
@@ -396,11 +401,14 @@ public:
SmemLayoutD{});
// Tile thread(b)lock tensors by (E)pilogue output tile shape (bE)
Tensor bEsC = local_tile(sC, EpilogueTile{}, _); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
Tensor bEgD = local_tile(gD, EpilogueTile{}, _); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
Tensor bEsC = local_tile(sC, EpilogueTileShape{}, _); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
Tensor bEgD = local_tile(gD, EpilogueTileShape{}, _); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
// Partition for register to smem copy (tRS_)
TiledCopy tiled_r2s = make_tiled_copy_C_atom(Copy_Atom<CopyOpR2S,ElementD>{}, tiled_mma);
using CopyAtomR2S = cute::conditional_t<cute::is_same_v<CopyOpR2S,DefaultCopy>,
Copy_Atom<UniversalCopy<uint_byte_t<sizeof(ElementD)*2>>,ElementD>,
Copy_Atom<CopyOpR2S,ElementD>>;
TiledCopy tiled_r2s = make_tiled_copy_C_atom(CopyAtomR2S{}, tiled_mma);
ThrCopy thread_r2s = tiled_r2s.get_slice(thread_idx);
Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((R2S,R2S_V),MMA_M,MMA_N)
Tensor tRS_sD = conditional_return<ReuseSmemC>(
@@ -430,7 +438,7 @@ public:
thrblk_s2g.partition_S(bEsD) ); // (S2G,S2G_M,S2G_N,PIPE)
Tensor tSG_gD = thrblk_s2g.partition_D(bEgD); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N)
CUTE_STATIC_ASSERT(size<0,0>(tRS_rAcc) % ThreadEpilogueOp::kCount == 0, "ThreadEpilogueOp does not vectorize properly");
CUTE_STATIC_ASSERT(size<0>(tRS_rAcc) % ThreadEpilogueOp::kCount == 0, "ThreadEpilogueOp does not vectorize properly");
CUTE_STATIC_ASSERT(mma_tile_m == epi_tile_m, "EPI_TILE_M must equal MMA_TILE_M");
CUTE_STATIC_ASSERT(mma_tile_n % epi_tile_n == 0, "EPI_TILE_N must divide MMA_TILE_N");
@@ -464,7 +472,13 @@ public:
int r2s_v = epi_n * size(tRS_rD_frg);
if (epilogue_op.is_source_needed()) {
// Copy source tile to register from smem
copy(tiled_s2r, tSR_sC(_,_,_,epi_m,epi_n), tSR_rC);
if constexpr (cute::is_same_v<CopyOpS2R,DefaultCopy>) {
copy(tSR_sC(_,_,_,epi_m,epi_n), tSR_rC);
}
else {
copy(tiled_s2r, tSR_sC(_,_,_,epi_m,epi_n), tSR_rC);
}
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(tRS_rD_frg); ++i) {
tRS_rD_frg(i) = epilogue_op(tRS_rAcc_frg_mn(r2s_v + i), tRS_rC_frg(i));
@@ -491,7 +505,12 @@ public:
}
// Copy output tile to smem from register
copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,epi_m,epi_n));
if constexpr (cute::is_same_v<CopyOpR2S,DefaultCopy>) {
copy(tRS_rD, tRS_sD(_,_,_,epi_m,epi_n));
}
else {
copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,epi_m,epi_n));
}
}
else {
// Issue the TMA store of the previous iteration
@@ -514,7 +533,12 @@ public:
synchronize();
// Copy tile to smem from register
copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index()));
if constexpr (cute::is_same_v<CopyOpR2S,DefaultCopy>) {
copy(tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index()));
}
else {
copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index()));
}
// Advance pipeline state
store_pipe_producer_state_prev = store_pipe_producer_state;

View File

@@ -47,8 +47,8 @@ namespace collective {
template <
int StagesC_,
int StagesD_,
class BlockTileShape_, // (BLK_M,BLK_N,BLK_K)
class EpilogueTile_, // (EPI_TILE_M,EPI_TILE_N) per-collective
class BlockTileShape_, // (BLK_M,BLK_N,BLK_K)
class EpilogueTileShape_, // (EPI_TILE_M,EPI_TILE_N)
class ElementC_,
class StrideC_,
class ElementD_,
@@ -64,7 +64,7 @@ template <
class CollectiveEpilogue<
Sm90TmaWarpSpecializedBiasElementwise<StagesC_, StagesD_>,
BlockTileShape_,
EpilogueTile_,
EpilogueTileShape_,
ElementC_,
StrideC_,
ElementD_,
@@ -81,10 +81,9 @@ public:
//
// Type Aliases
//
// derived types of output thread level operator
using DispatchPolicy = Sm90TmaWarpSpecializedBiasElementwise<StagesC_, StagesD_>;
using BlockTileShape = BlockTileShape_;
using EpilogueTile = EpilogueTile_;
using EpilogueTileShape = EpilogueTileShape_;
using ThreadEpilogueOp = ThreadEpilogueOp_;
using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator;
using ElementCompute = typename ThreadEpilogueOp::ElementCompute;
@@ -106,6 +105,9 @@ public:
using SmemLayoutAtomD = SmemLayoutAtomD_;
using CopyOpR2S = CopyOpR2S_;
using GmemTiledCopyC = SM90_TMA_LOAD;
using GmemTiledCopyD = SM90_TMA_STORE;
constexpr static bool StoreT = ThreadEpilogueOp::kStoreT;
constexpr static int kOutputAlignment = ThreadEpilogueOp::kCount;
static_assert(detail::IsThreadEpilogueOpWithBias<ThreadEpilogueOp>::value,
@@ -113,26 +115,28 @@ public:
constexpr static bool iskThreadEpilogueOpWithBias = true;
using AlignmentType = typename uint_bit<sizeof_bits<ElementOutput>::value * kOutputAlignment>::type;
static_assert(sizeof(ElementC) == 2, "Only 16b source supported for now");
static_assert(sizeof(ElementD) == 2, "Only 16b output supported for now");
static_assert(!is_layout<EpilogueTile>::value && is_tuple<EpilogueTile>::value, "EpilogueTile must be a cute::Tile or cute::Shape");
static_assert(!is_layout<EpilogueTileShape>::value && is_tuple<EpilogueTileShape>::value, "EpilogueTileShape must be a cute::Shape");
static_assert(rank(BlockTileShape{}) == 3, "BlockTileShape must be rank-3: [BLK_M,BLK_N,BLK_K]");
static_assert(rank(EpilogueTile{}) == 2, "EpilogueTile must be rank-2: [EPI_TILE_M,EPI_TILE_N]");
static_assert(rank(EpilogueTileShape{}) == 2, "EpilogueTileShape must be rank-2: [EPI_TILE_M,EPI_TILE_N]");
static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
private:
using InternalElementC = cute::conditional_t<cute::is_void_v<ElementC>,ElementD,ElementC>; // prevents void ref breakages
constexpr static int StagesC = StagesC_;
constexpr static int StagesD = StagesD_;
constexpr static bool is_source_supported = ThreadEpilogueOp::kScale == cutlass::epilogue::thread::ScaleType::Default ||
ThreadEpilogueOp::kScale == cutlass::epilogue::thread::ScaleType::NoBetaScaling;
constexpr static bool is_source_supported = not cute::is_void_v<ElementC>;
static_assert((cute::is_void_v<ElementC> && not is_source_supported) || (not cute::is_void_v<ElementC> && is_source_supported),
"Inconsistent C type and Scale kind");
// Find the max contiguous layout usable by TMA (if EpilogueTile is a by-mode tiler)
using SmemLayoutTmaD = decltype(tile_to_shape(
SmemLayoutAtomD{},
make_shape(max_common_vector(make_layout(get<0>(EpilogueTile{})),make_layout(get<0>(EpilogueTile{}))),
max_common_vector(make_layout(get<1>(EpilogueTile{})),make_layout(get<1>(EpilogueTile{})))),
cute::conditional_t<get<0>(StrideD{}) == 1, Step<_2,_1>, Step<_1,_2>>{} ));
// internal optimization to reuse C shared memory for storing D
using SmemLayoutAtomBitsC = decltype(downcast<sizeof_bits<InternalElementC>::value>(SmemLayoutAtomC{}));
using SmemLayoutAtomBitsD = decltype(downcast<sizeof_bits<ElementD>::value>(SmemLayoutAtomD{}));
constexpr static bool ReuseSmemC = is_source_supported &&
sizeof(InternalElementC) == sizeof(ElementD) &&
StrideC{} == StrideD{} &&
cute::is_same_v<SmemLayoutAtomBitsC,SmemLayoutAtomBitsD> &&
not StoreT;
public:
using SmemLayoutC = decltype(tile_to_shape(
@@ -140,29 +144,31 @@ public:
make_shape(size<0>(BlockTileShape{}), size<1>(BlockTileShape{}), Int<StagesC>{}),
cute::conditional_t<get<0>(StrideC{}) == 1, Step<_2,_1,_3>, Step<_1,_2,_3>>{} ));
using SmemLayoutD = decltype(tile_to_shape(
SmemLayoutTmaD{},
make_shape(size<0>(shape(EpilogueTile{})), size<1>(shape(EpilogueTile{})), Int<StagesD>{}),
SmemLayoutAtomD{},
make_shape(size<0>(EpilogueTileShape{}), size<1>(EpilogueTileShape{}), Int<StagesD>{}),
cute::conditional_t<get<0>(StrideD{}) == 1, Step<_2,_1,_3>, Step<_1,_2,_3>>{} ));
// TMA pipeline for loading C
using LoadPipeline = cutlass::PipelineTransactionAsync<is_source_supported ? StagesC : 0>;
using LoadPipelineState = cutlass::PipelineState<is_source_supported ? StagesC : 0>;
constexpr static uint32_t TmaTransactionBytes =
size(take<0,2>(SmemLayoutC{})) * static_cast<uint32_t>(sizeof(ElementC));
size(take<0,2>(SmemLayoutC{})) * static_cast<uint32_t>(sizeof(InternalElementC));
// TMA pipeline for storing D and T
using StorePipeline = cutlass::PipelineTmaStore<StagesD>;
using StorePipelineState = cutlass::PipelineState<StagesD>;
// TMA pipeline for storing D and T. ReuseSmemC cannot be set to true if StoreT is enabled.
using StorePipeline = cutlass::PipelineTmaStore<ReuseSmemC ? StagesC : StagesD>;
using StorePipelineState = cutlass::PipelineState<ReuseSmemC ? StagesC : StagesD>;
struct SharedStorage {
struct TensorStorage : aligned_struct<128> {
cute::conditional_t<not is_source_supported,
detail::EmptyStorage<ElementC>,
array_aligned<ElementC, size(SmemLayoutC{})>> smem_C;
alignas(128) array_aligned<ElementD, size(SmemLayoutD{})> smem_D;
alignas(128) cute::conditional_t<not StoreT,
detail::EmptyStorage<InternalElementC>,
array_aligned<InternalElementC, size(SmemLayoutC{})>> smem_C;
alignas(128) cute::conditional_t<ReuseSmemC,
detail::EmptyStorage<ElementD>,
array_aligned<ElementD, size(SmemLayoutD{})>> smem_D;
alignas(128) cute::conditional_t<not StoreT,
detail::EmptyStorage<ElementT>,
array_aligned<ElementT, size(SmemLayoutD{})>> smem_T;
array_aligned<ElementT, size(SmemLayoutD{})>> smem_T;
} tensors;
using PipelineStorage = typename LoadPipeline::SharedStorage;
@@ -173,29 +179,32 @@ public:
// Host side epilogue arguments
struct Arguments {
typename ThreadEpilogueOp::Params thread{};
ElementC const* ptr_C = nullptr;
StrideC dC{};
ElementD* ptr_D = nullptr;
StrideD dD{};
typename ThreadEpilogueOp::Params thread;
ElementC const* ptr_C;
StrideC dC;
ElementD const* ptr_D;
StrideD dD;
ElementBias const* ptr_Bias = nullptr;
ElementT* ptr_T = nullptr;
ElementT const* ptr_T = nullptr;
};
// Device side epilogue params
// Device side epilgoue params
struct Params {
using TMA_C = decltype(make_tma_copy(
CopyOpG2S{},
make_tensor(static_cast<ElementC const*>(nullptr), repeat_like(StrideC{}, int32_t(0)), StrideC{}),
make_tensor(static_cast<InternalElementC const*>(nullptr),
repeat_like(StrideC{}, int32_t(0)), StrideC{}),
SmemLayoutC{}(_,_,0)));
using TMA_D = decltype(make_tma_copy(
CopyOpS2G{},
make_tensor(static_cast<ElementD*>(nullptr), repeat_like(StrideD{}, int32_t(0)), StrideD_{}),
SmemLayoutTmaD{}));
make_tensor(static_cast<ElementD const*>(nullptr),
repeat_like(StrideD{}, int32_t(0)), StrideD{}),
SmemLayoutD{}(_,_,0)));
using TMA_T = decltype(make_tma_copy(
CopyOpS2G{},
make_tensor(static_cast<ElementT*>(nullptr), repeat_like(StrideD{}, int32_t(0)), StrideD{}),
SmemLayoutTmaD{}));
make_tensor(static_cast<ElementT const*>(nullptr),
repeat_like(StrideD{}, int32_t(0)), StrideD{}),
SmemLayoutD{}(_,_,0)));
typename ThreadEpilogueOp::Params thread{};
TMA_C tma_load_c;
TMA_D tma_store_d;
@@ -209,29 +218,42 @@ public:
template <class ProblemShape>
static constexpr Params
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, [[maybe_unused]] void* workspace) {
to_underlying_arguments(
ProblemShape const& problem_shape,
Arguments const& args,
[[maybe_unused]] void* workspace) {
// Optionally append _1s until problem shape is rank-4 in case its is only rank-3 (MNK)
auto problem_shape_MNKL = append<4>(problem_shape, Int<1>{});
auto M = get<0>(problem_shape_MNKL);
auto N = get<1>(problem_shape_MNKL);
auto L = get<3>(problem_shape_MNKL);
Tensor tensor_c = make_tensor(args.ptr_C, make_layout(make_shape(M,N,L), args.dC));
typename Params::TMA_C tma_load_c = [&]() {
if constexpr (not cute::is_void_v<ElementC>) {
Tensor tensor_c = make_tensor(static_cast<InternalElementC const*>(args.ptr_C), make_layout(make_shape(M,N,L), args.dC));
return make_tma_copy(
CopyOpG2S{},
tensor_c,
SmemLayoutC{}(_,_,0));
}
else {
return typename Params::TMA_C{};
}
}();
Tensor tensor_d = make_tensor(args.ptr_D, make_layout(make_shape(M,N,L), args.dD));
typename Params::TMA_C tma_load_c = make_tma_copy(
CopyOpG2S{},
tensor_c,
SmemLayoutC{}(_,_,0));
typename Params::TMA_D tma_store_d = make_tma_copy(
CopyOpS2G{},
tensor_d,
SmemLayoutTmaD{});
SmemLayoutD{}(_,_,0));
typename Params::TMA_T tma_store_t = [&]() {
if constexpr (StoreT) {
Tensor tensor_t = make_tensor(args.ptr_T, make_layout(make_shape(M,N,L), args.dD));
return make_tma_copy(
CopyOpS2G{},
tensor_t,
SmemLayoutTmaD{});
SmemLayoutD{}(_,_,0));
}
else {
return typename Params::TMA_T{};
@@ -262,6 +284,10 @@ public:
CUTLASS_HOST_DEVICE
static constexpr int
get_store_pipe_increment(TileShapeMNK tile_shape_MNK) {
if constexpr (ReuseSmemC) {
return get_load_pipe_increment(tile_shape_MNK);
}
// Compute number of D subtiles
constexpr int epi_m = size<0>(tile_shape_MNK) / size<0>(SmemLayoutD{});
constexpr int epi_n = size<1>(tile_shape_MNK) / size<1>(SmemLayoutD{});
@@ -276,7 +302,7 @@ public:
CUTLASS_DEVICE
bool
is_source_needed() {
return epilogue_op.is_source_needed();
return is_source_supported && epilogue_op.is_source_needed();
}
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
@@ -390,8 +416,8 @@ public:
auto L = get<3>(problem_shape_mnkl);
auto mma_tile_m = size<0>(typename TiledMma::TiledShape_MNK{});
auto mma_tile_n = size<1>(typename TiledMma::TiledShape_MNK{});
auto epi_tile_m = size<0>(shape(EpilogueTile{}));
auto epi_tile_n = size<1>(shape(EpilogueTile{}));
auto epi_tile_m = size<0>(EpilogueTileShape{});
auto epi_tile_n = size<1>(EpilogueTileShape{});
// Represent the full output tensor
Tensor mD_mnl = params.tma_store_d.get_tma_tensor(make_shape(M,N,L)); // (m,n,l)
@@ -407,7 +433,7 @@ public:
Tensor gT = gT_mnl(_,_,m_coord,n_coord,l_coord); // (TILE_M,TILE_N)
Tensor gBias = gBias_mnl(_,_,m_coord,n_coord,l_coord); // (TILE_M,TILE_N)
// Construct the smem tensors for source (sC) and output (sD)
// Construct the smem tensors for source (sC) and output (sD, sT)
Tensor sC = make_tensor(make_smem_ptr(shared_tensors.smem_C.data()), // (TILE_M,TILE_N)
SmemLayoutC{})(_,_,load_pipe_consumer_state.index());
Tensor bEsD = make_tensor(make_smem_ptr(shared_tensors.smem_D.data()), // (EPI_TILE_M,EPI_TILE_N,PIPE)
@@ -416,21 +442,26 @@ public:
SmemLayoutD{});
// Tile thread(b)lock tensors by (E)pilogue output tile shape (bE)
Tensor bEsC = local_tile(sC, EpilogueTile{}, _); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
Tensor bEgD = local_tile(gD, EpilogueTile{}, _); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
Tensor bEgT = local_tile(gT, EpilogueTile{}, _); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
Tensor bEgBias = local_tile(gBias, EpilogueTile{}, _); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
Tensor bEsC = local_tile(sC, EpilogueTileShape{}, _); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
Tensor bEgD = local_tile(gD, EpilogueTileShape{}, _); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
Tensor bEgT = local_tile(gT, EpilogueTileShape{}, _); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
Tensor bEgBias = local_tile(gBias, EpilogueTileShape{}, _); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
// Partition for register to smem copy (tRS_)
TiledCopy tiled_r2s = make_tiled_copy_C_atom(Copy_Atom<CopyOpR2S,ElementD>{}, tiled_mma);
using CopyAtomR2S = cute::conditional_t<cute::is_same_v<CopyOpR2S,DefaultCopy>,
Copy_Atom<UniversalCopy<uint_byte_t<sizeof(ElementD)*2>>,ElementD>,
Copy_Atom<CopyOpR2S,ElementD>>;
TiledCopy tiled_r2s = make_tiled_copy_C_atom(CopyAtomR2S{}, tiled_mma);
ThrCopy thread_r2s = tiled_r2s.get_slice(thread_idx);
Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((R2S,R2S_V),MMA_M,MMA_N)
Tensor tRS_sD = thread_r2s.partition_D(bEsD); // (R2S,R2S_M,R2S_N,PIPE)
Tensor tRS_sD = conditional_return<ReuseSmemC>(
thread_r2s.partition_D(recast<ElementD>(bEsC)), // (R2S,R2S_M,R2S_N,EPI_M,EPI_N)
thread_r2s.partition_D(bEsD) ); // (R2S,R2S_M,R2S_N,PIPE)
Tensor tRS_sT = thread_r2s.partition_D(bEsT); // (R2S,R2S_M,R2S_N,PIPE)
// Allocate register tensors
auto tRS_rD_shape = take<0,3>(shape(thread_r2s.partition_S(bEsD))); // (R2S,R2S_M,R2S_N)
Tensor tRS_rC = make_tensor<ElementC>(tRS_rD_shape); // (R2S,R2S_M,R2S_N)
Tensor tRS_rC = make_tensor<InternalElementC>(tRS_rD_shape); // (R2S,R2S_M,R2S_N)
Tensor tRS_rD = make_tensor<ElementD>(tRS_rD_shape); // (R2S,R2S_M,R2S_N)
Tensor tRS_rT = make_tensor<ElementT>(tRS_rD_shape); // (R2S,R2S_M,R2S_N)
@@ -445,21 +476,23 @@ public:
Tensor tRS_rBias_frg = recast<typename ThreadEpilogueOp::FragmentBias>(tRS_rBias);
// Partition for smem to register copy (tSR_)
TiledCopy tiled_s2r = make_tiled_copy_S(Copy_Atom<CopyOpS2R,ElementC>{}, tiled_r2s);
TiledCopy tiled_s2r = make_tiled_copy_S(Copy_Atom<CopyOpS2R,InternalElementC>{}, tiled_r2s);
ThrCopy thread_s2r = tiled_s2r.get_slice(thread_idx);
Tensor tSR_sC = thread_s2r.partition_S(bEsC); // (S2R,S2R_M,S2R_N,EPI_M,EPI_N)
Tensor tSR_rC = thread_s2r.retile_D(tRS_rC); // (S2R,S2R_M,S2R_N)
Tensor tSR_sC = thread_s2r.partition_S(bEsC); // (S2R,S2R_M,S2R_N,EPI_M,EPI_N)
Tensor tSR_rC = thread_s2r.retile_D(tRS_rC); // (S2R,S2R_M,S2R_N)
// Partition for smem to gmem copy (tSG_)
ThrCopy thrblk_s2g = params.tma_store_d.get_slice(Int<0>{});
Tensor tSG_sD = thrblk_s2g.partition_S(bEsD); // (S2G,S2G_M,S2G_N,PIPE)
Tensor tSG_gD = thrblk_s2g.partition_D(bEgD); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N)
Tensor tSG_sD = conditional_return<ReuseSmemC>(
thrblk_s2g.partition_S(recast<ElementD>(bEsC)), // (S2G,S2G_M,S2G_N,EPI_M,EPI_N)
thrblk_s2g.partition_S(bEsD) ); // (S2G,S2G_M,S2G_N,PIPE)
Tensor tSG_gD = thrblk_s2g.partition_D(bEgD); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N)
ThrCopy thrblk_s2g_t = params.tma_store_t.get_slice(Int<0>{});
Tensor tSG_sT = thrblk_s2g_t.partition_S(bEsT); // (S2G,S2G_M,S2G_N,PIPE)
Tensor tSG_gT = thrblk_s2g_t.partition_D(bEgT); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N)
CUTE_STATIC_ASSERT(size<0,0>(tRS_rAcc) % ThreadEpilogueOp::kCount == 0, "ThreadEpilogueOp does not vectorize properly");
CUTE_STATIC_ASSERT(size<0>(tRS_rAcc) % ThreadEpilogueOp::kCount == 0, "ThreadEpilogueOp does not vectorize properly");
CUTE_STATIC_ASSERT(mma_tile_m == epi_tile_m, "EPI_TILE_M must equal MMA_TILE_M");
CUTE_STATIC_ASSERT(mma_tile_n % epi_tile_n == 0, "EPI_TILE_N must divide MMA_TILE_N");
@@ -470,11 +503,15 @@ public:
// Predication for TMA store (one warp issues TMA store)
bool issue_tma_store = (thread_idx / NumThreadsPerWarp) == 0;
if (epilogue_op.is_source_needed()) {
if (is_source_supported && epilogue_op.is_source_needed()) {
// Wait for epilogue load to fill smem buffer with C
load_pipeline.consumer_wait(load_pipe_consumer_state);
}
// Delay issue of TMA store by 1 iteration to achieve better instruction pipelining
PipelineState store_pipe_producer_state_prev = store_pipe_producer_state;
int epi_m_prev = 0, epi_n_prev = 0;
// For each output tile
CUTLASS_PRAGMA_UNROLL
for (int epi_n = 0; epi_n < size<3>(bEgD); ++epi_n) {
@@ -490,9 +527,14 @@ public:
// Elementwise operation with conversion
int r2s_v = epi_n * size(tRS_rD_frg);
if (epilogue_op.is_source_needed()) {
if (is_source_supported && epilogue_op.is_source_needed()) {
// Copy source tile to registers from smem
copy(tiled_s2r, tSR_sC(_,_,_,epi_m,epi_n), tSR_rC);
if constexpr (cute::is_same_v<CopyOpS2R,DefaultCopy>) {
copy(tSR_sC(_,_,_,epi_m,epi_n), tSR_rC);
}
else {
copy(tiled_s2r, tSR_sC(_,_,_,epi_m,epi_n), tSR_rC);
}
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(tRS_rD_frg); ++i) {
@@ -506,40 +548,119 @@ public:
}
}
// Wait for a smem buffer to be available
if (issue_tma_store) {
store_pipeline.producer_acquire(store_pipe_producer_state);
}
synchronize();
if constexpr (ReuseSmemC) {
// If ReuseSmemC is true, StoreT must be false. Therefore, we do not perform copies for T in this block.
// Copy tile to smem from register
copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index()));
// Issue the TMA store of the previous iteration
if (not (epi_m == 0 && epi_n == 0)) {
// Make sure smem writes are visible to TMA
cutlass::arch::fence_view_async_shared();
synchronize(); // ensure all threads have issued their async fence
if constexpr (StoreT) {
copy(tiled_r2s, tRS_rT, tRS_sT(_,_,_,store_pipe_producer_state.index()));
}
// Make sure smem writes are visible to TMA
cutlass::arch::fence_view_async_shared();
synchronize(); // ensure all threads have issued their async fence
// Write the tile to gmem from smem with TMA
if (issue_tma_store) {
copy(params.tma_store_d, tSG_sD(_,_,_,store_pipe_producer_state.index()), tSG_gD(_,_,_,epi_m,epi_n));
if constexpr (StoreT) {
copy(params.tma_store_t, tSG_sT(_,_,_,store_pipe_producer_state.index()), tSG_gT(_,_,_,epi_m,epi_n));
// Write the tile to gmem from smem with TMA
if (issue_tma_store) {
copy(params.tma_store_d, tSG_sD(_,_,_,epi_m_prev,epi_n_prev), tSG_gD(_,_,_,epi_m_prev,epi_n_prev));
}
}
store_pipeline.producer_commit(store_pipe_producer_state);
// Copy output tile to smem from register
if constexpr (cute::is_same_v<CopyOpR2S,DefaultCopy>) {
copy(tRS_rD, tRS_sD(_,_,_,epi_m,epi_n));
}
else {
copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,epi_m,epi_n));
}
}
else {
// Issue the TMA store of the previous iteration
if (not (epi_m == 0 && epi_n == 0)) {
// Make sure smem writes are visible to TMA
cutlass::arch::fence_view_async_shared();
synchronize(); // ensure all threads have issued their async fence
// Write the tile to gmem from smem with TMA
if (issue_tma_store) {
copy(params.tma_store_d, tSG_sD(_,_,_,store_pipe_producer_state_prev.index()), tSG_gD(_,_,_,epi_m_prev,epi_n_prev));
if constexpr (StoreT) {
copy(params.tma_store_t, tSG_sT(_,_,_,store_pipe_producer_state_prev.index()), tSG_gT(_,_,_,epi_m_prev,epi_n_prev));
}
store_pipeline.producer_commit(store_pipe_producer_state_prev);
}
}
// Wait for a smem buffer to be available
if (issue_tma_store) {
store_pipeline.producer_acquire(store_pipe_producer_state);
}
synchronize();
// Copy tile to smem from register
if constexpr (cute::is_same_v<CopyOpR2S,DefaultCopy>) {
copy(tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index()));
if constexpr (StoreT) {
copy(tRS_rT, tRS_sT(_,_,_,store_pipe_producer_state.index()));
}
}
else {
copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index()));
if constexpr (StoreT) {
copy(tiled_r2s, tRS_rT, tRS_sT(_,_,_,store_pipe_producer_state.index()));
}
}
// Advance pipeline state
store_pipe_producer_state_prev = store_pipe_producer_state;
++store_pipe_producer_state;
}
// Advance pipeline state
++store_pipe_producer_state;
epi_m_prev = epi_m;
epi_n_prev = epi_n;
}
}
// Let dma warp know smem buffer is consumed and empty
if (epilogue_op.is_source_needed()) {
load_pipeline.consumer_release(load_pipe_consumer_state);
if constexpr (ReuseSmemC) {
// If ReuseSmemC is true, StoreT must be false. Therefore, we do not perform copies for T in this block.
// Fence and issue the TMA store of the last iteration
cutlass::arch::fence_view_async_shared();
synchronize(); // ensure all threads have issued their async fence
if (issue_tma_store) {
copy(params.tma_store_d, tSG_sD(_,_,_,epi_m_prev,epi_n_prev), tSG_gD(_,_,_,epi_m_prev,epi_n_prev));
}
// Arrive and advance pipeline state
if (issue_tma_store) {
store_pipeline.producer_commit(store_pipe_producer_state);
}
++store_pipe_producer_state;
// Wait for a smem buffer to be available
if (issue_tma_store) {
store_pipeline.producer_acquire(store_pipe_producer_state);
}
synchronize();
// Let dma warp know smem buffer is consumed and empty
if (is_source_supported && epilogue_op.is_source_needed()) {
load_pipeline.consumer_release(store_pipe_producer_state);
}
}
else {
// Fence and issue the TMA store of the last iteration
cutlass::arch::fence_view_async_shared();
synchronize(); // ensure all threads have issued their async fence
if (issue_tma_store) {
copy(params.tma_store_d, tSG_sD(_,_,_,store_pipe_producer_state_prev.index()), tSG_gD(_,_,_,epi_m_prev,epi_n_prev));
if (StoreT) {
copy(params.tma_store_t, tSG_sT(_,_,_,store_pipe_producer_state_prev.index()), tSG_gT(_,_,_,epi_m_prev,epi_n_prev));
}
store_pipeline.producer_commit(store_pipe_producer_state_prev);
}
// Let dma warp know smem buffer is consumed and empty
if (epilogue_op.is_source_needed()) {
load_pipeline.consumer_release(load_pipe_consumer_state);
}
}
}

View File

@@ -87,7 +87,7 @@ public:
using FragmentAccumulator = Array<ElementAccumulator, kElementsPerAccess>;
using FragmentCompute = Array<ElementCompute, kElementsPerAccess>;
using FragmentC = Array<ElementOutput, kElementsPerAccess>;
using FragmentC = Array<ElementC, kElementsPerAccess>;
using FragmentZ = Array<ElementZ, kElementsPerAccess>;
using FragmentT = Array<ElementT, kElementsPerAccess>;

View File

@@ -28,9 +28,9 @@
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*!
/*!
\file
\brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and
\brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and
batched array variants.
*/
@@ -57,7 +57,7 @@ namespace cutlass::gemm::device {
////////////////////////////////////////////////////////////////////////////////
/*!
/*!
GemmUniversalAdapter is a stateful, reusable GEMM handle built around a kernel
of type cutlass::gemm::kernel::Gemm or cutlass::gemm::kernel::GemmUniversal.
@@ -159,10 +159,10 @@ public:
typename CollectiveMainloop::GmemTiledCopyA, ElementA>();
static int constexpr kAlignmentB = gemm::detail::get_alignment_count_from_gmem_tiled_copy<
typename CollectiveMainloop::GmemTiledCopyB, ElementB>();
// NOTE: 3.0 DefaultEpilogues don't support vectorized stores (yet)
static int constexpr kAlignmentC = 1;
static int constexpr kAlignmentD = 1;
static int constexpr kAlignmentC = gemm::detail::get_alignment_count_from_gmem_tiled_copy<
typename CollectiveEpilogue::GmemTiledCopyC, ElementC>();
static int constexpr kAlignmentD = gemm::detail::get_alignment_count_from_gmem_tiled_copy<
typename CollectiveEpilogue::GmemTiledCopyD, ElementD>();
using EpilogueOutputOp = typename CollectiveEpilogue::ThreadEpilogueOp;
@@ -327,7 +327,7 @@ public:
static Status
run(Params& params, cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST("GemmUniversal::run()");
dim3 constexpr block = GemmKernel::get_block_shape();
dim3 const block = GemmKernel::get_block_shape();
dim3 const grid = get_grid_shape(params);
// configure smem size and carveout
@@ -404,19 +404,19 @@ public:
using GemmKernel = GemmKernel_;
static bool const kInternalTranspose =
static bool const kInternalTranspose =
cute::is_same<typename GemmKernel::LayoutC, cutlass::layout::RowMajor>::value;
using ThreadblockShape = typename GemmKernel::Mma::Shape;
using WarpShape = typename GemmKernel::WarpShape;
using InstructionShape = typename GemmKernel::InstructionShape;
// warp-level, arch-level (instruction), math operator
// warp-level, arch-level (instruction), math operator
using WarpMmaOperator = typename GemmKernel::Mma::Policy::Operator;
using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator;
using MathOperator = typename WarpMmaOperator::MathOperator;
// Operator class and arch tag extract bottom-up
// Operator class and arch tag extract bottom-up
// set it for top-level gemm device-level template
using OperatorClass = typename WarpMmaOperator::OperatorClass;
using ArchTag = typename WarpMmaOperator::ArchTag;
@@ -444,15 +444,15 @@ public:
using LayoutB = typename MapArguments::LayoutB;
static ComplexTransform const kTransformB = MapArguments::kTransformB;
static int const kAlignmentB = MapArguments::kAlignmentB;
using ElementC = typename GemmKernel::ElementC;
using LayoutC = typename MapArguments::LayoutC;
static int const kAlignmentC = GemmKernel::kAlignmentC;
// C and D same type for 2.x kernel
using ElementD = ElementC;
using LayoutD = LayoutC;
using TensorRefA = TensorRef<ElementA const, LayoutA>;
using TensorRefB = TensorRef<ElementB const, LayoutB>;
using TensorRefC = TensorRef<ElementC const, LayoutC>;
@@ -493,12 +493,12 @@ public:
/// Gets the workspace size
static size_t get_workspace_size(Arguments const &args) {
return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args));
}
/// Computes the grid shape
static dim3 get_grid_shape(Arguments const &args) {
static dim3 get_grid_shape(Arguments const &args) {
return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args));
}
@@ -532,12 +532,12 @@ public:
/// Runs the kernel using initialized state.
Status operator()(
Arguments const &args,
void *workspace = nullptr,
Arguments const &args,
void *workspace = nullptr,
cudaStream_t stream = nullptr) {
Status status = initialize(args, workspace, stream);
if (status == Status::kSuccess) {
status = run(stream);
}

View File

@@ -75,6 +75,8 @@ public:
static ComplexTransform const kTransformB = GemvKernel::kTransformB;
static int const kThreadCount = GemvKernel::kThreadCount;
static int const kThreadsPerRow = GemvKernel::kThreadsPerRow;
static int const kStages = GemvKernel::kStages;
static int const kAlignmentA = GemvKernel::kAlignmentA;
@@ -106,8 +108,23 @@ public:
}
/// Computes the grid shape
static dim3 get_grid_shape(Arguments const &args) {
return dim3((args.problem_size.row() + (kThreadCount - 1)) / kThreadCount, 1, args.batch_count % 65565);
static dim3 get_grid_shape(Arguments const &args, dim3 const &block) {
if(platform::is_same<LayoutA, layout::ColumnMajor>::value) {
return dim3((args.problem_size.row() + (block.x - 1)) / block.x, 1, args.batch_count % 65536);
}
else {
return dim3((args.problem_size.row() + (block.y - 1)) / block.y, 1, args.batch_count % 65536);
}
}
/// Computes the block shape
static dim3 get_block_shape() {
if(platform::is_same<LayoutA, layout::ColumnMajor>::value) {
return dim3(kThreadCount, 1, 1);
}
else {
return dim3(kThreadsPerRow, kThreadCount / kThreadsPerRow, 1);
}
}
/// Initializes Gemv state from arguments.
@@ -124,8 +141,8 @@ public:
/// Runs the kernel using initialized state.
Status run(cudaStream_t stream = nullptr) {
dim3 grid = get_grid_shape(params_);
dim3 block(GemvKernel::kThreadCount, 1, 1);
dim3 block = get_block_shape();
dim3 grid = get_grid_shape(params_, block);
int smem_size = int(sizeof(typename GemvKernel::SharedStorage));
@@ -137,11 +154,7 @@ public:
//
cudaError_t result = cudaGetLastError();
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
return Status::kSuccess;
return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
}
/// Runs the kernel using initialized state.

View File

@@ -1,167 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/arch/arch.h"
#include "cutlass/device_kernel.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/gemm/kernel/gemm_universal.h"
#include "cutlass/gemm/kernel/default_gemm_universal.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"
#include "cutlass/gemm/device/gemm_universal_base.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace device {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename GemvKernel_>
class GemvStridedBatched {
public:
using GemvKernel = GemvKernel_;
using ElementA = typename GemvKernel::ElementA;
using LayoutA = typename GemvKernel::LayoutA;
using ElementB = typename GemvKernel::ElementB;
using ElementC = typename GemvKernel::ElementC;
using ElementAccumulator = typename GemvKernel::ElementAccumulator;
using EpilogueOutputOp = typename GemvKernel::EpilogueOutputOp;
static ComplexTransform const kTransformA = GemvKernel::kTransformA;
static ComplexTransform const kTransformB = GemvKernel::kTransformB;
static int const kThreadCount = GemvKernel::kThreadCount;
static int const mThreadCount = GemvKernel::mThreadCount;
static int const kStages = GemvKernel::kStages;
static int const kAlignmentA = GemvKernel::kAlignmentA;
static int const kAlignmentB = GemvKernel::kAlignmentB;
static int const kAlignmentC = GemvKernel::kAlignmentC;
using Arguments = typename GemvKernel::Arguments;
using Params = typename GemvKernel::Params;
private:
Params params_;
public:
/// Constructs the Gemv.
GemvStridedBatched() {}
/// Determines whether the Gemv can execute the given problem.
static Status can_implement(Arguments const& args) {
return GemvKernel::can_implement(args);
}
/// Gets the workspace size
static size_t get_workspace_size(Arguments const& args) { return 0; }
/// Initializes Gemv state from arguments.
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
params_ = Params(args);
if (args.problem_size.column() % GemvKernel::kElementsPerAccess) {
return Status::kErrorMisalignedOperand;
}
return Status::kSuccess;
}
/// Lightweight update given a subset of arguments
Status update(Arguments const &args, void *workspace = nullptr) {
return params_.update(args);
}
/// Runs the kernel using initialized state.
Status run(cudaStream_t stream = nullptr) {
dim3 grid(1, 1, params_.batch_count % 65536);
dim3 block(kThreadCount, mThreadCount, 1);
int smem_size = 0;
// Launch
cutlass::Kernel<GemvKernel><<<grid, block, smem_size, stream>>>(params_);
//
// Query for errors
//
cudaError_t result = cudaGetLastError();
return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
}
/// Runs the kernel using initialized state.
Status operator()(cudaStream_t stream = nullptr) { return run(stream); }
/// Runs the kernel using initialized state.
Status operator()(
Arguments const &args,
void *workspace = nullptr,
cudaStream_t stream = nullptr) {
Status status = initialize(args, workspace, stream);
if (status == Status::kSuccess) {
status = run(stream);
}
return status;
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace device
} // namespace gemm
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////

View File

@@ -522,15 +522,27 @@ stride_to_layout_tag_B() {
template <class GmemTiledCopy, class Element>
constexpr int
get_alignment_count_from_gmem_tiled_copy() {
// For TMA tiled copies, we know the alignment has to be 128 bits
if constexpr ( cute::is_base_of_v<cute::SM90_TMA_LOAD, GmemTiledCopy>
|| cute::is_base_of_v<cute::SM90_TMA_LOAD_MULTICAST, GmemTiledCopy>
) {
return 128 / sizeof_bits<Element>::value;
if constexpr (cute::is_void_v<GmemTiledCopy>) {
return 1;
}
// Account for ElementC = void kernels
else if constexpr (cute::is_void_v<Element>) {
return 0;
}
else {
// For non-TMA tiled copies, TiledCopy holds the alignment count directly in its TiledShape_MN
return GmemTiledCopy::NumValSrc;
// For TMA tiled copies, we know the alignment has to be 128 bits
if constexpr ( cute::is_base_of_v<cute::SM90_TMA_LOAD, GmemTiledCopy>
|| cute::is_base_of_v<cute::SM90_TMA_LOAD_MULTICAST, GmemTiledCopy>
|| cute::is_base_of_v<cute::SM90_TMA_STORE, GmemTiledCopy>
) {
return 128 / sizeof_bits<Element>::value;
}
else {
// For non-TMA tiled copies, TiledCopy holds the alignment count directly in its TiledShape_MN
return GmemTiledCopy::NumValSrc;
}
}
}

View File

@@ -41,9 +41,13 @@
#include "cutlass/complex.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/arch/memory.h"
#include "cutlass/arch/cache_operation.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/numeric_conversion.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
@@ -58,18 +62,49 @@ template <
typename ElementB_,
typename ElementC_,
typename ElementAccumulator_,
typename EpilogueOutputOp_
typename EpilogueOutputOp_,
int kElementsPerAccess_ = 1, ///< Number of elements involved in a global access.
int kThreadCount_ = 0, ///< Number of threads in the thread block.
/// It will be calculated automatically if set to 0.
int kThreadsPerRow_ = 0 ///< Number of threads in the k dimension.
/// It will be calculated automatically if set to 0.
>
struct Gemv {
struct Gemv;
/////////////////////////////////////////////////////////////////////////////////////////////////
//
// Specializations
//
/////////////////////////////////////////////////////////////////////////////////////////////////
// GEMV for column-major A matrix
template <
typename ElementA_,
typename ElementB_,
typename ElementC_,
typename ElementAccumulator_,
typename EpilogueOutputOp_,
int kElementsPerAccess_,
int kThreadCount_,
int kThreadsPerRow_
>
struct Gemv <
ElementA_,
layout::ColumnMajor,
ElementB_,
ElementC_,
ElementAccumulator_,
EpilogueOutputOp_,
kElementsPerAccess_,
kThreadCount_,
kThreadsPerRow_
>{
public:
using ElementA = ElementA_;
using LayoutA = layout::ColumnMajor;
using TensorRefA = TensorRef<ElementA, LayoutA>;
static_assert(platform::is_same<LayoutA, LayoutA_>::value,
"Only supported for column-major A matrix");
using ElementB = ElementB_;
using ElementC = ElementC_;
@@ -79,7 +114,10 @@ public:
static ComplexTransform const kTransformA = ComplexTransform::kNone;
static ComplexTransform const kTransformB = ComplexTransform::kNone;
static int const kThreadCount = 32;
// thread block shape (kThreadCount, 1, 1)
static int const kThreadCount = (kThreadCount_ == 0) ? 32 : kThreadCount_;
static int const kThreadsPerRow = kThreadsPerRow_;
static int const kStages = 1;
static int const kAlignmentA = 1;
@@ -121,17 +159,17 @@ public:
MatrixCoord problem_size,
int batch_count,
typename EpilogueOutputOp::Params output_op,
TensorRefA ref_A,
void const * ptr_B,
void const * ptr_C,
void * ptr_D,
int64_t inc_B,
int64_t inc_C,
int64_t inc_D,
int64_t batch_stride_A,
int64_t batch_stride_B,
int64_t batch_stride_C,
int64_t batch_stride_D
TensorRefA ref_A,
void const *ptr_B,
void const *ptr_C,
void *ptr_D,
int64_t inc_B,
int64_t inc_C,
int64_t inc_D,
int64_t batch_stride_A,
int64_t batch_stride_B,
int64_t batch_stride_C,
int64_t batch_stride_D
):
problem_size(problem_size),
batch_count(batch_count),
@@ -151,14 +189,44 @@ public:
Arguments(
MatrixCoord problem_size,
int batch_count,
typename EpilogueOutputOp::Params output_op,
TensorRefA ref_A,
void const * ptr_B,
void const * ptr_C,
void * ptr_D,
int64_t inc_B,
int64_t inc_C,
int64_t inc_D
TensorRefA ref_A,
void const *ptr_B,
void const *ptr_C,
void *ptr_D,
int64_t batch_stride_A,
int64_t batch_stride_B,
int64_t batch_stride_C,
int64_t batch_stride_D
):
Arguments(
problem_size,
batch_count,
output_op,
ref_A,
ptr_B,
ptr_C,
ptr_D,
1,
1,
1,
batch_stride_A,
batch_stride_B,
batch_stride_C,
batch_stride_D)
{ }
Arguments(
MatrixCoord problem_size,
typename EpilogueOutputOp::Params output_op,
TensorRefA ref_A,
void const *ptr_B,
void const *ptr_C,
void *ptr_D,
int64_t inc_B,
int64_t inc_C,
int64_t inc_D
):
Arguments(
problem_size,
@@ -206,7 +274,6 @@ public:
/// Determines whether kernel satisfies alignment
static Status can_implement(cutlass::MatrixCoord const & problem_size) {
return Status::kSuccess;
}
@@ -214,7 +281,7 @@ public:
return can_implement(args.problem_size);
}
/// Executes one GEMM
/// Executes one GEMV
CUTLASS_DEVICE
void operator()(Params const &params, SharedStorage &shared_storage) {
@@ -282,6 +349,288 @@ public:
/////////////////////////////////////////////////////////////////////////////////////////////////
// GEMV for row-major A matrix
template <
typename ElementA_,
typename ElementB_,
typename ElementC_,
typename ElementAccumulator_,
typename EpilogueOutputOp_,
int kElementsPerAccess_,
int kThreadCount_,
int kThreadsPerRow_
>
struct Gemv <
ElementA_,
layout::RowMajor,
ElementB_,
ElementC_,
ElementAccumulator_,
EpilogueOutputOp_,
kElementsPerAccess_,
kThreadCount_,
kThreadsPerRow_
>{
public:
using ElementA = ElementA_;
using LayoutA = layout::RowMajor;
using TensorRefA = TensorRef<ElementA, LayoutA>;
using ElementB = ElementB_;
using ElementC = ElementC_;
using ElementAccumulator = ElementAccumulator_;
using EpilogueOutputOp = EpilogueOutputOp_;
static ComplexTransform const kTransformA = ComplexTransform::kNone;
static ComplexTransform const kTransformB = ComplexTransform::kNone;
static FloatRoundStyle const Round = cutlass::FloatRoundStyle::round_to_nearest;
// number of return elements in a global access
static int const kElementsPerAccess = kElementsPerAccess_;
using FragmentA = Array<ElementA, kElementsPerAccess>;
using FragmentB = Array<ElementB, kElementsPerAccess>;
using FragmentCompute = Array<ElementAccumulator, kElementsPerAccess>;
// thread block shape (kThreadsPerRow, kThreadCount / kThreadsPerRow, 1)
static int const kThreadCount = (kThreadCount_ == 0) ? 128 : kThreadCount_;
static int const kThreadsPerRow = (kThreadsPerRow_ == 0) ?
std::min(static_cast<int>(kThreadCount / (kElementsPerAccess * sizeof(ElementA))), 16)
: kThreadsPerRow_;
//
// Structures
//
/// Argument structure
struct Arguments {
MatrixCoord problem_size;
int32_t batch_count;
typename EpilogueOutputOp::Params output_op;
TensorRefA ref_A;
ElementB const *ptr_B;
ElementC const *ptr_C;
ElementC *ptr_D;
int64_t batch_stride_A;
int64_t batch_stride_B;
int64_t batch_stride_C;
int64_t batch_stride_D;
//
// Methods
//
Arguments(): batch_count(0) { }
Arguments(
MatrixCoord problem_size,
int32_t batch_count,
typename EpilogueOutputOp::Params output_op,
TensorRefA ref_A,
void const *ptr_B,
void const *ptr_C,
void *ptr_D,
int64_t batch_stride_A,
int64_t batch_stride_B,
int64_t batch_stride_C,
int64_t batch_stride_D
):
problem_size(problem_size),
batch_count(batch_count),
output_op(output_op),
ref_A(ref_A),
ptr_B(static_cast<ElementB const *>(ptr_B)),
ptr_C(static_cast<ElementC const *>(ptr_C)),
ptr_D(static_cast<ElementC *>(ptr_D)),
batch_stride_A(batch_stride_A),
batch_stride_B(batch_stride_B),
batch_stride_C(batch_stride_C),
batch_stride_D(batch_stride_D)
{ }
Arguments(
MatrixCoord problem_size,
typename EpilogueOutputOp::Params output_op,
TensorRefA ref_A,
void const *ptr_B,
void const *ptr_C,
void *ptr_D
):
Arguments(
problem_size,
1,
output_op,
ref_A,
ptr_B,
ptr_C,
ptr_D,
1,
1,
1,
1)
{ }
Status update(Arguments const &args) {
problem_size = args.problem_size;
batch_count = args.batch_count;
output_op = args.output_op;
ref_A = ref_A;
ptr_B = args.ptr_B;
ptr_C = args.ptr_C;
ptr_D = args.ptr_D;
batch_stride_A = args.batch_stride_A;
batch_stride_B = args.batch_stride_B;
batch_stride_C = args.batch_stride_C;
batch_stride_D = args.batch_stride_D;
return Status::kSuccess;
}
};
using Params = Arguments;
/// Shared memory storage structure
union SharedStorage {
};
public:
//
// Methods
//
CUTLASS_DEVICE
Gemv() {}
/// Determines whether kernel satisfies alignment
static Status can_implement(cutlass::MatrixCoord const &problem_size) {
if (problem_size.column() % kElementsPerAccess != 0) {
return Status::kErrorMisalignedOperand;
}
return Status::kSuccess;
}
static Status can_implement(Arguments const &args) {
return can_implement(args.problem_size);
}
/// Executes one GEMV
CUTLASS_DEVICE
void operator()(Params const &params, SharedStorage &shared_storage) {
// Loop over batch indices
for (int batch_idx = blockIdx.z; batch_idx < params.batch_count; batch_idx += gridDim.z) {
int idx_col_k = threadIdx.x;
int idx_row_m = blockIdx.x * blockDim.y + threadIdx.y;
if (idx_row_m < params.problem_size.row()) {
// problem_size (row = m, column = k)
// matrix A (batch, m, k)
// vector B (batch, 1, k)
// vector C (batch, m, 1)
// vector D (batch, m, 1)
// move in the batch dimension
ElementA const *ptr_A = params.ref_A.data() + batch_idx * params.batch_stride_A;
ElementB const *ptr_B = params.ptr_B + batch_idx * params.batch_stride_B;
ElementC const *ptr_C = params.ptr_C + batch_idx * params.batch_stride_C;
ElementC *ptr_D = params.ptr_D + batch_idx * params.batch_stride_D;
// move in the k dimension
ptr_A += idx_col_k * kElementsPerAccess;
ptr_B += idx_col_k * kElementsPerAccess;
// move in the m dimension
ptr_A += idx_row_m * params.problem_size.column();
ptr_C += idx_row_m;
ptr_D += idx_row_m;
NumericArrayConverter<ElementAccumulator, ElementA, kElementsPerAccess, Round> srcA_converter;
NumericArrayConverter<ElementAccumulator, ElementB, kElementsPerAccess, Round> srcB_converter;
ElementAccumulator accum = 0.f;
FragmentB fragB;
FragmentA fragA;
int unroll_col_k = 0;
// rows of the rolling tile
int const tileA_k = kThreadsPerRow * kElementsPerAccess;
for (; unroll_col_k < params.problem_size.column() / tileA_k * tileA_k; unroll_col_k += tileA_k) {
// fetch from matrix A
arch::global_load<FragmentA,
sizeof(FragmentA),
arch::CacheOperation::LastUse>(fragA, (ptr_A + unroll_col_k), true);
// fetch from vector B
arch::global_load<FragmentB,
sizeof(FragmentB),
arch::CacheOperation::Always>(fragB, (ptr_B + unroll_col_k), true);
FragmentCompute fragB_Compute = srcB_converter(fragB);
FragmentCompute fragA_Compute = srcA_converter(fragA);
// Math
CUTLASS_PRAGMA_UNROLL
for (int e = 0; e < kElementsPerAccess; e++) {
accum += fragA_Compute.at(e) * fragB_Compute.at(e);
}
}
// calculate the rest of K elements
// each thread fetch 1 element each time
for (int k = unroll_col_k + idx_col_k; k < params.problem_size.column(); k += kThreadsPerRow) {
ElementB b = *(ptr_B - idx_col_k * kElementsPerAccess + k);
ElementA a = *(ptr_A - idx_col_k * kElementsPerAccess + k);
accum += ElementAccumulator(a) * ElementAccumulator(b);
}
EpilogueOutputOp output_op(params.output_op);
typename EpilogueOutputOp::FragmentOutput source_fragment;
// prefetch from source matrix C
if (output_op.is_source_needed()) {
source_fragment[0] = *(ptr_C);
}
typename EpilogueOutputOp::FragmentAccumulator accum_fragment;
typename EpilogueOutputOp::FragmentOutput output_fragment;
for (int mask = (kThreadsPerRow >> 1); mask > 0; mask >>= 1) {
accum += __shfl_xor_sync(0xFFFFFFFF, accum, mask, 32);
}
if (idx_col_k == 0) {
accum_fragment[0] = accum;
if (output_op.is_source_needed()) {
output_fragment = output_op(accum_fragment, source_fragment);
}
else {
output_fragment = output_op(accum_fragment);
}
*ptr_D = output_fragment[0];
}
}
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass

View File

@@ -1,368 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/complex.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/arch/memory.h"
#include "cutlass/arch/cache_operation.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/numeric_conversion.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename ElementA_, /// matrix
typename LayoutA_,
typename ElementB_, /// vector
typename ElementC_,
typename ElementAccumulator_,
int kElementsPerAccess_,
typename EpilogueOutputOp_
>
struct GemvStridedBatched {
public:
using ElementA = ElementA_;
using LayoutA = layout::RowMajor;
using TensorRefA = TensorRef<ElementA, LayoutA>;
static_assert(std::is_same<LayoutA, LayoutA_>::value,
"Only supported for row-major A matrix");
using ElementB = ElementB_;
using ElementC = ElementC_;
using ElementAccumulator = ElementAccumulator_;
using EpilogueOutputOp = EpilogueOutputOp_;
static ComplexTransform const kTransformA = ComplexTransform::kNone;
static ComplexTransform const kTransformB = ComplexTransform::kNone;
static FloatRoundStyle const Round = cutlass::FloatRoundStyle::round_to_nearest;
// number of return elements in a global access
static int const kElementsPerAccess = kElementsPerAccess_;
using FragmentA = Array<ElementA, kElementsPerAccess>;
using FragmentB = Array<ElementB, kElementsPerAccess>;
using FragmentCompute = Array<ElementAccumulator, kElementsPerAccess>;
// thread block shape (kThreadCount, mThreadCount)
static int const kThreadCount = std::min(static_cast<int>(128 / (kElementsPerAccess * sizeof(ElementA))), 16);
static int const mThreadCount = 128 / kThreadCount;
// rolling tile shape
static int const kTileA = kThreadCount * kElementsPerAccess;
static int const mTileA = mThreadCount * 8;
//
// Structures
//
/// Argument structure
struct Arguments
{
MatrixCoord problem_size;
int32_t batch_count;
typename EpilogueOutputOp::Params output_op;
TensorRefA ref_A;
ElementB const *ptr_B;
ElementC const *ptr_C;
ElementC *ptr_D;
int64_t batch_stride_A;
int64_t batch_stride_B;
int64_t batch_stride_C;
int64_t batch_stride_D;
//
// Methods
//
Arguments() : batch_count(0) {}
Arguments(
MatrixCoord problem_size,
int32_t batch_count,
typename EpilogueOutputOp::Params output_op,
TensorRefA ref_A,
void const *ptr_B,
void const *ptr_C,
void *ptr_D,
int64_t batch_stride_A,
int64_t batch_stride_B,
int64_t batch_stride_C,
int64_t batch_stride_D) : problem_size(problem_size),
batch_count(batch_count),
output_op(output_op),
ref_A(ref_A),
ptr_B(static_cast<ElementB const *>(ptr_B)),
ptr_C(static_cast<ElementC const *>(ptr_C)),
ptr_D(static_cast<ElementC *>(ptr_D)),
batch_stride_A(batch_stride_A),
batch_stride_B(batch_stride_B),
batch_stride_C(batch_stride_C),
batch_stride_D(batch_stride_D)
{
}
Arguments(
MatrixCoord problem_size,
typename EpilogueOutputOp::Params output_op,
TensorRefA ref_A,
void const *ptr_B,
void const *ptr_C,
void *ptr_D) : Arguments(problem_size,
1,
1,
output_op,
ref_A,
ptr_B,
ptr_C,
ptr_D,
1,
1,
1,
1)
{
}
Status update(Arguments const &args)
{
problem_size = args.problem_size;
batch_count = args.batch_count;
output_op = args.output_op;
ref_A = ref_A;
ptr_B = args.ptr_B;
ptr_C = args.ptr_C;
ptr_D = args.ptr_D;
batch_stride_A = args.batch_stride_A;
batch_stride_B = args.batch_stride_B;
batch_stride_C = args.batch_stride_C;
batch_stride_D = args.batch_stride_D;
return Status::kSuccess;
}
};
using Params = Arguments;
/// Shared memory storage structure
union SharedStorage
{
};
public:
//
// Methods
//
CUTLASS_DEVICE
GemvStridedBatched() {}
/// Determines whether kernel satisfies alignment
static Status can_implement(cutlass::MatrixCoord const &problem_size)
{
if (problem_size.column() % kElementsPerAccess != 0)
return Status::kErrorMisalignedOperand;
return Status::kSuccess;
}
static Status can_implement(Arguments const &args)
{
return can_implement(args.problem_size);
}
/// Executes one GEMV
CUTLASS_DEVICE
void operator()(Params const &params, SharedStorage &shared_storage)
{
// Loop over batch indices
for (int batch_idx = blockIdx.z; batch_idx < params.batch_count; batch_idx += gridDim.z)
{
int k_col_id = threadIdx.x;
int m_row_id = threadIdx.y;
// problem_size (row = m, column = k)
// matrix A (batch, m, k)
// vector B (batch, 1, k)
// vector C (batch, m, 1)
// vector D (batch, m, 1)
// move in the batch dimension
ElementA const *ptr_A = params.ref_A.data() + batch_idx * params.batch_stride_A;
ElementB const *ptr_B = params.ptr_B + batch_idx * params.batch_stride_B;
ElementC const *ptr_C = params.ptr_C + batch_idx * params.batch_stride_C;
ElementC *ptr_D = params.ptr_D + batch_idx * params.batch_stride_D;
// move in the k dimension
ptr_A += k_col_id * kElementsPerAccess;
ptr_B += k_col_id * kElementsPerAccess;
// move in the m dimension
ptr_A += m_row_id * params.problem_size.column();
ptr_C += m_row_id;
ptr_D += m_row_id;
NumericArrayConverter<ElementAccumulator, ElementA, kElementsPerAccess, Round> srcA_converter;
NumericArrayConverter<ElementAccumulator, ElementB, kElementsPerAccess, Round> srcB_converter;
for (; m_row_id < params.problem_size.row(); m_row_id += mTileA)
{
ElementAccumulator accum[mTileA / mThreadCount] = {0.f};
FragmentB fragB;
FragmentA fragA[mTileA / mThreadCount];
int mElemCountPerTile = min(mTileA / mThreadCount, (params.problem_size.row() - m_row_id - 1) / mThreadCount + 1);
int kUnroll = 0;
for (; kUnroll < params.problem_size.column() / kTileA * kTileA; kUnroll += kTileA)
{
for (int m = 0; m < mElemCountPerTile; m++)
{
// fetch from matrix A
arch::global_load<FragmentA,
sizeof(FragmentA),
arch::CacheOperation::LastUse>(fragA[m], (ptr_A + kUnroll + m * mThreadCount * params.problem_size.column()), true);
}
// fetch from vector B
arch::global_load<FragmentB,
sizeof(FragmentB),
arch::CacheOperation::Always>(fragB, (ptr_B + kUnroll), true);
for (int m = 0; m < mElemCountPerTile; m++)
{
FragmentCompute fragB_Compute = srcB_converter(fragB);
FragmentCompute fragA_Compute = srcA_converter(fragA[m]);
// Math
CUTLASS_PRAGMA_UNROLL
for (int e = 0; e < kElementsPerAccess; e++)
{
accum[m] += fragA_Compute.at(e) * fragB_Compute.at(e);
}
}
}
// calculate the rest of K elements
// each thread fetch 1 element each time
for (int k = kUnroll + k_col_id; k < params.problem_size.column(); k += kThreadCount)
{
ElementB b = *(ptr_B - k_col_id * kElementsPerAccess + k);
for (int m = 0; m < mElemCountPerTile; m++)
{
ElementA a = *(ptr_A - k_col_id * kElementsPerAccess + k + m * mThreadCount * params.problem_size.column());
accum[m] += ElementAccumulator(a) * ElementAccumulator(b);
}
}
EpilogueOutputOp output_op(params.output_op);
typename EpilogueOutputOp::FragmentOutput source_fragment[mTileA / mThreadCount];
// prefetch from source matrix C
if (output_op.is_source_needed())
{
for (int m = 0; m < mElemCountPerTile; m++)
{
source_fragment[m][0] = *(ptr_C + m * mThreadCount);
}
}
typename EpilogueOutputOp::FragmentAccumulator accum_fragment;
typename EpilogueOutputOp::FragmentOutput output_fragment;
for (int m = 0; m < mElemCountPerTile; m++)
{
for (int mask = (kThreadCount >> 1); mask > 0; mask >>= 1)
{
accum[m] += __shfl_xor_sync(0xFFFFFFFF, accum[m], mask, 32);
}
if (k_col_id == 0)
{
accum_fragment[0] = accum[m];
if (output_op.is_source_needed())
{
output_fragment = output_op(accum_fragment, source_fragment[m]);
}
else
{
output_fragment = output_op(accum_fragment);
}
*(ptr_D + m * mThreadCount) = output_fragment[0];
}
}
ptr_A += mTileA * params.problem_size.column();
ptr_C += mTileA;
ptr_D += mTileA;
}
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -129,21 +129,18 @@ public:
};
}
static
bool
static bool
can_implement(Arguments const& args) {
return args.mode == GemmUniversalMode::kGemm or
(args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4);
}
static
int
static int
get_workspace_size(Arguments const& args) {
return 0;
}
static constexpr
dim3
static dim3
get_grid_shape(Params const& params) {
int batch_count = 1;
if constexpr (rank(ProblemShape{}) == 4) {
@@ -157,8 +154,7 @@ public:
);
}
static constexpr
dim3
static dim3
get_block_shape() {
return dim3(MaxThreadsPerBlock, 1, 1);
}

View File

@@ -172,20 +172,20 @@ public:
auto N = get<1>(args.problem_shape);
auto K = get<2>(args.problem_shape);
// Contiguous dimension for the TMA tensor should be 128b aligned
implementable = std::is_same_v<gemm::detail::StrideToLayoutTagA_t<StrideA>, layout::RowMajor> ?
implementable = std::is_same_v<gemm::detail::StrideToLayoutTagA_t<StrideA>, layout::RowMajor> ?
K % min_tma_aligned_elements == 0 : M % min_tma_aligned_elements == 0;
implementable = implementable && (std::is_same_v<gemm::detail::StrideToLayoutTagB_t<StrideB>, layout::RowMajor> ?
implementable = implementable && (std::is_same_v<gemm::detail::StrideToLayoutTagB_t<StrideB>, layout::RowMajor> ?
N % min_tma_aligned_elements == 0 : K % min_tma_aligned_elements == 0);
implementable = implementable && (!cutlass::epilogue::collective::detail::IF_EPILOGUE_USES_TMA<CollectiveEpilogue>::value ||
(cutlass::epilogue::collective::detail::IF_EPILOGUE_USES_TMA<CollectiveEpilogue>::value &&
std::is_same_v<gemm::detail::StrideToLayoutTagC_t<StrideC>, layout::RowMajor> ?
std::is_same_v<gemm::detail::StrideToLayoutTagC_t<StrideC>, layout::RowMajor> ?
N % min_tma_aligned_elements == 0 : M % min_tma_aligned_elements == 0));
if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
return implementable;
}
constexpr bool is_beta_supported =
constexpr bool is_beta_supported =
CollectiveEpilogue::ThreadEpilogueOp::kScale == cutlass::epilogue::thread::ScaleType::Default;
implementable = is_beta_supported || (args.epilogue.thread.beta == 0 && args.epilogue.thread.beta_ptr == nullptr);
if (!implementable) {
@@ -196,15 +196,13 @@ public:
return implementable;
}
static
int
static int
get_workspace_size(Arguments const& args) {
return 0;
}
// Computes the kernel launch grid shape based on runtime parameters
static constexpr
dim3
static dim3
get_grid_shape(Params const& params) {
auto cluster_shape = ClusterShape{};
auto tile_shape = TileShape{};
@@ -213,8 +211,7 @@ public:
problem_shape_MNKL, tile_shape, cluster_shape);
}
static constexpr
dim3
static dim3
get_block_shape() {
return dim3(MaxThreadsPerBlock, 1, 1);
}
@@ -243,7 +240,7 @@ public:
int warp_idx = canonical_warp_idx();
int lane_predicate = cute::elect_one_sync();
// Issue Tma Descriptor Prefetch from a single thread
// Issue Tma Descriptor Prefetch from a single thread
if ((warp_idx == 0) && lane_predicate) {
CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
}

View File

@@ -179,20 +179,20 @@ public:
auto N = get<1>(args.problem_shape);
auto K = get<2>(args.problem_shape);
// Contiguous dimension for the TMA tensor should be 128b aligned
implementable = std::is_same_v<gemm::detail::StrideToLayoutTagA_t<StrideA>, layout::RowMajor> ?
implementable = std::is_same_v<gemm::detail::StrideToLayoutTagA_t<StrideA>, layout::RowMajor> ?
K % min_tma_aligned_elements == 0 : M % min_tma_aligned_elements == 0;
implementable = implementable && (std::is_same_v<gemm::detail::StrideToLayoutTagB_t<StrideB>, layout::RowMajor> ?
implementable = implementable && (std::is_same_v<gemm::detail::StrideToLayoutTagB_t<StrideB>, layout::RowMajor> ?
N % min_tma_aligned_elements == 0 : K % min_tma_aligned_elements == 0);
implementable = implementable && (!cutlass::epilogue::collective::detail::IF_EPILOGUE_USES_TMA<CollectiveEpilogue>::value ||
(cutlass::epilogue::collective::detail::IF_EPILOGUE_USES_TMA<CollectiveEpilogue>::value &&
std::is_same_v<gemm::detail::StrideToLayoutTagC_t<StrideC>, layout::RowMajor> ?
std::is_same_v<gemm::detail::StrideToLayoutTagC_t<StrideC>, layout::RowMajor> ?
N % min_tma_aligned_elements == 0 : M % min_tma_aligned_elements == 0));
if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
return implementable;
}
constexpr bool is_beta_supported =
constexpr bool is_beta_supported = not cute::is_void_v<ElementC> &&
CollectiveEpilogue::ThreadEpilogueOp::kScale == cutlass::epilogue::thread::ScaleType::Default;
implementable = is_beta_supported || (args.epilogue.thread.beta == 0 && args.epilogue.thread.beta_ptr == nullptr);
if (!implementable) {
@@ -210,8 +210,7 @@ public:
}
// Computes the kernel launch grid shape based on runtime parameters
static constexpr
dim3
static dim3
get_grid_shape(Params const& params) {
auto cluster_shape = ClusterShape{};
auto tile_shape = TileShape{};
@@ -220,8 +219,7 @@ public:
problem_shape_MNKL, tile_shape, cluster_shape);
}
static constexpr
dim3
static dim3
get_block_shape() {
return dim3(MaxThreadsPerBlock, 1, 1);
}
@@ -300,7 +298,7 @@ public:
typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state;
typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state;
// For the DMA Load (producer) we start with an opposite phase
// For the DMA Load (producer) we start with an opposite phase
// i.e., we skip all waits since we know that the buffer is indeed empty
PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state<MainloopPipeline>();
PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state<EpiLoadPipeline>();

View File

@@ -202,20 +202,20 @@ public:
auto N = get<1>(args.problem_shape);
auto K = get<2>(args.problem_shape);
// Contiguous dimension for the TMA tensor should be 128b aligned
implementable = std::is_same_v<gemm::detail::StrideToLayoutTagA_t<StrideA>, layout::RowMajor> ?
implementable = std::is_same_v<gemm::detail::StrideToLayoutTagA_t<StrideA>, layout::RowMajor> ?
K % min_tma_aligned_elements == 0 : M % min_tma_aligned_elements == 0;
implementable = implementable && (std::is_same_v<gemm::detail::StrideToLayoutTagB_t<StrideB>, layout::RowMajor> ?
implementable = implementable && (std::is_same_v<gemm::detail::StrideToLayoutTagB_t<StrideB>, layout::RowMajor> ?
N % min_tma_aligned_elements == 0 : K % min_tma_aligned_elements == 0);
implementable = implementable && (!cutlass::epilogue::collective::detail::IF_EPILOGUE_USES_TMA<CollectiveEpilogue>::value ||
(cutlass::epilogue::collective::detail::IF_EPILOGUE_USES_TMA<CollectiveEpilogue>::value &&
std::is_same_v<gemm::detail::StrideToLayoutTagC_t<StrideC>, layout::RowMajor> ?
std::is_same_v<gemm::detail::StrideToLayoutTagC_t<StrideC>, layout::RowMajor> ?
N % min_tma_aligned_elements == 0 : M % min_tma_aligned_elements == 0));
if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
return implementable;
}
constexpr bool is_beta_supported =
constexpr bool is_beta_supported =
CollectiveEpilogue::ThreadEpilogueOp::kScale == cutlass::epilogue::thread::ScaleType::Default;
implementable = is_beta_supported || (args.epilogue.thread.beta == 0 && args.epilogue.thread.beta_ptr == nullptr);
if (!implementable) {
@@ -233,15 +233,13 @@ public:
}
// Computes the kernel launch grid shape based on runtime parameters
static constexpr
dim3
static dim3
get_grid_shape(Params const& params) {
// Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently
return detail::PersistentTileSchedulerSm90::get_grid_shape(params.problem_shape, TileShape{}, ClusterShape{}, params.hw_info);
}
static constexpr
dim3
static dim3
get_block_shape() {
return dim3(MaxThreadsPerBlock, 1, 1);
}
@@ -333,7 +331,7 @@ public:
typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state;
typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state;
// For the DMA Load (producer) we start with an opposite phase
// For the DMA Load (producer) we start with an opposite phase
// i.e., we skip all waits since we know that the buffer is indeed empty
PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state<MainloopPipeline>();
PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state<EpiLoadPipeline>();

View File

@@ -110,7 +110,7 @@ public:
static constexpr uint32_t LoadRegisterRequirement = 40;
static constexpr uint32_t MmaRegisterRequirement = 232;
// Order Sequence barrier with two stages: one for Mainloop and one for Epilogue
// Order Sequence barrier with two stages: one for Mainloop and one for Epilogue
static constexpr uint32_t StagesPerMathWarpGroup = 2;
using MathWarpGroupOrderBarrier = cutlass::OrderedSequenceBarrier<
StagesPerMathWarpGroup, NumMmaWarpGroups>;
@@ -210,20 +210,20 @@ public:
auto N = get<1>(args.problem_shape);
auto K = get<2>(args.problem_shape);
// Contiguous dimension for the TMA tensor should be 128b aligned
implementable = std::is_same_v<gemm::detail::StrideToLayoutTagA_t<StrideA>, layout::RowMajor> ?
implementable = std::is_same_v<gemm::detail::StrideToLayoutTagA_t<StrideA>, layout::RowMajor> ?
K % min_tma_aligned_elements == 0 : M % min_tma_aligned_elements == 0;
implementable = implementable && (std::is_same_v<gemm::detail::StrideToLayoutTagB_t<StrideB>, layout::RowMajor> ?
implementable = implementable && (std::is_same_v<gemm::detail::StrideToLayoutTagB_t<StrideB>, layout::RowMajor> ?
N % min_tma_aligned_elements == 0 : K % min_tma_aligned_elements == 0);
implementable = implementable && (!cutlass::epilogue::collective::detail::IF_EPILOGUE_USES_TMA<CollectiveEpilogue>::value ||
(cutlass::epilogue::collective::detail::IF_EPILOGUE_USES_TMA<CollectiveEpilogue>::value &&
std::is_same_v<gemm::detail::StrideToLayoutTagC_t<StrideC>, layout::RowMajor> ?
std::is_same_v<gemm::detail::StrideToLayoutTagC_t<StrideC>, layout::RowMajor> ?
N % min_tma_aligned_elements == 0 : M % min_tma_aligned_elements == 0));
if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
return implementable;
}
constexpr bool is_beta_supported =
constexpr bool is_beta_supported =
CollectiveEpilogue::ThreadEpilogueOp::kScale == cutlass::epilogue::thread::ScaleType::Default;
implementable = is_beta_supported || (args.epilogue.thread.beta == 0 && args.epilogue.thread.beta_ptr == nullptr);
if (!implementable) {
@@ -241,15 +241,13 @@ public:
}
// Computes the kernel launch grid shape based on runtime parameters
static constexpr
dim3
static dim3
get_grid_shape(Params const& params) {
// Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently
return detail::PersistentTileSchedulerSm90::get_grid_shape(params.problem_shape, TileShape{}, ClusterShape{}, params.hw_info);
}
static constexpr
dim3
static dim3
get_block_shape() {
return dim3(MaxThreadsPerBlock, 1, 1);
}
@@ -341,7 +339,7 @@ public:
typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state;
typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state;
// For the DMA Load (producer) we start with an opposite phase
// For the DMA Load (producer) we start with an opposite phase
// i.e., we skip all waits since we know that the buffer is indeed empty
PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state<MainloopPipeline>();
PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state<EpiLoadPipeline>();
@@ -389,9 +387,9 @@ public:
detail::PersistentTileSchedulerSm90 scheduler;
if (warp_group_role == WarpGroupRole::Consumer1) {
// Advance 2nd Math WG to the next work tile for the startup
// Advance 2nd Math WG to the next work tile for the startup
scheduler.advance_to_next_work();
// Advance 2nd Math WG pipeline states to the end of 1st Math WG
// Advance 2nd Math WG pipeline states to the end of 1st Math WG
mainloop_pipe_consumer_state.advance(k_tile_count);
epi_load_pipe_consumer_state.advance(c_tile_count);
epi_store_pipe_producer_state.advance(d_tile_count);
@@ -486,7 +484,7 @@ public:
params.mainloop
);
// Cue for next Math WG's MMA to start
// Cue for next Math WG's MMA to start
math_wg_order_barrier.arrive();
// Make sure the math instructions are done and free buffers before entering the epilogue
@@ -522,7 +520,7 @@ public:
// Wait for all TMA stores to complete
epi_store_pipeline.producer_tail(epi_store_pipe_producer_state);
// Cue for next Math WG's Epilogue to start
// Cue for next Math WG's Epilogue to start
math_wg_order_barrier.arrive();
// Get next work tile

View File

@@ -108,7 +108,7 @@ public:
return {work_idx_m, work_idx_n, static_cast<int32_t>(work_idx_l), current_work_linear_idx_ < scheduler_params.blocks_per_problem_};
}
CUTLASS_DEVICE
CUTLASS_DEVICE
void
advance_to_next_work(uint32_t advance_count = 1) {
current_work_linear_idx_ += grid_blocks_total_ * advance_count;
@@ -117,7 +117,7 @@ public:
// Given the inputs, computes the total number of output blocks this problem will compute over
// Note that this is only the logical size of our grid, not the physical grid we will actually launch.
template<class ProblemShapeMNKL, class BlockShape, class ClusterShape>
CUTLASS_HOST_DEVICE constexpr static
CUTLASS_HOST_DEVICE static
dim3
get_tiled_blk_shape_mnl(ProblemShapeMNKL problem_shape_mnkl, BlockShape blk_shape, ClusterShape cluster_shape) {
// Across M and N is our Cluster tile, so we must round up the blocks to the nearest whole number of Cluster tiles
@@ -135,7 +135,7 @@ public:
// Given the inputs, computes the physical grid we should launch.
template<class ProblemShapeMNKL, class BlockShape, class ClusterShape>
CUTLASS_HOST_DEVICE constexpr static
CUTLASS_HOST_DEVICE static
dim3
get_grid_shape(ProblemShapeMNKL problem_shape_mnk, BlockShape blk_shape, ClusterShape cluster_shape, KernelHardwareInfo hw_info) {
int const sm_count = hw_info.sm_count;

View File

@@ -630,6 +630,12 @@ public:
accum = plus_accum(accum, tmp_accum);
}
// Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();
}
};

View File

@@ -690,6 +690,11 @@ public:
__syncthreads();
}
// Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();
}
};

View File

@@ -846,12 +846,10 @@ public:
}
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
// commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();
}
// commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();
}
};

View File

@@ -660,12 +660,10 @@ public:
accum = plus_accum(accum, pipe_state.tmp_accum_);
}
// Optionally commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();
}
// Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();
}

View File

@@ -628,6 +628,12 @@ public:
}
// Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();
}
};

View File

@@ -739,6 +739,11 @@ public:
__syncthreads();
}
// Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();
}
};

View File

@@ -650,6 +650,12 @@ public:
}
// Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();
}
};

View File

@@ -528,12 +528,10 @@ public:
}
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
// commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();
}
// commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();
}
};

View File

@@ -524,10 +524,10 @@ class PredicatedTileAccessIterator<Shape_, Element_, layout::PitchLinear,
if (kAdvanceRank) {
pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided() - 1);
pointer_ += Shape::kContiguous * tile_offset.contiguous();
pointer_ += Shape::kContiguous * tile_offset.contiguous() * sizeof_bits<Element>::value / 8;
} else {
pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous() - 1);
pointer_ += Shape::kStrided * tile_offset.strided();
pointer_ += Shape::kStrided * tile_offset.strided() * sizeof_bits<Element>::value / 8;
}
} else {
coord_offset_.strided() = the_predicates.thread_offset_.strided() + Shape::kStrided * (tile_offset.strided() - kAdvanceRank);

View File

@@ -156,11 +156,11 @@ static bool
can_implement(Arguments const& args);
// Returns a dim3 representing the threadblock shape.
static constexpr dim3
static dim3
get_block_shape();
// Returns a dim3 representing the grid shape in terms of threadblocks.
static constexpr dim3
static dim3
get_grid_shape(Params const& params);
```

View File

@@ -353,8 +353,9 @@ Hopper architecture and beyond so as to indicate new features of the kernel with
To best illustrate this naming convention, we will walk through the meaning of each of the components
in a GEMM kernel used by the profiler:
```
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_f32_128x128x64_2x1x1_0_ntn_align8
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_f16_f32_128x128x64_2x1x1_0_ntn_align8
```
The components within this name are as follows:
@@ -366,8 +367,7 @@ The components within this name are as follows:
* `s`: indicates that the Tensor Core instruction being used accumulates in single precision
(as opposed to `h`, which indicates half precision)
* `64x128x16gemm`: indicates that the shape of the Tensor Core instruction being used (MxNxK) is 64x128x16
* `f16_f16_f32_f16`: indicates that the data types for operands A, B, and C are each `f16`
(half precision) and that accumulation is performed using `f32` (single precision)
* `f16_f16_f32_f16_f16`: indicates that the data types for operands A, B, Accumulator, C and D (in that order).
* `128x128x64`: indicates that the thread block shape used in the GEMM (MxNxK) is 128x128x64
* `2x1x1`: indicates that the cluster shape being used is 2x1x1
* `0`: indicates that the kernel uses the CollectiveBuilder's automatic stage calculation to determine the
@@ -382,12 +382,24 @@ Note that in some special cases where the input A/B types do not match that of t
instruction's, the MMA facing input type is added to the instruction string as well.
```
cutlass3x_sm90_tensorop_s64x128x8tf32gemm_f32_f32_f32_f32_128x128x32_2x1x1_0_tnn_align4
cutlass3x_sm90_tensorop_s64x128x8tf32gemm_f32_f32_f32_f32_f32_128x128x32_2x1x1_0_tnn_align4
```
* `s64x128x8tf32gemm`: indicates that the MMA consumes inputs in `tf32` format, and therefore
the kernel performs rounding of the `f32` values in global memory while loading them into shared memory.
For custom mainloop or epilogue schedules, details of the opted-in schedule are appended to the end of the
kernel name. For example,
```
cutlass3x_sm90_tensorop_h64x128x16gemm_f16_f16_f16_void_f16_128x128x64_1x1x1_0_nnn_align8_warpspecialized_cooperative_epi_tma
```
* `warpspecialized_cooperative`: Mainloop employs a persistent warp-specialized mainloop and kernel schedule.
* `epi_tma`: Kernel epilogue employs TMA based vectorization.
* `f16_f16_f16_void_f16`: In this case, C type is set to `void`, indicating that residual matrix support
is disabled.
# Convolution
The CUTLASS Profiler is capable of executing 2-D and 3-D convolution problems for forwards and backwards

View File

@@ -41,7 +41,6 @@ add_custom_target(
cutlass_test_unit_gemm_device_tensorop_planar_complex
cutlass_test_unit_gemm_device_sparse_tensorop_sm80
cutlass_test_unit_gemv_device
cutlass_test_unit_gemv_device_strided_batched
cutlass_test_unit_gemm_device_tensorop_sm90
cutlass_test_unit_gemm_device_tensorop_cluster_multicast_sm90
)
@@ -61,7 +60,6 @@ add_custom_target(
test_unit_gemm_device_tensorop_planar_complex
test_unit_gemm_device_sparse_tensorop_sm80
test_unit_gemv_device
test_unit_gemv_device_strided_batched
test_unit_gemm_device_tensorop_sm90
)
@@ -500,15 +498,6 @@ cutlass_test_unit_add_executable(
gemv.cu
)
cutlass_test_unit_add_executable(
cutlass_test_unit_gemv_device_strided_batched
BATCH_SOURCES ON
BATCH_SIZE 4
gemv_strided_batched.cu
)
if (NOT CUDA_COMPILER MATCHES "[Cc]lang")
add_dependencies(

View File

@@ -98,7 +98,7 @@ public:
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
uint64_t seed_ = 2080
uint64_t seed_ = 2023
):
init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { }
@@ -156,22 +156,29 @@ public:
/// Initializes data structures
void initialize(
cutlass::MatrixCoord problem_size
cutlass::MatrixCoord problem_size,
int32_t batch_count
) {
//
// Allocate the GEMM workspace
// Allocate the GEMV workspace
//
tensor_A.resize(problem_size);
tensor_B.resize({problem_size.column(), 1});
tensor_C.resize({problem_size.row(), 1});
tensor_D.resize({problem_size.row(), 1});
reference_D.resize({problem_size.row(), 1}, false);
if(std::is_same<LayoutA, cutlass::layout::ColumnMajor>::value) {
tensor_A.resize({problem_size.row(), batch_count * problem_size.column()});
}
else {
tensor_A.resize({batch_count * problem_size.row(), problem_size.column()});
}
tensor_B.resize({batch_count * problem_size.column(), 1});
tensor_C.resize({batch_count * problem_size.row(), 1});
tensor_D.resize({batch_count * problem_size.row(), 1});
reference_D.resize({batch_count * problem_size.row(), 1}, false);
EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019));
EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018));
EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017));
EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 1));
EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2));
EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 3));
// It is possible to randomly initialize to all zeros, so override this with non-zeros
// in the upper left corner of each operand.
@@ -225,9 +232,14 @@ public:
return passed;
}
/// Verifies the result is a GEMM
/// Verifies the result
bool verify(
cutlass::MatrixCoord problem_size,
cutlass::MatrixCoord problem_size,
int32_t batch_count,
int64_t batch_stride_A,
int64_t batch_stride_B,
int64_t batch_stride_C,
int64_t batch_stride_D,
ElementCompute alpha,
ElementCompute beta) {
@@ -242,7 +254,7 @@ public:
ElementCompute, ElementAccumulator
>(
{problem_size.row(), 1, problem_size.column()},
alpha,
alpha,
tensor_A.host_ref(),
Gemv::kTransformA,
tensor_B.host_ref(),
@@ -250,7 +262,12 @@ public:
beta,
tensor_C.host_ref(),
reference_D.host_ref(),
ElementAccumulator(0)
ElementAccumulator(0),
batch_count,
batch_stride_A,
batch_stride_B,
batch_stride_C,
batch_stride_D
);
return compare_reference(problem_size, alpha, beta);
@@ -259,39 +276,50 @@ public:
/// Runs one problem size
bool run(
cutlass::MatrixCoord problem_size,
int32_t batch_count,
int64_t batch_stride_A,
int64_t batch_stride_B,
int64_t batch_stride_C,
int64_t batch_stride_D,
ElementCompute alpha,
ElementCompute beta) {
this->initialize(problem_size);
this->initialize(problem_size, batch_count);
//
// Initialize the GEMM operator
// Initialize the GEMV operator
//
typename Gemv::Arguments arguments{
problem_size,
batch_count,
{alpha, beta},
tensor_A.device_ref(),
tensor_B.device_data(),
tensor_C.device_data(),
tensor_D.device_data(),
tensor_B.layout().stride(0),
tensor_C.layout().stride(0),
tensor_D.layout().stride(0)
batch_stride_A,
batch_stride_B,
batch_stride_C,
batch_stride_D
};
Gemv gemm_op;
cutlass::Status status = gemm_op.can_implement(arguments);
EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status);
size_t workspace_size = Gemv::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
cutlass::Status status = gemm_op.initialize(arguments, workspace.get());
status = gemm_op.initialize(arguments, workspace.get());
EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status);
//
// Run the GEMM
// Run the GEMV
//
status = gemm_op();
@@ -302,8 +330,15 @@ public:
// Verify
//
bool passed = this->verify(problem_size, alpha, beta);
bool passed = this->verify(
problem_size,
batch_count,
batch_stride_A,
batch_stride_B,
batch_stride_C,
batch_stride_D,
alpha,
beta);
return passed;
}
};
@@ -315,12 +350,16 @@ bool TestAllGemv() {
using ElementCompute = typename Gemv::EpilogueOutputOp::ElementCompute;
int Batch[] = {
1, 520, 1314
};
int M[] = {
8, 48, 192, 520
1, 5, 16
};
int K[] = {
8, 192, 528
8, 128, 256
};
double Alpha[] = {
@@ -331,15 +370,25 @@ bool TestAllGemv() {
0, 1, 1.25
};
for (int m : M) {
for (int k : K) {
for (double alpha : Alpha) {
for (double beta : Beta) {
for (int b : Batch) {
for (int m : M) {
for (int k : K) {
for (double alpha : Alpha) {
for (double beta : Beta) {
TestbedGemv<Gemv> testbed;
TestbedGemv<Gemv> testbed;
if (!testbed.run({m, k}, ElementCompute(alpha), ElementCompute(beta))) {
return false;
if (!testbed.run(
{m, k},
b,
m * k,
k,
m,
m,
ElementCompute(alpha),
ElementCompute(beta))) {
return false;
}
}
}
}
@@ -354,66 +403,100 @@ bool TestAllGemv() {
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM50_Device_Gemv_f32n_f32_f32_simt_f32, Simple) {
using ElementOutput = float;
using LayoutA = cutlass::layout::ColumnMajor;
using ElementAccumulator = float;
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput,
1,
ElementAccumulator,
ElementAccumulator>;
using Gemv = cutlass::gemm::device::Gemv<
cutlass::gemm::kernel::Gemv<
ElementOutput, // Element A
LayoutA, // Layout A
ElementOutput, // Element B
ElementOutput, // Element C
ElementAccumulator, // Element Accumulator
EpilogueOp // Output operator
>
>;
EXPECT_TRUE(test::gemm::TestAllGemv<Gemv>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM50_Device_Gemv_f16n_f16_f32_simt_f32, Simple) {
TEST(SM50_Device_Gemv_f16n_f16_f16_simt_f32, RowMajorA) {
using ElementInput = cutlass::half_t;
using ElementOutput = float;
using LayoutA = cutlass::layout::ColumnMajor;
using ElementOutput = cutlass::half_t;
using LayoutA = cutlass::layout::RowMajor;
using ElementAccumulator = float;
int const kElementsPerAccess = 8;
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput,
1,
ElementAccumulator,
ElementAccumulator,
ElementAccumulator>;
using Gemv = cutlass::gemm::device::Gemv<
cutlass::gemm::kernel::Gemv<
ElementInput, // Element A
LayoutA, // Layout A
ElementInput, // Element B
ElementOutput, // Element C
ElementAccumulator, // Element Accumulator
EpilogueOp // Output operator
>
>;
cutlass::gemm::kernel::Gemv<
ElementInput, // Element A
LayoutA, // Layout A
ElementInput, // Element B
ElementOutput, // Element C
ElementAccumulator, // Element accumulator
EpilogueOp, // Output operator
kElementsPerAccess // Element access granularity
>
>;
EXPECT_TRUE(test::gemm::TestAllGemv<Gemv>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM50_Device_Gemv_f16n_f16_f16_simt_f32, Simple) {
TEST(SM50_Device_Gemv_f32n_f32_f32_simt_f32, RowMajorA) {
using ElementInput = float;
using ElementOutput = float;
using LayoutA = cutlass::layout::RowMajor;
using ElementAccumulator = float;
int const kElementsPerAccess = 4;
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput,
1,
ElementAccumulator,
ElementAccumulator>;
using Gemv = cutlass::gemm::device::Gemv<
cutlass::gemm::kernel::Gemv<
ElementInput, // Element A
LayoutA, // Layout A
ElementInput, // Element B
ElementOutput, // Element C
ElementAccumulator, // Element accumulator
EpilogueOp, // Output operator
kElementsPerAccess // Element access granularity
>
>;
EXPECT_TRUE(test::gemm::TestAllGemv<Gemv>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM50_Device_Gemv_f64n_f64_f64_simt_f64, RowMajorA) {
using ElementInput = double;
using ElementOutput = double;
using LayoutA = cutlass::layout::RowMajor;
using ElementAccumulator = double;
int const kElementsPerAccess = 2;
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput,
1,
ElementAccumulator,
ElementAccumulator>;
using Gemv = cutlass::gemm::device::Gemv<
cutlass::gemm::kernel::Gemv<
ElementInput, // Element A
LayoutA, // Layout A
ElementInput, // Element B
ElementOutput, // Element C
ElementAccumulator, // Element accumulator
EpilogueOp, // Output operator
kElementsPerAccess // Element access granularity
>
>;
EXPECT_TRUE(test::gemm::TestAllGemv<Gemv>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM50_Device_Gemv_f16n_f16_f16_simt_f32, ColumnMajorA) {
using ElementInput = cutlass::half_t;
using ElementOutput = cutlass::half_t;
@@ -442,3 +525,63 @@ TEST(SM50_Device_Gemv_f16n_f16_f16_simt_f32, Simple) {
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM50_Device_Gemv_f32n_f32_f32_simt_f32, ColumnMajorA) {
using ElementInput = float;
using ElementOutput = float;
using LayoutA = cutlass::layout::ColumnMajor;
using ElementAccumulator = float;
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput,
1,
ElementAccumulator,
ElementAccumulator>;
using Gemv = cutlass::gemm::device::Gemv<
cutlass::gemm::kernel::Gemv<
ElementInput, // Element A
LayoutA, // Layout A
ElementInput, // Element B
ElementOutput, // Element C
ElementAccumulator, // Element Accumulator
EpilogueOp // Output operator
>
>;
EXPECT_TRUE(test::gemm::TestAllGemv<Gemv>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM50_Device_Gemv_f64n_f64_f64_simt_f64, ColumnMajorA) {
using ElementInput = double;
using ElementOutput = double;
using LayoutA = cutlass::layout::ColumnMajor;
using ElementAccumulator = double;
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput,
1,
ElementAccumulator,
ElementAccumulator>;
using Gemv = cutlass::gemm::device::Gemv<
cutlass::gemm::kernel::Gemv<
ElementInput, // Element A
LayoutA, // Layout A
ElementInput, // Element B
ElementOutput, // Element C
ElementAccumulator, // Element Accumulator
EpilogueOp // Output operator
>
>;
EXPECT_TRUE(test::gemm::TestAllGemv<Gemv>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -1,490 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide strided batched GEMV interface
*/
#include <iostream>
#include <fstream>
#include <sstream>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/kernel/gemv_strided_batched.h"
#include "cutlass/gemm/device/gemv_strided_batched.h"
#include "../../common/cutlass_unit_test.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/util/reference/host/gemm_complex.h"
#include "testbed_utils.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace test {
namespace gemm {
template <typename GemvStridedBatched>
class TestbedStridedBatchedGemv
{
public:
using ElementA = typename GemvStridedBatched::ElementA;
using LayoutA = typename GemvStridedBatched::LayoutA;
using ElementB = typename GemvStridedBatched::ElementB;
using ElementC = typename GemvStridedBatched::ElementC;
using ElementAccumulator = typename GemvStridedBatched::ElementAccumulator;
using ElementCompute = typename GemvStridedBatched::EpilogueOutputOp::ElementCompute;
using LayoutV = cutlass::layout::RowMajor;
private:
/// Initialization
cutlass::Distribution::Kind init_A;
cutlass::Distribution::Kind init_B;
cutlass::Distribution::Kind init_C;
uint64_t seed;
cutlass::HostTensor<ElementA, LayoutA> tensor_A;
cutlass::HostTensor<ElementB, LayoutV> tensor_B;
cutlass::HostTensor<ElementC, LayoutV> tensor_C;
cutlass::HostTensor<ElementC, LayoutV> tensor_D;
cutlass::HostTensor<ElementC, LayoutV> reference_D;
public:
//
// Methods
//
TestbedStridedBatchedGemv(
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
uint64_t seed_ = 2023):
init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) {}
/// Helper to initialize a tensor view
template <typename Element, typename Layout>
bool initialize_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
if (dist_kind == cutlass::Distribution::Uniform) {
double scope_max, scope_min;
int bits_input = cutlass::sizeof_bits<Element>::value;
int bits_output = cutlass::sizeof_bits<typename GemvStridedBatched::ElementC>::value;
if (bits_input == 1) {
scope_max = 2;
scope_min = 0;
} else if (bits_input <= 8) {
scope_max = 2;
scope_min = -2;
} else if (bits_output == 16) {
scope_max = 5;
scope_min = -5;
} else {
scope_max = 8;
scope_min = -8;
}
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
}
else if (dist_kind == cutlass::Distribution::Sequential) {
cutlass::reference::host::BlockFillSequential(
view.data(), view.capacity());
}
else {
// TODO: Implement the rest
EXPECT_TRUE(false) << "Not implemented";
return false;
}
return true;
}
/// Initializes data structures
void initialize(
cutlass::MatrixCoord problem_size,
int32_t batch_count
) {
//
// Allocate the GEMV workspace
//
tensor_A.resize({batch_count * problem_size.row(), problem_size.column()});
tensor_B.resize({batch_count * problem_size.column(), 1});
tensor_C.resize({batch_count * problem_size.row(), 1});
tensor_D.resize({batch_count * problem_size.row(), 1});
reference_D.resize({batch_count * problem_size.row(), 1}, false);
EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 1));
EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2));
EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 3));
// It is possible to randomly initialize to all zeros, so override this with non-zeros
// in the upper left corner of each operand.
tensor_A.host_view().at({0, 0}) = typename GemvStridedBatched::ElementA(1);
tensor_B.host_view().at({0, 0}) = typename GemvStridedBatched::ElementB(1);
tensor_C.host_view().at({0, 0}) = typename GemvStridedBatched::ElementC(1);
cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view());
tensor_A.sync_device();
tensor_B.sync_device();
tensor_C.sync_device();
tensor_D.sync_device();
}
/// Compares computed reference with device reference and outputs to a file if incorrect
bool compare_reference(
cutlass::MatrixCoord problem_size,
ElementCompute alpha,
ElementCompute beta) {
tensor_D.sync_host();
EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0);
EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0);
EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0);
EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0);
EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0);
bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view());
EXPECT_TRUE(passed) << " mismatched reference";
if (!passed) {
std::ofstream file("testbed_universal_errors.txt");
file
<< "problem: " << problem_size
<< ", alpha: " << alpha << ", beta: " << beta << "\n\n";
file
<< "A =\n" << tensor_A.host_view()
<< "\nB =\n" << tensor_B.host_view()
<< "\nC =\n" << tensor_C.host_view()
<< "\n\nReference =\n" << reference_D.host_view()
<< "\nComputed =\n" << tensor_D.host_view();
}
return passed;
}
/// Verifies the result
bool verify(
cutlass::MatrixCoord problem_size,
int32_t batch_count,
int64_t batch_stride_A,
int64_t batch_stride_B,
int64_t batch_stride_C,
int64_t batch_stride_D,
ElementCompute alpha,
ElementCompute beta) {
//
// Verify
//
cutlass::reference::host::GemmComplex<
typename GemvStridedBatched::ElementA, typename GemvStridedBatched::LayoutA,
typename GemvStridedBatched::ElementB, LayoutV,
typename GemvStridedBatched::ElementC, LayoutV,
ElementCompute, ElementAccumulator
>(
{problem_size.row(), 1, problem_size.column()},
alpha,
tensor_A.host_ref(),
GemvStridedBatched::kTransformA,
tensor_B.host_ref(),
GemvStridedBatched::kTransformB,
beta,
tensor_C.host_ref(),
reference_D.host_ref(),
ElementAccumulator(0),
batch_count,
batch_stride_A,
batch_stride_B,
batch_stride_C,
batch_stride_D
);
return compare_reference(problem_size, alpha, beta);
}
/// Runs one problem size
bool run(
cutlass::MatrixCoord problem_size,
int32_t batch_count,
int64_t batch_stride_A,
int64_t batch_stride_B,
int64_t batch_stride_C,
int64_t batch_stride_D,
ElementCompute alpha,
ElementCompute beta) {
this->initialize(problem_size, batch_count);
//
// Initialize the GEMV operator
//
typename GemvStridedBatched::Arguments arguments{
problem_size,
batch_count,
{alpha, beta},
tensor_A.device_ref(),
tensor_B.device_data(),
tensor_C.device_data(),
tensor_D.device_data(),
batch_stride_A,
batch_stride_B,
batch_stride_C,
batch_stride_D
};
GemvStridedBatched gemm_op;
cutlass::Status status = gemm_op.can_implement(arguments);
EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status);
size_t workspace_size = GemvStridedBatched::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
status = gemm_op.initialize(arguments, workspace.get());
EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status);
//
// Run the GEMV
//
status = gemm_op();
EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status);
//
// Verify
//
bool passed = this->verify(
problem_size,
batch_count,
batch_stride_A,
batch_stride_B,
batch_stride_C,
batch_stride_D,
alpha,
beta);
return passed;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename GemvStridedBatched>
bool TestAllGemv() {
using ElementCompute = typename GemvStridedBatched::EpilogueOutputOp::ElementCompute;
int Batch[] = {
1, 520, 1314
};
int M[] = {
1, 5, 16
};
int K[] = {
8, 128, 256
};
double Alpha[] = {
1, 1.25
};
double Beta[] = {
0, 1, 1.25
};
for (int b : Batch) {
for (int m : M) {
for (int k : K) {
for (double alpha : Alpha) {
for (double beta : Beta) {
TestbedStridedBatchedGemv<GemvStridedBatched> testbed;
if (!testbed.run(
{m, k},
b,
m * k,
k,
m,
m,
ElementCompute(alpha),
ElementCompute(beta))) {
return false;
}
}
}
}
}
}
return true;
}
} // namespace gemm
} // namespace test
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM50_Device_StridedBatchedGemv_f16n_f16_f16_simt_f32, Simple) {
using ElementInput = cutlass::half_t;
using ElementOutput = cutlass::half_t;
using LayoutA = cutlass::layout::RowMajor;
using ElementAccumulator = float;
int const kElementsPerAccess = 8;
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput,
1,
ElementAccumulator,
ElementAccumulator>;
using GemvStridedBatched = cutlass::gemm::device::GemvStridedBatched<
cutlass::gemm::kernel::GemvStridedBatched<
ElementInput, // Element A
LayoutA, // Layout A
ElementInput, // Element B
ElementOutput, // Element C
ElementAccumulator, // Element accumulator
kElementsPerAccess, // Element access granularity
EpilogueOp // Output operator
>>;
EXPECT_TRUE(test::gemm::TestAllGemv<GemvStridedBatched>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM50_Device_StridedBatchedGemv_f32n_f32_f32_simt_f32, Simple) {
using ElementInput = float;
using ElementOutput = float;
using LayoutA = cutlass::layout::RowMajor;
using ElementAccumulator = float;
int const kElementsPerAccess = 4;
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput,
1,
ElementAccumulator,
ElementAccumulator>;
using GemvStridedBatched = cutlass::gemm::device::GemvStridedBatched<
cutlass::gemm::kernel::GemvStridedBatched<
ElementInput, // Element A
LayoutA, // Layout A
ElementInput, // Element B
ElementOutput, // Element C
ElementAccumulator, // Element accumulator
kElementsPerAccess, // Element access granularity
EpilogueOp // Output operator
>>;
EXPECT_TRUE(test::gemm::TestAllGemv<GemvStridedBatched>());}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM50_Device_StridedBatchedGemv_f64n_f64_f64_simt_f64, Simple) {
using ElementInput = double;
using ElementOutput = double;
using LayoutA = cutlass::layout::RowMajor;
using ElementAccumulator = double;
int const kElementsPerAccess = 2;
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput,
1,
ElementAccumulator,
ElementAccumulator>;
using GemvStridedBatched = cutlass::gemm::device::GemvStridedBatched<
cutlass::gemm::kernel::GemvStridedBatched<
ElementInput, // Element A
LayoutA, // Layout A
ElementInput, // Element B
ElementOutput, // Element C
ElementAccumulator, // Element accumulator
kElementsPerAccess, // Element access granularity
EpilogueOp // Output operator
>>;
EXPECT_TRUE(test::gemm::TestAllGemv<GemvStridedBatched>());}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -773,7 +773,7 @@ TEST(SM90_Device_Gemm_f16n_f16n_f32n_tensor_op_gmma_f32_cooperative, 256x128x64_
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1) {
TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
@@ -810,7 +810,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 25
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1) {
TEST(SM90_Device_Gemm_f16t_f16n_f16t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
@@ -847,4 +847,78 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 128x128x64_2x2x1) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using TileShape_MNK = Shape<_128,_128,_64>;
using ClusterShape_MNK = Shape<_2,_2,_1>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape_MNK, ClusterShape_MNK,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
float, LayoutC, 4,
float, LayoutC, 4,
cutlass::epilogue::TmaWarpSpecializedCooperative
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, LayoutA, 8,
cutlass::half_t, LayoutB, 8,
float,
TileShape_MNK, ClusterShape_MNK,
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
cutlass::gemm::KernelTmaWarpSpecializedCooperative
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 128x128x64_2x2x1) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using TileShape_MNK = Shape<_128,_128,_64>;
using ClusterShape_MNK = Shape<_2,_2,_1>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape_MNK, ClusterShape_MNK,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
float, LayoutC, 4,
float, LayoutC, 4,
cutlass::epilogue::TmaWarpSpecializedCooperative
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, LayoutA, 8,
cutlass::half_t, LayoutB, 8,
float,
TileShape_MNK, ClusterShape_MNK,
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
cutlass::gemm::KernelTmaWarpSpecializedCooperative
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)

View File

@@ -363,4 +363,48 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25
EXPECT_TRUE(passed);
}
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasMul_ReLU_VoidC) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using TileShape_MNK = Shape<_256,_128,_64>;
using ClusterShape_MNK = Shape<_2,_2,_1>;
static constexpr bool StoreT = true;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperativeBiasElementwise<
cutlass::epilogue::thread::ReLu, cutlass::half_t, cutlass::multiplies, StoreT, float>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape_MNK, ClusterShape_MNK,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
void, LayoutC, 8,
cutlass::half_t, LayoutC, 8,
EpilogueSchedule
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, LayoutA, 8,
cutlass::half_t, LayoutB, 8,
float,
TileShape_MNK, ClusterShape_MNK,
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
cutlass::gemm::KernelTmaWarpSpecializedCooperative
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
bool passed = test::gemm::device::TestAllBiasElementwise<Gemm>();
EXPECT_TRUE(passed);
}
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)

View File

@@ -1104,13 +1104,13 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16t_tensor_op_gmma_f32_persistent_Epilogue, 128
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_persistent, 128x128x64_2x2x1) {
TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1) {
using ElementA = cutlass::half_t;
using LayoutA = cutlass::layout::RowMajor;
using ElementB = cutlass::half_t;
using LayoutB = cutlass::layout::ColumnMajor;
using ElementAccumulator = float;
using ElementC = ElementA;
using ElementC = cutlass::half_t;
using LayoutC = cutlass::layout::ColumnMajor;
using TileShape_MNK = Shape<_128,_128,_64>;
using ClusterShape_MNK = Shape<_2,_2,_1>;
@@ -1121,20 +1121,20 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_persistent, 128x128x64_2
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape_MNK, ClusterShape_MNK,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
cutlass::half_t, LayoutC, 8,
cutlass::half_t, LayoutC, 8,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutC, 16 / sizeof(ElementC),
ElementC, LayoutC, 16 / sizeof(ElementC),
cutlass::epilogue::TmaWarpSpecialized
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
ElementA, LayoutA, 8,
ElementB, LayoutB, 8,
ElementA, LayoutA, 16 / sizeof(ElementA),
ElementB, LayoutB, 16 / sizeof(ElementB),
ElementAccumulator,
TileShape_MNK, ClusterShape_MNK,
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
cutlass::gemm::KernelTmaWarpSpecializedPingpong
KernelSchedule
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
@@ -1147,13 +1147,13 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_persistent, 128x128x64_2
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent, 128x128x64_2x2x1) {
TEST(SM90_Device_Gemm_f16t_f16n_f16t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1) {
using ElementA = cutlass::half_t;
using LayoutA = cutlass::layout::RowMajor;
using ElementB = cutlass::half_t;
using LayoutB = cutlass::layout::ColumnMajor;
using ElementAccumulator = float;
using ElementC = ElementA;
using ElementC = cutlass::half_t;
using LayoutC = cutlass::layout::RowMajor;
using TileShape_MNK = Shape<_128,_128,_64>;
using ClusterShape_MNK = Shape<_2,_2,_1>;
@@ -1164,20 +1164,106 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent, 128x128x64_2
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape_MNK, ClusterShape_MNK,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
cutlass::half_t, LayoutC, 8,
cutlass::half_t, LayoutC, 8,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutC, 16 / sizeof(ElementC),
ElementC, LayoutC, 16 / sizeof(ElementC),
cutlass::epilogue::TmaWarpSpecialized
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
ElementA, LayoutA, 8,
ElementB, LayoutB, 8,
ElementA, LayoutA, 16 / sizeof(ElementA),
ElementB, LayoutB, 16 / sizeof(ElementB),
ElementAccumulator,
TileShape_MNK, ClusterShape_MNK,
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
cutlass::gemm::KernelTmaWarpSpecializedPingpong
KernelSchedule
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1) {
using ElementA = cutlass::half_t;
using LayoutA = cutlass::layout::RowMajor;
using ElementB = cutlass::half_t;
using LayoutB = cutlass::layout::ColumnMajor;
using ElementAccumulator = float;
using ElementC = float;
using LayoutC = cutlass::layout::ColumnMajor;
using TileShape_MNK = Shape<_128,_128,_64>;
using ClusterShape_MNK = Shape<_2,_2,_1>;
using StageCountType = cutlass::gemm::collective::StageCountAuto;
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape_MNK, ClusterShape_MNK,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutC, 16 / sizeof(ElementC),
ElementC, LayoutC, 16 / sizeof(ElementC),
cutlass::epilogue::TmaWarpSpecialized
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
ElementA, LayoutA, 16 / sizeof(ElementA),
ElementB, LayoutB, 16 / sizeof(ElementB),
ElementAccumulator,
TileShape_MNK, ClusterShape_MNK,
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
KernelSchedule
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1) {
using ElementA = cutlass::half_t;
using LayoutA = cutlass::layout::RowMajor;
using ElementB = cutlass::half_t;
using LayoutB = cutlass::layout::ColumnMajor;
using ElementAccumulator = float;
using ElementC = float;
using LayoutC = cutlass::layout::RowMajor;
using TileShape_MNK = Shape<_128,_128,_64>;
using ClusterShape_MNK = Shape<_2,_2,_1>;
using StageCountType = cutlass::gemm::collective::StageCountAuto;
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape_MNK, ClusterShape_MNK,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutC, 16 / sizeof(ElementC),
ElementC, LayoutC, 16 / sizeof(ElementC),
cutlass::epilogue::TmaWarpSpecialized
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
ElementA, LayoutA, 16 / sizeof(ElementA),
ElementB, LayoutB, 16 / sizeof(ElementB),
ElementAccumulator,
TileShape_MNK, ClusterShape_MNK,
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
KernelSchedule
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<

View File

@@ -362,4 +362,48 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128
EXPECT_TRUE(passed);
}
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_BiasMul_ReLU_VoidC) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using TileShape_MNK = Shape<_128,_128,_64>;
using ClusterShape_MNK = Shape<_2,_2,_1>;
static constexpr bool StoreT = true;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedBiasElementwise<
cutlass::epilogue::thread::ReLu, cutlass::half_t, cutlass::multiplies, StoreT, float>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape_MNK, ClusterShape_MNK,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
void, LayoutC, 8,
cutlass::half_t, LayoutC, 8,
EpilogueSchedule
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, LayoutA, 8,
cutlass::half_t, LayoutB, 8,
float,
TileShape_MNK, ClusterShape_MNK,
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
cutlass::gemm::KernelTmaWarpSpecializedPingpong
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
bool passed = test::gemm::device::TestAllBiasElementwise<Gemm>();
EXPECT_TRUE(passed);
}
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)

View File

@@ -269,6 +269,150 @@ TEST(SM90_Device_Gemm_s8t_s8n_s8n_tensor_op_gmma_s32, 128x128x128_2x2x1) {
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
TEST(SM90_Device_Gemm_s8t_s8n_s8n_tensor_op_gmma_s32_pingpong_epilogue, 64x128x128) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_64,_128,_128>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
int32_t, int32_t,
int8_t, LayoutC, 16,
int8_t, LayoutC, 16,
cutlass::epilogue::TmaWarpSpecialized
>::CollectiveOp;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
int8_t, LayoutA, 16,
int8_t, LayoutB, 16,
int32_t,
Shape<_64,_128,_128>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
cutlass::gemm::KernelTmaWarpSpecializedPingpong
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
TEST(SM90_Device_Gemm_s8t_s8n_s8t_tensor_op_gmma_s32_pingpong_epilogue, 64x128x128) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_64,_128,_128>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
int32_t, int32_t,
int8_t, LayoutC, 16,
int8_t, LayoutC, 16,
cutlass::epilogue::TmaWarpSpecialized
>::CollectiveOp;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
int8_t, LayoutA, 16,
int8_t, LayoutB, 16,
int32_t,
Shape<_64,_128,_128>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
cutlass::gemm::KernelTmaWarpSpecializedPingpong
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
TEST(SM90_Device_Gemm_s8t_s8n_s8n_tensor_op_gmma_s32_cooperative_epilogue, 128x128x128) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
int32_t, int32_t,
int8_t, LayoutC, 16,
int8_t, LayoutC, 16,
cutlass::epilogue::TmaWarpSpecializedCooperative
>::CollectiveOp;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
int8_t, LayoutA, 16,
int8_t, LayoutB, 16,
int32_t,
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
cutlass::gemm::KernelTmaWarpSpecializedCooperative
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
TEST(SM90_Device_Gemm_s8t_s8n_s8t_tensor_op_gmma_s32_cooperative_epilogue, 128x128x128) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
int32_t, int32_t,
int8_t, LayoutC, 16,
int8_t, LayoutC, 16,
cutlass::epilogue::TmaWarpSpecializedCooperative
>::CollectiveOp;
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
int8_t, LayoutA, 16,
int8_t, LayoutB, 16,
int32_t,
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
cutlass::gemm::KernelTmaWarpSpecializedCooperative
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
}
///////////////////////////////////////////////////////////////////////////////
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)

View File

@@ -4079,32 +4079,36 @@ def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version):
max_cc = 90
for math_inst in math_instructions:
tile_descriptions = [
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
tile_descriptions_small = [
# Not compatible with TmaWarpSpecializedCooperative
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
]
tile_descriptions_medium = [
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
]
tile_descriptions_large = [
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4],
0, [4, 2, 1], math_inst, min_cc, max_cc, [2,1,1]),
#TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
# 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), - Not compatible with TmaWarpSpecializedCooperative
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4],
0, [4, 2, 1], math_inst, min_cc, max_cc, [1,2,1]),
#TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
# 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),- Not compatible with TmaWarpSpecializedCooperative
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
#TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
# 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), - Not compatible with TmaWarpSpecializedCooperative
]
tile_descriptions = tile_descriptions_medium + tile_descriptions_large
data_type = {
"a_type" : math_inst.element_a,
@@ -4139,11 +4143,17 @@ def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version):
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, schedules)
# persistent kernels with TMA epilogues
if data_type["c_type"] in [DataType.f16, DataType.bf16] and CudaToolkitVersionSatisfies(cuda_version, 12, 1):
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
[[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecialized],
[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative]])
# Emit instance without C allocation+load
if CudaToolkitVersionSatisfies(cuda_version, 12, 1):
# not enough smem for 256x128 f32 out with C allocation
if data_type["d_type"] == DataType.f32:
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions_medium, data_type,
[[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecialized],
[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative]])
else:
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
[[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecialized],
[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative]])
# Emit instance without C allocation + load
data_type["c_type"] = DataType.void
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
[[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecialized],
@@ -4170,7 +4180,7 @@ def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version):
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type_mixed, schedules)
# persistent kernels with TMA epilogues
if data_type_mixed["c_type"] in [DataType.f16, DataType.bf16] and CudaToolkitVersionSatisfies(cuda_version, 12, 1):
if CudaToolkitVersionSatisfies(cuda_version, 12, 1):
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type_mixed,
[[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecialized],
[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative]])
@@ -4331,6 +4341,28 @@ def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version):
for data_type in data_types:
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type)
# persistent kernels with TMA epilogues
if CudaToolkitVersionSatisfies(cuda_version, 12, 1):
# Emit instance without C allocation+load
data_types += [
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : DataType.void,
"d_type" : math_inst.element_accumulator,
"acc_type" : math_inst.element_accumulator,
"epi_type" : math_inst.element_accumulator
}
]
for data_type in data_types:
# Set alignment d based on Destination format.
for layout in layouts:
layout[2][1] = 128 // DataTypeSize[data_type["d_type"]]
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
[[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecialized],
[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative]])
#
def GenerateSM90_TensorOp_1684(manifest, cuda_version):

View File

@@ -345,358 +345,6 @@ void initialize_gemm_reference_operations(Manifest &manifest) {
complex<double>
>(manifest);
//
// FP8 GEMMs
//
//////////////////////////////////
/// ElementC: half_t
//////////////////////////////////
make_gemm_real_canonical_layouts<
float_e4m3_t, // ElementA
float_e4m3_t, // ElementB
half_t, // ElementC
float, // ElementScalar
float, // ElementAccumulator
half_t // ElementD
>(manifest);
make_gemm_real_canonical_layouts<
float_e4m3_t, // ElementA
float_e4m3_t, // ElementB
half_t, // ElementC
float, // ElementScalar
float , // ElementAccumulator
float_e4m3_t // ElementD
>(manifest);
make_gemm_real_canonical_layouts<
float_e4m3_t, // ElementA
float_e4m3_t, // ElementB
half_t, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float_e5m2_t // ElementD
>(manifest);
//////////////////////////////////
make_gemm_real_canonical_layouts<
float_e4m3_t, // ElementA
float_e5m2_t, // ElementB
half_t, // ElementC
float, // ElementScalar
float , // ElementAccumulator
half_t // ElementD
>(manifest);
make_gemm_real_canonical_layouts<
float_e4m3_t, // ElementA
float_e5m2_t, // ElementB
half_t, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float_e4m3_t // ElementD
>(manifest);
make_gemm_real_canonical_layouts<
float_e4m3_t, // ElementA
float_e5m2_t, // ElementB
half_t, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float_e5m2_t // ElementD
>(manifest);
//////////////////////////////////
make_gemm_real_canonical_layouts<
float_e5m2_t, // ElementA
float_e4m3_t, // ElementB
half_t, // ElementC
float, // ElementScalar
float, // ElementAccumulator
half_t // ElementD
>(manifest);
make_gemm_real_canonical_layouts<
float_e5m2_t, // ElementA
float_e4m3_t, // ElementB
half_t, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float_e4m3_t // ElementD
>(manifest);
make_gemm_real_canonical_layouts<
float_e5m2_t, // ElementA
float_e4m3_t, // ElementB
half_t, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float_e5m2_t // ElementD
>(manifest);
//////////////////////////////////
make_gemm_real_canonical_layouts<
float_e5m2_t, // ElementA
float_e5m2_t, // ElementB
half_t, // ElementC
float, // ElementScalar
float, // ElementAccumulator
half_t // ElementD
>(manifest);
make_gemm_real_canonical_layouts<
float_e5m2_t, // ElementA
float_e5m2_t, // ElementB
half_t, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float_e4m3_t // ElementD
>(manifest);
make_gemm_real_canonical_layouts<
float_e5m2_t, // ElementA
float_e5m2_t, // ElementB
half_t, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float_e5m2_t // ElementD
>(manifest);
//////////////////////////////////
/// ElementC: bfloat16_t
//////////////////////////////////
make_gemm_real_canonical_layouts<
float_e4m3_t, // ElementA
float_e4m3_t, // ElementB
bfloat16_t, // ElementC
float, // ElementScalar
float, // ElementAccumulator
bfloat16_t // ElementD
>(manifest);
make_gemm_real_canonical_layouts<
float_e4m3_t, // ElementA
float_e4m3_t, // ElementB
bfloat16_t, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float_e4m3_t // ElementD
>(manifest);
make_gemm_real_canonical_layouts<
float_e4m3_t, // ElementA
float_e4m3_t, // ElementB
bfloat16_t, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float_e5m2_t // ElementD
>(manifest);
//////////////////////////////////
make_gemm_real_canonical_layouts<
float_e4m3_t, // ElementA
float_e5m2_t, // ElementB
bfloat16_t, // ElementC
float, // ElementScalar
float, // ElementAccumulator
bfloat16_t // ElementD
>(manifest);
make_gemm_real_canonical_layouts<
float_e4m3_t, // ElementA
float_e5m2_t, // ElementB
bfloat16_t, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float_e4m3_t // ElementD
>(manifest);
make_gemm_real_canonical_layouts<
float_e4m3_t, // ElementA
float_e5m2_t, // ElementB
bfloat16_t, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float_e5m2_t // ElementD
>(manifest);
//////////////////////////////////
make_gemm_real_canonical_layouts<
float_e5m2_t, // ElementA
float_e4m3_t, // ElementB
bfloat16_t, // ElementC
float, // ElementScalar
float, // ElementAccumulator
bfloat16_t // ElementD
>(manifest);
make_gemm_real_canonical_layouts<
float_e5m2_t, // ElementA
float_e4m3_t, // ElementB
bfloat16_t, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float_e4m3_t // ElementD
>(manifest);
make_gemm_real_canonical_layouts<
float_e5m2_t, // ElementA
float_e4m3_t, // ElementB
bfloat16_t, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float_e5m2_t // ElementD
>(manifest);
//////////////////////////////////
make_gemm_real_canonical_layouts<
float_e5m2_t, // ElementA
float_e5m2_t, // ElementB
bfloat16_t, // ElementC
float, // ElementScalar
float, // ElementAccumulator
bfloat16_t // ElementD
>(manifest);
make_gemm_real_canonical_layouts<
float_e5m2_t, // ElementA
float_e5m2_t, // ElementB
bfloat16_t, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float_e4m3_t // ElementD
>(manifest);
make_gemm_real_canonical_layouts<
float_e5m2_t, // ElementA
float_e5m2_t, // ElementB
bfloat16_t, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float_e5m2_t // ElementD
>(manifest);
//////////////////////////////////
/// ElementC: float
//////////////////////////////////
make_gemm_real_canonical_layouts<
float_e4m3_t, // ElementA
float_e4m3_t, // ElementB
float, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float // ElementD
>(manifest);
make_gemm_real_canonical_layouts<
float_e4m3_t, // ElementA
float_e4m3_t, // ElementB
float, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float_e4m3_t // ElementD
>(manifest);
make_gemm_real_canonical_layouts<
float_e4m3_t, // ElementA
float_e4m3_t, // ElementB
float, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float_e5m2_t // ElementD
>(manifest);
//////////////////////////////////
make_gemm_real_canonical_layouts<
float_e4m3_t, // ElementA
float_e5m2_t, // ElementB
float, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float // ElementD
>(manifest);
make_gemm_real_canonical_layouts<
float_e4m3_t, // ElementA
float_e5m2_t, // ElementB
float, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float_e4m3_t // ElementD
>(manifest);
make_gemm_real_canonical_layouts<
float_e4m3_t, // ElementA
float_e5m2_t, // ElementB
float, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float_e5m2_t // ElementD
>(manifest);
//////////////////////////////////
make_gemm_real_canonical_layouts<
float_e5m2_t, // ElementA
float_e4m3_t, // ElementB
float, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float // ElementD
>(manifest);
make_gemm_real_canonical_layouts<
float_e5m2_t, // ElementA
float_e4m3_t, // ElementB
float, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float_e4m3_t // ElementD
>(manifest);
make_gemm_real_canonical_layouts<
float_e5m2_t, // ElementA
float_e4m3_t, // ElementB
float, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float_e5m2_t // ElementD
>(manifest);
//////////////////////////////////
make_gemm_real_canonical_layouts<
float_e5m2_t, // ElementA
float_e5m2_t, // ElementB
float, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float // ElementD
>(manifest);
make_gemm_real_canonical_layouts<
float_e5m2_t, // ElementA
float_e5m2_t, // ElementB
float, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float_e4m3_t // ElementD
>(manifest);
make_gemm_real_canonical_layouts<
float_e5m2_t, // ElementA
float_e5m2_t, // ElementB
float, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float_e5m2_t // ElementD
>(manifest);
}
///////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -103,14 +103,20 @@ bool get_cublas_transpose_operation(
/// Maps a CUTLASS numeric type to a cuBLAS data type enumeration
bool get_cublas_datatype(cublasDataType_t &data_type, library::NumericTypeID element_type) {
switch (element_type) {
switch (element_type) {
case library::NumericTypeID::kFE4M3:
#if (__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8))
data_type = CUDA_R_8F_E4M3;
return true;
#endif
break;
case library::NumericTypeID::kFE5M2:
#if (__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8))
data_type = CUDA_R_8F_E5M2;
return true;
#endif
break;
case library::NumericTypeID::kF16:
data_type = CUDA_R_16F;
@@ -139,7 +145,7 @@ bool get_cublas_datatype(cublasDataType_t &data_type, library::NumericTypeID ele
return true;
case library::NumericTypeID::kS16:
break;
break;
case library::NumericTypeID::kS32:
data_type = CUDA_R_32I;
@@ -260,6 +266,13 @@ Status cublas_satisfies(library::GemmDescription const &desc) {
return Status::kErrorNotSupported;
}
// input type BF16 and TF32 not supported in cuBLAS
if (desc.A.element == library::NumericTypeID::kBF16 ||
desc.A.element == library::NumericTypeID::kTF32) {
return Status::kErrorNotSupported;
}
return Status::kSuccess;
}