mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-20 14:59:01 +00:00
More updates for 3.1 (#958)
* Updates for 3.1 * Minor change * doc link fix * Minor updates
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
|
||||
## [3.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.1.0) (2023-04-14)
|
||||
* New CUTLASS Python interface that aims to provide an ease-of-use interface for instantiating, emitting, compiling, and running CUTLASS kernels via Python. More details [here](/python/README.md) and new [examples](/examples/python).
|
||||
* New [efficient epilogues](test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative.cu#L783) for FP16 datatype using TMA for Hopper.
|
||||
* New [efficient epilogues](test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative.cu#L783) using TMA for Hopper.
|
||||
* Support for [fused epilogues](test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu), such Bias, ReLU and GELU, using the new efficient epilogues.
|
||||
* New [warp-specialized TensorFloat-32 (TF32) GEMM kernels](test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu) targeting Hopper TMA.
|
||||
* New [*warp-specialized persistent cooperative*](include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp) kernel design that allows for larger tile sizes and improves performance on Hopper.
|
||||
@@ -12,6 +12,11 @@
|
||||
* Profiler support for overriding kernel and epilogue builder auto schedules for 3.x API kernels, allowing specific policies to be run in the CUTLASS profiler.
|
||||
* Performance optimizations for the [*warp-specialized persistent ping-pong*](include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp) kernel.
|
||||
* Changes to the [GEMM API 3.x](media/docs/gemm_api_3x.md), involving the host-facing arguments and the underlying `Params` structs.
|
||||
* [FMHA Backward Pass](examples/41_fused_multi_head_attention/fused_multi_head_attention_backward.cu) from Meta xFormers.
|
||||
* [Streamk GEMM with Broadcast](examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu) enables epilogue broadcast with StreamK GEMM.
|
||||
* [Batched B2B GEMM](examples/13_two_tensor_op_fusion) now can run multiple Back-to-Back GEMM with the same problem size in parallel.
|
||||
* [Batched Strided GEMV](test/unit/gemm/device/gemv.cu) support both row major and column major input matrix.
|
||||
* [Permute + GEMM fusion](examples/39_gemm_permute) can fuse Permute with following GEMM now. Before, we only support fusing GEMM with Permute in the epilogue.
|
||||
* The GitHub branch is renamed from `master` to `main` in this release.
|
||||
* Optimal performance using [**CUDA 12.1**](https://developer.nvidia.com/cuda-downloads)
|
||||
* Updates and bugfixes from the community (thanks!)
|
||||
|
||||
@@ -46,7 +46,7 @@ In addition to GEMMs, CUTLASS implements high-performance convolution via the im
|
||||
CUTLASS 3.1 is an update to CUTLASS adding:
|
||||
|
||||
- New CUTLASS Python interface that aims to provide an ease-of-use interface for instantiating, emitting, compiling, and running CUTLASS kernels via Python. More details [here](/python/README.md) and new [examples](/examples/python).
|
||||
- New [efficient epilogues](test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative.cu#L783) for FP16 datatype using TMA for Hopper.
|
||||
- New [efficient epilogues](test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative.cu#L783) using TMA for Hopper.
|
||||
- Support for [fused epilogues](test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu), such Bias, ReLU and GELU, using the new efficient epilogues.
|
||||
- New [warp-specialized TensorFloat-32 (TF32) GEMM kernels](test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu) targeting Hopper TMA.
|
||||
- New [*warp-specialized persistent cooperative*](include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp) kernel design that improves performance on Hopper.
|
||||
@@ -54,6 +54,12 @@ CUTLASS 3.1 is an update to CUTLASS adding:
|
||||
- New Epilogue builders. Similar to mainloop builders (see [example 49](/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu)), epilogue builders aim to generate the best-possible epilogue while exposing incremental opt-ins for greater customization.
|
||||
- Profiler support for overriding kernel and epilogue builder auto schedules for 3.x API kernels, allowing specific policies to be run in the CUTLASS profiler.
|
||||
- Changes to the [GEMM API 3.x](media/docs/gemm_api_3x.md), involving the host-facing arguments and the underlying `Params` structs.
|
||||
- [FMHA Backward Pass](examples/41_fused_multi_head_attention/fused_multi_head_attention_backward.cu) from Meta xFormers.
|
||||
- [Streamk GEMM with Broadcast](examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu) enables epilogue broadcast with StreamK GEMM.
|
||||
- [Batched B2B GEMM](examples/13_two_tensor_op_fusion) now can run multiple Back-to-Back GEMM with the same problem size in parallel.
|
||||
- [Batched Strided GEMV](test/unit/gemm/device/gemv.cu) support both row major and column major input matrix.
|
||||
- [Permute + GEMM fusion](examples/39_gemm_permute) can fuse Permute with following GEMM now. Before, we only support fusing GEMM with Permute in the epilogue.
|
||||
|
||||
- *Announcement*:
|
||||
- The GitHub branch is renamed from `master` to `main` in this release.
|
||||
- A slight modification has been made to the ordering of arguments passed in to epilogues in 3.x kernels.
|
||||
|
||||
@@ -641,6 +641,11 @@ public:
|
||||
|
||||
}
|
||||
|
||||
// Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
|
||||
// 2nd Gemm
|
||||
|
||||
/// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile
|
||||
@@ -871,7 +876,10 @@ public:
|
||||
|
||||
}
|
||||
|
||||
|
||||
// Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
@@ -664,6 +664,11 @@ public:
|
||||
|
||||
}
|
||||
|
||||
// Insert fence and wait for all outstanding cp.async operations to commit.
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
|
||||
/// Epilogue for the first Implicit Gemm
|
||||
Epilogue0 epilogue0;
|
||||
|
||||
@@ -855,7 +860,10 @@ public:
|
||||
|
||||
}
|
||||
|
||||
|
||||
// Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
@@ -759,13 +759,10 @@ public:
|
||||
accum1 = plus_accum(accum1, tmp_accum1);
|
||||
}
|
||||
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
|
||||
// commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -461,11 +461,6 @@ Result run(std::string description, Options &options)
|
||||
std::cout << " GFLOPs: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
// TODO: uncomment when results match
|
||||
//if (!result.passed) {
|
||||
// exit(-1);
|
||||
//}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
@@ -499,7 +499,7 @@ flatten(T const& t)
|
||||
|
||||
namespace detail {
|
||||
|
||||
// Shortcut around tuple_cat for common insert/remove/repeat cases
|
||||
// Shortcut around cute::tuple_cat for common insert/remove/repeat cases
|
||||
template <class T, class X, int... I, int... J, int... K>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
|
||||
@@ -623,7 +623,7 @@ partition_shape_C(TiledMMA<Args...> const& mma, Shape_MN const& shape_MN)
|
||||
auto V = shape<1>(typename TiledMMA<Args...>::AtomLayoutC_TV{});
|
||||
auto M = shape_div(size<0>(shape_MN), size<0>(atomMNK) * size<1>(thrVMNK));
|
||||
auto N = shape_div(size<1>(shape_MN), size<1>(atomMNK) * size<2>(thrVMNK));
|
||||
return tuple_cat(make_shape(V,M,N), take<2,R>(shape_MN));
|
||||
return cute::tuple_cat(make_shape(V,M,N), take<2,R>(shape_MN));
|
||||
}
|
||||
|
||||
template <class... Args, class Shape_MN>
|
||||
@@ -651,7 +651,7 @@ partition_shape_A(TiledMMA<Args...> const& mma, Shape_MK const& shape_MK)
|
||||
auto V = shape<1>(typename TiledMMA<Args...>::AtomLayoutA_TV{});
|
||||
auto M = shape_div(size<0>(shape_MK), size<0>(atomMNK) * size<1>(thrVMNK));
|
||||
auto K = shape_div(size<1>(shape_MK), size<2>(atomMNK) * size<3>(thrVMNK));
|
||||
return tuple_cat(make_shape(V,M,K), take<2,R>(shape_MK));
|
||||
return cute::tuple_cat(make_shape(V,M,K), take<2,R>(shape_MK));
|
||||
}
|
||||
|
||||
template <class... Args, class Shape_NK>
|
||||
@@ -666,7 +666,7 @@ partition_shape_B(TiledMMA<Args...> const& mma, Shape_NK const& shape_NK)
|
||||
auto V = shape<1>(typename TiledMMA<Args...>::AtomLayoutB_TV{});
|
||||
auto N = shape_div(size<0>(shape_NK), size<1>(atomMNK) * size<2>(thrVMNK));
|
||||
auto K = shape_div(size<1>(shape_NK), size<2>(atomMNK) * size<3>(thrVMNK));
|
||||
return tuple_cat(make_shape(V,N,K), take<2,R>(shape_NK));
|
||||
return cute::tuple_cat(make_shape(V,N,K), take<2,R>(shape_NK));
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
@@ -46,8 +46,14 @@ namespace cute
|
||||
|
||||
using dim3 = ::dim3;
|
||||
|
||||
// MSVC doesn't define its C++ version macro to match
|
||||
// its C++ language version. This means that when
|
||||
// building with MSVC, dim3 isn't constexpr-friendly.
|
||||
template <size_t I>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
CUTE_HOST_DEVICE
|
||||
#if ! defined(_MSC_VER)
|
||||
constexpr
|
||||
#endif
|
||||
uint32_t& get(dim3& a)
|
||||
{
|
||||
static_assert(I < 3, "Index out of range");
|
||||
@@ -63,7 +69,10 @@ uint32_t& get(dim3& a)
|
||||
}
|
||||
|
||||
template <size_t I>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
CUTE_HOST_DEVICE
|
||||
#if ! defined(_MSC_VER)
|
||||
constexpr
|
||||
#endif
|
||||
uint32_t const& get(dim3 const& a)
|
||||
{
|
||||
static_assert(I < 3, "Index out of range");
|
||||
@@ -79,7 +88,10 @@ uint32_t const& get(dim3 const& a)
|
||||
}
|
||||
|
||||
template <size_t I>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
CUTE_HOST_DEVICE
|
||||
#if ! defined(_MSC_VER)
|
||||
constexpr
|
||||
#endif
|
||||
uint32_t&& get(dim3&& a)
|
||||
{
|
||||
static_assert(I < 3, "Index out of range");
|
||||
|
||||
@@ -86,18 +86,11 @@ constexpr auto
|
||||
sm90_compute_tile_shape_or_override() {
|
||||
if constexpr (cute::is_same_v<EpilogueTileType, EpilogueTileAuto>) {
|
||||
|
||||
constexpr int SmemAlloc = 4096;
|
||||
if constexpr (detail::sm90_is_cooperative_v<Schedule>) {
|
||||
constexpr int M = 128;
|
||||
constexpr int N = SmemAlloc / (M * sizeof(Element));
|
||||
|
||||
return make_shape(Int<M>{}, Int<N>{});
|
||||
return Shape<_128,_16>{};
|
||||
}
|
||||
else if constexpr (detail::sm90_is_warp_specialized_v<Schedule>) {
|
||||
constexpr int M = 64;
|
||||
constexpr int N = SmemAlloc / (M * sizeof(Element));
|
||||
|
||||
return make_shape(Int<M>{}, Int<N>{});
|
||||
return Shape<_64,_32>{};
|
||||
}
|
||||
else {
|
||||
static_assert(cutlass::detail::dependent_false<Schedule>, "Unsupported schedule.");
|
||||
@@ -167,8 +160,8 @@ template <
|
||||
class EpilogueTileType,
|
||||
class ElementAccumulator,
|
||||
class ElementCompute,
|
||||
class ElementC,
|
||||
class GmemLayoutTagC,
|
||||
class ElementC_,
|
||||
class GmemLayoutTagC_,
|
||||
int AlignmentC,
|
||||
class ElementD,
|
||||
class GmemLayoutTagD,
|
||||
@@ -178,6 +171,11 @@ template <
|
||||
class DispatchPolicy
|
||||
>
|
||||
struct TmaBuilderImpl {
|
||||
|
||||
// Passing void C disables source load
|
||||
using ElementC = cute::conditional_t<cute::is_void_v<ElementC_>,ElementD,ElementC_>; // prevents void ref breakages
|
||||
using GmemLayoutTagC = cute::conditional_t<cute::is_void_v<ElementC_>,GmemLayoutTagD,GmemLayoutTagC_>;
|
||||
|
||||
using GmemStrideTypeC = gemm::TagToStrideC_t<GmemLayoutTagC>;
|
||||
using GmemStrideTypeD = gemm::TagToStrideC_t<GmemLayoutTagD>;
|
||||
|
||||
@@ -188,7 +186,7 @@ struct TmaBuilderImpl {
|
||||
DispatchPolicy,
|
||||
TileShape_MNK,
|
||||
EpilogueTile_MN,
|
||||
ElementC,
|
||||
ElementC_, // Need to pass void through to expose via GemmUniversal
|
||||
GmemStrideTypeC,
|
||||
ElementD,
|
||||
GmemStrideTypeD,
|
||||
@@ -246,8 +244,9 @@ struct CollectiveBuilder<
|
||||
static constexpr thread::ScaleType::Kind ScaleType = cute::is_void_v<ElementC_> ?
|
||||
thread::ScaleType::OnlyAlphaScaling : thread::ScaleType::Default;
|
||||
|
||||
static constexpr int FragmentSize = 1;
|
||||
using ThreadOp = thread::LinearCombination<
|
||||
ElementD, 1, ElementAccumulator, ElementCompute,
|
||||
ElementD, FragmentSize, ElementAccumulator, ElementCompute,
|
||||
ScaleType, FloatRoundStyle::round_to_nearest, ElementC>;
|
||||
|
||||
using CollectiveOp = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter<
|
||||
@@ -267,7 +266,7 @@ template <
|
||||
class ElementAccumulator,
|
||||
class ElementCompute,
|
||||
class ElementC_,
|
||||
class GmemLayoutTagC_,
|
||||
class GmemLayoutTagC,
|
||||
int AlignmentC,
|
||||
class ElementD,
|
||||
class GmemLayoutTagD,
|
||||
@@ -283,7 +282,7 @@ struct CollectiveBuilder<
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
ElementC_,
|
||||
GmemLayoutTagC_,
|
||||
GmemLayoutTagC,
|
||||
AlignmentC,
|
||||
ElementD,
|
||||
GmemLayoutTagD,
|
||||
@@ -292,43 +291,26 @@ struct CollectiveBuilder<
|
||||
cute::enable_if_t<cute::is_same_v<Schedule, TmaWarpSpecialized> ||
|
||||
cute::is_same_v<Schedule, TmaWarpSpecializedCooperative> >> {
|
||||
public:
|
||||
// Passing void C disables source load
|
||||
using ElementC = cute::conditional_t<cute::is_void_v<ElementC_>,
|
||||
ElementD, ElementC_>; // prevents cute breakages
|
||||
using GmemLayoutTagC = cute::conditional_t<cute::is_void_v<ElementC_>,
|
||||
GmemLayoutTagD, GmemLayoutTagC_>;
|
||||
using ElementC = cute::conditional_t<cute::is_void_v<ElementC_>,ElementD,ElementC_>; // prevents void ref breakages
|
||||
static constexpr thread::ScaleType::Kind ScaleType = cute::is_void_v<ElementC_> ?
|
||||
thread::ScaleType::OnlyAlphaScaling : thread::ScaleType::Default;
|
||||
|
||||
static constexpr int FragmentSize = 4;
|
||||
using ThreadOp = thread::LinearCombination<
|
||||
ElementD, AlignmentD, ElementAccumulator, ElementCompute,
|
||||
ElementD, FragmentSize, ElementAccumulator, ElementCompute,
|
||||
ScaleType, FloatRoundStyle::round_to_nearest, ElementC>;
|
||||
|
||||
using GmemStrideTypeC = gemm::TagToStrideC_t<GmemLayoutTagC>;
|
||||
using GmemStrideTypeD = gemm::TagToStrideC_t<GmemLayoutTagD>;
|
||||
|
||||
using EpilogueTile_MN = decltype(detail::sm90_compute_tile_shape_or_override<
|
||||
ElementD, EpilogueTileType, Schedule>());
|
||||
|
||||
private:
|
||||
static constexpr int StagesC = 1;
|
||||
static constexpr int StagesD = 2;
|
||||
static constexpr bool DisableReuseSmemC = true;
|
||||
using CollectiveOp = cutlass::epilogue::collective::CollectiveEpilogue<
|
||||
cutlass::epilogue::Sm90TmaWarpSpecialized<StagesC,StagesD,DisableReuseSmemC>,
|
||||
TileShape_MNK,
|
||||
EpilogueTile_MN,
|
||||
ElementC_, // need to pass void to expose via GemmUniversal
|
||||
GmemStrideTypeC,
|
||||
ElementD,
|
||||
GmemStrideTypeD,
|
||||
ThreadOp,
|
||||
SM90_TMA_LOAD,
|
||||
decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom<GmemStrideTypeC, ElementC, TileShape_MNK>()),
|
||||
decltype(detail::sm90_get_smem_load_op_for_source<GmemStrideTypeC, ElementC>()),
|
||||
SM90_TMA_STORE,
|
||||
decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom<GmemStrideTypeD, ElementD, EpilogueTile_MN>()),
|
||||
decltype(detail::sm90_get_smem_store_op_for_accumulator<GmemStrideTypeD, ElementD>())
|
||||
>;
|
||||
using Impl = detail::TmaBuilderImpl<
|
||||
TileShape_MNK, ClusterShape_MNK, EpilogueTileType, ElementAccumulator, ElementCompute,
|
||||
ElementC_, GmemLayoutTagC, AlignmentC, ElementD, GmemLayoutTagD, AlignmentD,
|
||||
Schedule, ThreadOp, cutlass::epilogue::Sm90TmaWarpSpecialized<StagesC,StagesD, DisableReuseSmemC>>;
|
||||
|
||||
public:
|
||||
using CollectiveOp = typename Impl::CollectiveOp;
|
||||
};
|
||||
|
||||
// Auto builder
|
||||
@@ -427,11 +409,11 @@ struct CollectiveBuilder<
|
||||
Schedule,
|
||||
cute::enable_if_t<cute::is_base_of_v<TmaWarpSpecializedElementwiseBase, Schedule> ||
|
||||
cute::is_base_of_v<TmaWarpSpecializedCooperativeElementwiseBase, Schedule> >> {
|
||||
|
||||
public:
|
||||
static constexpr int FragmentSize = 4;
|
||||
using ThreadOp = thread::LinearCombinationGeneric<
|
||||
Schedule::ActivationFunctor,
|
||||
ElementD, AlignmentD,
|
||||
ElementD, FragmentSize,
|
||||
ElementAccumulator, ElementCompute, Schedule::Scale,
|
||||
Schedule::Round>;
|
||||
|
||||
@@ -455,7 +437,7 @@ template <
|
||||
class EpilogueTileType,
|
||||
class ElementAccumulator,
|
||||
class ElementCompute,
|
||||
class ElementC,
|
||||
class ElementC_,
|
||||
class GmemLayoutTagC,
|
||||
int AlignmentC,
|
||||
class ElementD,
|
||||
@@ -471,7 +453,7 @@ struct CollectiveBuilder<
|
||||
EpilogueTileType,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
ElementC,
|
||||
ElementC_,
|
||||
GmemLayoutTagC,
|
||||
AlignmentC,
|
||||
ElementD,
|
||||
@@ -480,10 +462,14 @@ struct CollectiveBuilder<
|
||||
Schedule,
|
||||
cute::enable_if_t<cute::is_base_of_v<TmaWarpSpecializedBiasElementwiseBase, Schedule> ||
|
||||
cute::is_base_of_v<TmaWarpSpecializedCooperativeBiasElementwiseBase, Schedule> >> {
|
||||
private:
|
||||
// Passing void C disables source load
|
||||
using ElementC = cute::conditional_t<cute::is_void_v<ElementC_>, ElementD, ElementC_>; // prevents void ref breakages
|
||||
|
||||
public:
|
||||
static constexpr int FragmentSize = 4;
|
||||
using ThreadOp = thread::LinearCombinationBiasElementwise<
|
||||
ElementC, ElementAccumulator, ElementCompute, ElementD, typename Schedule::ElementT, AlignmentD,
|
||||
ElementC, ElementAccumulator, ElementCompute, ElementD, typename Schedule::ElementT, FragmentSize,
|
||||
typename Schedule::ActivationFunctor<ElementCompute>, typename Schedule::BiasOp<ElementCompute>,
|
||||
Schedule::StoreT, typename Schedule::ElementBias>;
|
||||
|
||||
@@ -492,7 +478,7 @@ private:
|
||||
static constexpr int StagesD = 2;
|
||||
using Impl = detail::TmaBuilderImpl<
|
||||
TileShape_MNK, ClusterShape_MNK, EpilogueTileType, ElementAccumulator, ElementCompute,
|
||||
ElementC, GmemLayoutTagC, AlignmentC, ElementD, GmemLayoutTagD, AlignmentD,
|
||||
ElementC_, GmemLayoutTagC, AlignmentC, ElementD, GmemLayoutTagD, AlignmentD,
|
||||
Schedule, ThreadOp, cutlass::epilogue::Sm90TmaWarpSpecializedBiasElementwise<StagesC,StagesD>>;
|
||||
|
||||
public:
|
||||
@@ -540,8 +526,9 @@ struct CollectiveBuilder<
|
||||
static constexpr thread::ScaleType::Kind ScaleType = cute::is_void_v<ElementC_> ?
|
||||
thread::ScaleType::OnlyAlphaScaling : thread::ScaleType::Default;
|
||||
|
||||
static constexpr int FragmentSize = 1;
|
||||
using ThreadOp = thread::LinearCombination<
|
||||
ElementD, 1, ElementAccumulator, ElementCompute,
|
||||
ElementD, FragmentSize, ElementAccumulator, ElementCompute,
|
||||
ScaleType, FloatRoundStyle::round_to_nearest, ElementC>;
|
||||
|
||||
using CollectiveOp = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter<
|
||||
|
||||
@@ -75,6 +75,9 @@ public:
|
||||
using ElementD = typename ThreadEpilogueOp::ElementD;
|
||||
using StrideD = StrideD_;
|
||||
|
||||
using GmemTiledCopyC = void;
|
||||
using GmemTiledCopyD = void;
|
||||
|
||||
static const int kOutputAlignment = ThreadEpilogueOp::kCount;
|
||||
using AlignmentType = typename cute::uint_bit<sizeof_bits<ElementOutput>::value * kOutputAlignment>::type;
|
||||
|
||||
|
||||
@@ -48,8 +48,8 @@ template <
|
||||
int StagesC_,
|
||||
int StagesD_,
|
||||
bool DisableSmemReuseC_,
|
||||
class BlockTileShape_, // (BLK_M,BLK_N,BLK_K)
|
||||
class EpilogueTile_, // (EPI_TILE_M,EPI_TILE_N) per-collective
|
||||
class BlockTileShape_, // (BLK_M,BLK_N,BLK_K)
|
||||
class EpilogueTileShape_, // (EPI_TILE_M,EPI_TILE_N)
|
||||
class ElementC_,
|
||||
class StrideC_,
|
||||
class ElementD_,
|
||||
@@ -65,7 +65,7 @@ template <
|
||||
class CollectiveEpilogue<
|
||||
Sm90TmaWarpSpecialized<StagesC_,StagesD_,DisableSmemReuseC_>,
|
||||
BlockTileShape_,
|
||||
EpilogueTile_,
|
||||
EpilogueTileShape_,
|
||||
ElementC_,
|
||||
StrideC_,
|
||||
ElementD_,
|
||||
@@ -84,7 +84,7 @@ public:
|
||||
//
|
||||
using DispatchPolicy = Sm90TmaWarpSpecialized<StagesC_,StagesD_,DisableSmemReuseC_>;
|
||||
using BlockTileShape = BlockTileShape_;
|
||||
using EpilogueTile = EpilogueTile_;
|
||||
using EpilogueTileShape = EpilogueTileShape_;
|
||||
using ThreadEpilogueOp = ThreadEpilogueOp_;
|
||||
using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator;
|
||||
using ElementCompute = typename ThreadEpilogueOp::ElementCompute;
|
||||
@@ -103,24 +103,27 @@ public:
|
||||
using SmemLayoutAtomD = SmemLayoutAtomD_;
|
||||
using CopyOpR2S = CopyOpR2S_;
|
||||
|
||||
using GmemTiledCopyC = SM90_TMA_LOAD;
|
||||
using GmemTiledCopyD = SM90_TMA_STORE;
|
||||
|
||||
constexpr static int kOutputAlignment = ThreadEpilogueOp::kCount;
|
||||
constexpr static bool iskThreadEpilogueOpWithBias = detail::IsThreadEpilogueOpWithBias<ThreadEpilogueOp>::value;
|
||||
using AlignmentType = typename uint_bit<sizeof_bits<ElementOutput>::value * kOutputAlignment>::type;
|
||||
|
||||
static_assert(sizeof(ElementD) == 2, "Only 16b output supported for now");
|
||||
static_assert(!is_layout<EpilogueTile>::value && is_tuple<EpilogueTile>::value, "EpilogueTile must be a cute::Tile or cute::Shape");
|
||||
static_assert(!is_layout<EpilogueTileShape>::value && is_tuple<EpilogueTileShape>::value, "EpilogueTileShape must be a cute::Shape");
|
||||
static_assert(rank(BlockTileShape{}) == 3, "BlockTileShape must be rank-3: [BLK_M,BLK_N,BLK_K]");
|
||||
static_assert(rank(EpilogueTile{}) == 2, "EpilogueTile must be rank-2: [EPI_TILE_M,EPI_TILE_N]");
|
||||
static_assert(rank(EpilogueTileShape{}) == 2, "EpilogueTileShape must be rank-2: [EPI_TILE_M,EPI_TILE_N]");
|
||||
static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
|
||||
static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
|
||||
|
||||
private:
|
||||
using InternalElementC = std::conditional_t<std::is_void_v<ElementC>,ElementD,ElementC>; // prevents void ref breakages
|
||||
using InternalElementC = cute::conditional_t<cute::is_void_v<ElementC>,ElementD,ElementC>; // prevents void ref breakages
|
||||
constexpr static int StagesC = StagesC_;
|
||||
constexpr static int StagesD = StagesD_;
|
||||
constexpr static bool is_source_supported = ThreadEpilogueOp::kScale == cutlass::epilogue::thread::ScaleType::Default ||
|
||||
ThreadEpilogueOp::kScale == cutlass::epilogue::thread::ScaleType::NoBetaScaling;
|
||||
static_assert((std::is_void_v<ElementC> && not is_source_supported) || (not std::is_void_v<ElementC> && is_source_supported));
|
||||
static_assert((cute::is_void_v<ElementC> && not is_source_supported) || (not cute::is_void_v<ElementC> && is_source_supported),
|
||||
"Inconsistent C type and Scale kind");
|
||||
|
||||
// internal optimization to reuse C shared memory for storing D
|
||||
using SmemLayoutAtomBitsC = decltype(downcast<sizeof_bits<InternalElementC>::value>(SmemLayoutAtomC{}));
|
||||
@@ -131,21 +134,14 @@ private:
|
||||
StrideC{} == StrideD{} &&
|
||||
cute::is_same_v<SmemLayoutAtomBitsC,SmemLayoutAtomBitsD>;
|
||||
|
||||
// Find the max contiguous layout usable by TMA (if EpilogueTile is a by-mode tiler)
|
||||
using SmemLayoutTmaD = decltype(tile_to_shape(
|
||||
SmemLayoutAtomD{},
|
||||
make_shape(max_common_vector(make_layout(get<0>(EpilogueTile{})),make_layout(get<0>(EpilogueTile{}))),
|
||||
max_common_vector(make_layout(get<1>(EpilogueTile{})),make_layout(get<1>(EpilogueTile{})))),
|
||||
cute::conditional_t<get<0>(StrideD{}) == 1, Step<_2,_1>, Step<_1,_2>>{} ));
|
||||
|
||||
public:
|
||||
using SmemLayoutC = decltype(tile_to_shape(
|
||||
SmemLayoutAtomC{},
|
||||
make_shape(size<0>(BlockTileShape{}), size<1>(BlockTileShape{}), Int<StagesC>{}),
|
||||
cute::conditional_t<get<0>(StrideC{}) == 1, Step<_2,_1,_3>, Step<_1,_2,_3>>{} ));
|
||||
using SmemLayoutD = decltype(tile_to_shape(
|
||||
SmemLayoutTmaD{},
|
||||
make_shape(size<0>(shape(EpilogueTile{})), size<1>(shape(EpilogueTile{})), Int<StagesD>{}),
|
||||
SmemLayoutAtomD{},
|
||||
make_shape(size<0>(EpilogueTileShape{}), size<1>(EpilogueTileShape{}), Int<StagesD>{}),
|
||||
cute::conditional_t<get<0>(StrideD{}) == 1, Step<_2,_1,_3>, Step<_1,_2,_3>>{} ));
|
||||
|
||||
// TMA pipeline for loading C
|
||||
@@ -194,7 +190,7 @@ public:
|
||||
CopyOpS2G{},
|
||||
make_tensor(static_cast<ElementD const*>(nullptr),
|
||||
repeat_like(StrideD{}, int32_t(0)), StrideD{}),
|
||||
SmemLayoutTmaD{}));
|
||||
SmemLayoutD{}(_,_,0)));
|
||||
|
||||
typename ThreadEpilogueOp::Params thread{};
|
||||
TMA_C tma_load_c;
|
||||
@@ -210,23 +206,32 @@ public:
|
||||
to_underlying_arguments(
|
||||
ProblemShape const& problem_shape,
|
||||
Arguments const& args,
|
||||
[[maybe_unused]] void* workspace)
|
||||
{
|
||||
[[maybe_unused]] void* workspace) {
|
||||
// Optionally append _1s until problem shape is rank-4 in case its is only rank-3 (MNK)
|
||||
auto problem_shape_MNKL = append<4>(problem_shape, Int<1>{});
|
||||
auto M = get<0>(problem_shape_MNKL);
|
||||
auto N = get<1>(problem_shape_MNKL);
|
||||
auto L = get<3>(problem_shape_MNKL);
|
||||
Tensor tensor_c = make_tensor(static_cast<InternalElementC const*>(args.ptr_C), make_layout(make_shape(M,N,L), args.dC));
|
||||
|
||||
typename Params::TMA_C tma_load_c = [&]() {
|
||||
if constexpr (not cute::is_void_v<ElementC>) {
|
||||
Tensor tensor_c = make_tensor(static_cast<InternalElementC const*>(args.ptr_C), make_layout(make_shape(M,N,L), args.dC));
|
||||
return make_tma_copy(
|
||||
CopyOpG2S{},
|
||||
tensor_c,
|
||||
SmemLayoutC{}(_,_,0));
|
||||
}
|
||||
else {
|
||||
return typename Params::TMA_C{};
|
||||
}
|
||||
}();
|
||||
|
||||
Tensor tensor_d = make_tensor(args.ptr_D, make_layout(make_shape(M,N,L), args.dD));
|
||||
typename Params::TMA_C tma_load_c = make_tma_copy(
|
||||
CopyOpG2S{},
|
||||
tensor_c,
|
||||
SmemLayoutC{}(_,_,0));
|
||||
typename Params::TMA_D tma_store_d = make_tma_copy(
|
||||
CopyOpS2G{},
|
||||
tensor_d,
|
||||
SmemLayoutTmaD{});
|
||||
SmemLayoutD{}(_,_,0));
|
||||
|
||||
return {
|
||||
args.thread,
|
||||
tma_load_c,
|
||||
@@ -378,8 +383,8 @@ public:
|
||||
auto L = get<3>(problem_shape_mnkl);
|
||||
auto mma_tile_m = size<0>(typename TiledMma::TiledShape_MNK{});
|
||||
auto mma_tile_n = size<1>(typename TiledMma::TiledShape_MNK{});
|
||||
auto epi_tile_m = size<0>(shape(EpilogueTile{}));
|
||||
auto epi_tile_n = size<1>(shape(EpilogueTile{}));
|
||||
auto epi_tile_m = size<0>(EpilogueTileShape{});
|
||||
auto epi_tile_n = size<1>(EpilogueTileShape{});
|
||||
|
||||
// Represent the full output tensor
|
||||
Tensor mD_mnl = params.tma_store_d.get_tma_tensor(make_shape(M,N,L)); // (m,n,l)
|
||||
@@ -396,11 +401,14 @@ public:
|
||||
SmemLayoutD{});
|
||||
|
||||
// Tile thread(b)lock tensors by (E)pilogue output tile shape (bE)
|
||||
Tensor bEsC = local_tile(sC, EpilogueTile{}, _); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
||||
Tensor bEgD = local_tile(gD, EpilogueTile{}, _); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
||||
Tensor bEsC = local_tile(sC, EpilogueTileShape{}, _); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
||||
Tensor bEgD = local_tile(gD, EpilogueTileShape{}, _); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
||||
|
||||
// Partition for register to smem copy (tRS_)
|
||||
TiledCopy tiled_r2s = make_tiled_copy_C_atom(Copy_Atom<CopyOpR2S,ElementD>{}, tiled_mma);
|
||||
using CopyAtomR2S = cute::conditional_t<cute::is_same_v<CopyOpR2S,DefaultCopy>,
|
||||
Copy_Atom<UniversalCopy<uint_byte_t<sizeof(ElementD)*2>>,ElementD>,
|
||||
Copy_Atom<CopyOpR2S,ElementD>>;
|
||||
TiledCopy tiled_r2s = make_tiled_copy_C_atom(CopyAtomR2S{}, tiled_mma);
|
||||
ThrCopy thread_r2s = tiled_r2s.get_slice(thread_idx);
|
||||
Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((R2S,R2S_V),MMA_M,MMA_N)
|
||||
Tensor tRS_sD = conditional_return<ReuseSmemC>(
|
||||
@@ -430,7 +438,7 @@ public:
|
||||
thrblk_s2g.partition_S(bEsD) ); // (S2G,S2G_M,S2G_N,PIPE)
|
||||
Tensor tSG_gD = thrblk_s2g.partition_D(bEgD); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N)
|
||||
|
||||
CUTE_STATIC_ASSERT(size<0,0>(tRS_rAcc) % ThreadEpilogueOp::kCount == 0, "ThreadEpilogueOp does not vectorize properly");
|
||||
CUTE_STATIC_ASSERT(size<0>(tRS_rAcc) % ThreadEpilogueOp::kCount == 0, "ThreadEpilogueOp does not vectorize properly");
|
||||
CUTE_STATIC_ASSERT(mma_tile_m == epi_tile_m, "EPI_TILE_M must equal MMA_TILE_M");
|
||||
CUTE_STATIC_ASSERT(mma_tile_n % epi_tile_n == 0, "EPI_TILE_N must divide MMA_TILE_N");
|
||||
|
||||
@@ -464,7 +472,13 @@ public:
|
||||
int r2s_v = epi_n * size(tRS_rD_frg);
|
||||
if (epilogue_op.is_source_needed()) {
|
||||
// Copy source tile to register from smem
|
||||
copy(tiled_s2r, tSR_sC(_,_,_,epi_m,epi_n), tSR_rC);
|
||||
if constexpr (cute::is_same_v<CopyOpS2R,DefaultCopy>) {
|
||||
copy(tSR_sC(_,_,_,epi_m,epi_n), tSR_rC);
|
||||
}
|
||||
else {
|
||||
copy(tiled_s2r, tSR_sC(_,_,_,epi_m,epi_n), tSR_rC);
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(tRS_rD_frg); ++i) {
|
||||
tRS_rD_frg(i) = epilogue_op(tRS_rAcc_frg_mn(r2s_v + i), tRS_rC_frg(i));
|
||||
@@ -491,7 +505,12 @@ public:
|
||||
}
|
||||
|
||||
// Copy output tile to smem from register
|
||||
copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,epi_m,epi_n));
|
||||
if constexpr (cute::is_same_v<CopyOpR2S,DefaultCopy>) {
|
||||
copy(tRS_rD, tRS_sD(_,_,_,epi_m,epi_n));
|
||||
}
|
||||
else {
|
||||
copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,epi_m,epi_n));
|
||||
}
|
||||
}
|
||||
else {
|
||||
// Issue the TMA store of the previous iteration
|
||||
@@ -514,7 +533,12 @@ public:
|
||||
synchronize();
|
||||
|
||||
// Copy tile to smem from register
|
||||
copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index()));
|
||||
if constexpr (cute::is_same_v<CopyOpR2S,DefaultCopy>) {
|
||||
copy(tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index()));
|
||||
}
|
||||
else {
|
||||
copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index()));
|
||||
}
|
||||
|
||||
// Advance pipeline state
|
||||
store_pipe_producer_state_prev = store_pipe_producer_state;
|
||||
|
||||
@@ -47,8 +47,8 @@ namespace collective {
|
||||
template <
|
||||
int StagesC_,
|
||||
int StagesD_,
|
||||
class BlockTileShape_, // (BLK_M,BLK_N,BLK_K)
|
||||
class EpilogueTile_, // (EPI_TILE_M,EPI_TILE_N) per-collective
|
||||
class BlockTileShape_, // (BLK_M,BLK_N,BLK_K)
|
||||
class EpilogueTileShape_, // (EPI_TILE_M,EPI_TILE_N)
|
||||
class ElementC_,
|
||||
class StrideC_,
|
||||
class ElementD_,
|
||||
@@ -64,7 +64,7 @@ template <
|
||||
class CollectiveEpilogue<
|
||||
Sm90TmaWarpSpecializedBiasElementwise<StagesC_, StagesD_>,
|
||||
BlockTileShape_,
|
||||
EpilogueTile_,
|
||||
EpilogueTileShape_,
|
||||
ElementC_,
|
||||
StrideC_,
|
||||
ElementD_,
|
||||
@@ -81,10 +81,9 @@ public:
|
||||
//
|
||||
// Type Aliases
|
||||
//
|
||||
// derived types of output thread level operator
|
||||
using DispatchPolicy = Sm90TmaWarpSpecializedBiasElementwise<StagesC_, StagesD_>;
|
||||
using BlockTileShape = BlockTileShape_;
|
||||
using EpilogueTile = EpilogueTile_;
|
||||
using EpilogueTileShape = EpilogueTileShape_;
|
||||
using ThreadEpilogueOp = ThreadEpilogueOp_;
|
||||
using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator;
|
||||
using ElementCompute = typename ThreadEpilogueOp::ElementCompute;
|
||||
@@ -106,6 +105,9 @@ public:
|
||||
using SmemLayoutAtomD = SmemLayoutAtomD_;
|
||||
using CopyOpR2S = CopyOpR2S_;
|
||||
|
||||
using GmemTiledCopyC = SM90_TMA_LOAD;
|
||||
using GmemTiledCopyD = SM90_TMA_STORE;
|
||||
|
||||
constexpr static bool StoreT = ThreadEpilogueOp::kStoreT;
|
||||
constexpr static int kOutputAlignment = ThreadEpilogueOp::kCount;
|
||||
static_assert(detail::IsThreadEpilogueOpWithBias<ThreadEpilogueOp>::value,
|
||||
@@ -113,26 +115,28 @@ public:
|
||||
constexpr static bool iskThreadEpilogueOpWithBias = true;
|
||||
using AlignmentType = typename uint_bit<sizeof_bits<ElementOutput>::value * kOutputAlignment>::type;
|
||||
|
||||
static_assert(sizeof(ElementC) == 2, "Only 16b source supported for now");
|
||||
static_assert(sizeof(ElementD) == 2, "Only 16b output supported for now");
|
||||
static_assert(!is_layout<EpilogueTile>::value && is_tuple<EpilogueTile>::value, "EpilogueTile must be a cute::Tile or cute::Shape");
|
||||
static_assert(!is_layout<EpilogueTileShape>::value && is_tuple<EpilogueTileShape>::value, "EpilogueTileShape must be a cute::Shape");
|
||||
static_assert(rank(BlockTileShape{}) == 3, "BlockTileShape must be rank-3: [BLK_M,BLK_N,BLK_K]");
|
||||
static_assert(rank(EpilogueTile{}) == 2, "EpilogueTile must be rank-2: [EPI_TILE_M,EPI_TILE_N]");
|
||||
static_assert(rank(EpilogueTileShape{}) == 2, "EpilogueTileShape must be rank-2: [EPI_TILE_M,EPI_TILE_N]");
|
||||
static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
|
||||
static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
|
||||
|
||||
private:
|
||||
using InternalElementC = cute::conditional_t<cute::is_void_v<ElementC>,ElementD,ElementC>; // prevents void ref breakages
|
||||
constexpr static int StagesC = StagesC_;
|
||||
constexpr static int StagesD = StagesD_;
|
||||
constexpr static bool is_source_supported = ThreadEpilogueOp::kScale == cutlass::epilogue::thread::ScaleType::Default ||
|
||||
ThreadEpilogueOp::kScale == cutlass::epilogue::thread::ScaleType::NoBetaScaling;
|
||||
constexpr static bool is_source_supported = not cute::is_void_v<ElementC>;
|
||||
static_assert((cute::is_void_v<ElementC> && not is_source_supported) || (not cute::is_void_v<ElementC> && is_source_supported),
|
||||
"Inconsistent C type and Scale kind");
|
||||
|
||||
// Find the max contiguous layout usable by TMA (if EpilogueTile is a by-mode tiler)
|
||||
using SmemLayoutTmaD = decltype(tile_to_shape(
|
||||
SmemLayoutAtomD{},
|
||||
make_shape(max_common_vector(make_layout(get<0>(EpilogueTile{})),make_layout(get<0>(EpilogueTile{}))),
|
||||
max_common_vector(make_layout(get<1>(EpilogueTile{})),make_layout(get<1>(EpilogueTile{})))),
|
||||
cute::conditional_t<get<0>(StrideD{}) == 1, Step<_2,_1>, Step<_1,_2>>{} ));
|
||||
// internal optimization to reuse C shared memory for storing D
|
||||
using SmemLayoutAtomBitsC = decltype(downcast<sizeof_bits<InternalElementC>::value>(SmemLayoutAtomC{}));
|
||||
using SmemLayoutAtomBitsD = decltype(downcast<sizeof_bits<ElementD>::value>(SmemLayoutAtomD{}));
|
||||
constexpr static bool ReuseSmemC = is_source_supported &&
|
||||
sizeof(InternalElementC) == sizeof(ElementD) &&
|
||||
StrideC{} == StrideD{} &&
|
||||
cute::is_same_v<SmemLayoutAtomBitsC,SmemLayoutAtomBitsD> &&
|
||||
not StoreT;
|
||||
|
||||
public:
|
||||
using SmemLayoutC = decltype(tile_to_shape(
|
||||
@@ -140,29 +144,31 @@ public:
|
||||
make_shape(size<0>(BlockTileShape{}), size<1>(BlockTileShape{}), Int<StagesC>{}),
|
||||
cute::conditional_t<get<0>(StrideC{}) == 1, Step<_2,_1,_3>, Step<_1,_2,_3>>{} ));
|
||||
using SmemLayoutD = decltype(tile_to_shape(
|
||||
SmemLayoutTmaD{},
|
||||
make_shape(size<0>(shape(EpilogueTile{})), size<1>(shape(EpilogueTile{})), Int<StagesD>{}),
|
||||
SmemLayoutAtomD{},
|
||||
make_shape(size<0>(EpilogueTileShape{}), size<1>(EpilogueTileShape{}), Int<StagesD>{}),
|
||||
cute::conditional_t<get<0>(StrideD{}) == 1, Step<_2,_1,_3>, Step<_1,_2,_3>>{} ));
|
||||
|
||||
// TMA pipeline for loading C
|
||||
using LoadPipeline = cutlass::PipelineTransactionAsync<is_source_supported ? StagesC : 0>;
|
||||
using LoadPipelineState = cutlass::PipelineState<is_source_supported ? StagesC : 0>;
|
||||
constexpr static uint32_t TmaTransactionBytes =
|
||||
size(take<0,2>(SmemLayoutC{})) * static_cast<uint32_t>(sizeof(ElementC));
|
||||
size(take<0,2>(SmemLayoutC{})) * static_cast<uint32_t>(sizeof(InternalElementC));
|
||||
|
||||
// TMA pipeline for storing D and T
|
||||
using StorePipeline = cutlass::PipelineTmaStore<StagesD>;
|
||||
using StorePipelineState = cutlass::PipelineState<StagesD>;
|
||||
// TMA pipeline for storing D and T. ReuseSmemC cannot be set to true if StoreT is enabled.
|
||||
using StorePipeline = cutlass::PipelineTmaStore<ReuseSmemC ? StagesC : StagesD>;
|
||||
using StorePipelineState = cutlass::PipelineState<ReuseSmemC ? StagesC : StagesD>;
|
||||
|
||||
struct SharedStorage {
|
||||
struct TensorStorage : aligned_struct<128> {
|
||||
cute::conditional_t<not is_source_supported,
|
||||
detail::EmptyStorage<ElementC>,
|
||||
array_aligned<ElementC, size(SmemLayoutC{})>> smem_C;
|
||||
alignas(128) array_aligned<ElementD, size(SmemLayoutD{})> smem_D;
|
||||
alignas(128) cute::conditional_t<not StoreT,
|
||||
detail::EmptyStorage<InternalElementC>,
|
||||
array_aligned<InternalElementC, size(SmemLayoutC{})>> smem_C;
|
||||
alignas(128) cute::conditional_t<ReuseSmemC,
|
||||
detail::EmptyStorage<ElementD>,
|
||||
array_aligned<ElementD, size(SmemLayoutD{})>> smem_D;
|
||||
alignas(128) cute::conditional_t<not StoreT,
|
||||
detail::EmptyStorage<ElementT>,
|
||||
array_aligned<ElementT, size(SmemLayoutD{})>> smem_T;
|
||||
array_aligned<ElementT, size(SmemLayoutD{})>> smem_T;
|
||||
} tensors;
|
||||
|
||||
using PipelineStorage = typename LoadPipeline::SharedStorage;
|
||||
@@ -173,29 +179,32 @@ public:
|
||||
|
||||
// Host side epilogue arguments
|
||||
struct Arguments {
|
||||
typename ThreadEpilogueOp::Params thread{};
|
||||
ElementC const* ptr_C = nullptr;
|
||||
StrideC dC{};
|
||||
ElementD* ptr_D = nullptr;
|
||||
StrideD dD{};
|
||||
typename ThreadEpilogueOp::Params thread;
|
||||
ElementC const* ptr_C;
|
||||
StrideC dC;
|
||||
ElementD const* ptr_D;
|
||||
StrideD dD;
|
||||
ElementBias const* ptr_Bias = nullptr;
|
||||
ElementT* ptr_T = nullptr;
|
||||
ElementT const* ptr_T = nullptr;
|
||||
};
|
||||
|
||||
// Device side epilogue params
|
||||
// Device side epilgoue params
|
||||
struct Params {
|
||||
using TMA_C = decltype(make_tma_copy(
|
||||
CopyOpG2S{},
|
||||
make_tensor(static_cast<ElementC const*>(nullptr), repeat_like(StrideC{}, int32_t(0)), StrideC{}),
|
||||
make_tensor(static_cast<InternalElementC const*>(nullptr),
|
||||
repeat_like(StrideC{}, int32_t(0)), StrideC{}),
|
||||
SmemLayoutC{}(_,_,0)));
|
||||
using TMA_D = decltype(make_tma_copy(
|
||||
CopyOpS2G{},
|
||||
make_tensor(static_cast<ElementD*>(nullptr), repeat_like(StrideD{}, int32_t(0)), StrideD_{}),
|
||||
SmemLayoutTmaD{}));
|
||||
make_tensor(static_cast<ElementD const*>(nullptr),
|
||||
repeat_like(StrideD{}, int32_t(0)), StrideD{}),
|
||||
SmemLayoutD{}(_,_,0)));
|
||||
using TMA_T = decltype(make_tma_copy(
|
||||
CopyOpS2G{},
|
||||
make_tensor(static_cast<ElementT*>(nullptr), repeat_like(StrideD{}, int32_t(0)), StrideD{}),
|
||||
SmemLayoutTmaD{}));
|
||||
make_tensor(static_cast<ElementT const*>(nullptr),
|
||||
repeat_like(StrideD{}, int32_t(0)), StrideD{}),
|
||||
SmemLayoutD{}(_,_,0)));
|
||||
typename ThreadEpilogueOp::Params thread{};
|
||||
TMA_C tma_load_c;
|
||||
TMA_D tma_store_d;
|
||||
@@ -209,29 +218,42 @@ public:
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, [[maybe_unused]] void* workspace) {
|
||||
to_underlying_arguments(
|
||||
ProblemShape const& problem_shape,
|
||||
Arguments const& args,
|
||||
[[maybe_unused]] void* workspace) {
|
||||
// Optionally append _1s until problem shape is rank-4 in case its is only rank-3 (MNK)
|
||||
auto problem_shape_MNKL = append<4>(problem_shape, Int<1>{});
|
||||
auto M = get<0>(problem_shape_MNKL);
|
||||
auto N = get<1>(problem_shape_MNKL);
|
||||
auto L = get<3>(problem_shape_MNKL);
|
||||
Tensor tensor_c = make_tensor(args.ptr_C, make_layout(make_shape(M,N,L), args.dC));
|
||||
|
||||
typename Params::TMA_C tma_load_c = [&]() {
|
||||
if constexpr (not cute::is_void_v<ElementC>) {
|
||||
Tensor tensor_c = make_tensor(static_cast<InternalElementC const*>(args.ptr_C), make_layout(make_shape(M,N,L), args.dC));
|
||||
return make_tma_copy(
|
||||
CopyOpG2S{},
|
||||
tensor_c,
|
||||
SmemLayoutC{}(_,_,0));
|
||||
}
|
||||
else {
|
||||
return typename Params::TMA_C{};
|
||||
}
|
||||
}();
|
||||
|
||||
Tensor tensor_d = make_tensor(args.ptr_D, make_layout(make_shape(M,N,L), args.dD));
|
||||
typename Params::TMA_C tma_load_c = make_tma_copy(
|
||||
CopyOpG2S{},
|
||||
tensor_c,
|
||||
SmemLayoutC{}(_,_,0));
|
||||
typename Params::TMA_D tma_store_d = make_tma_copy(
|
||||
CopyOpS2G{},
|
||||
tensor_d,
|
||||
SmemLayoutTmaD{});
|
||||
SmemLayoutD{}(_,_,0));
|
||||
|
||||
typename Params::TMA_T tma_store_t = [&]() {
|
||||
if constexpr (StoreT) {
|
||||
Tensor tensor_t = make_tensor(args.ptr_T, make_layout(make_shape(M,N,L), args.dD));
|
||||
return make_tma_copy(
|
||||
CopyOpS2G{},
|
||||
tensor_t,
|
||||
SmemLayoutTmaD{});
|
||||
SmemLayoutD{}(_,_,0));
|
||||
}
|
||||
else {
|
||||
return typename Params::TMA_T{};
|
||||
@@ -262,6 +284,10 @@ public:
|
||||
CUTLASS_HOST_DEVICE
|
||||
static constexpr int
|
||||
get_store_pipe_increment(TileShapeMNK tile_shape_MNK) {
|
||||
if constexpr (ReuseSmemC) {
|
||||
return get_load_pipe_increment(tile_shape_MNK);
|
||||
}
|
||||
|
||||
// Compute number of D subtiles
|
||||
constexpr int epi_m = size<0>(tile_shape_MNK) / size<0>(SmemLayoutD{});
|
||||
constexpr int epi_n = size<1>(tile_shape_MNK) / size<1>(SmemLayoutD{});
|
||||
@@ -276,7 +302,7 @@ public:
|
||||
CUTLASS_DEVICE
|
||||
bool
|
||||
is_source_needed() {
|
||||
return epilogue_op.is_source_needed();
|
||||
return is_source_supported && epilogue_op.is_source_needed();
|
||||
}
|
||||
|
||||
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
|
||||
@@ -390,8 +416,8 @@ public:
|
||||
auto L = get<3>(problem_shape_mnkl);
|
||||
auto mma_tile_m = size<0>(typename TiledMma::TiledShape_MNK{});
|
||||
auto mma_tile_n = size<1>(typename TiledMma::TiledShape_MNK{});
|
||||
auto epi_tile_m = size<0>(shape(EpilogueTile{}));
|
||||
auto epi_tile_n = size<1>(shape(EpilogueTile{}));
|
||||
auto epi_tile_m = size<0>(EpilogueTileShape{});
|
||||
auto epi_tile_n = size<1>(EpilogueTileShape{});
|
||||
|
||||
// Represent the full output tensor
|
||||
Tensor mD_mnl = params.tma_store_d.get_tma_tensor(make_shape(M,N,L)); // (m,n,l)
|
||||
@@ -407,7 +433,7 @@ public:
|
||||
Tensor gT = gT_mnl(_,_,m_coord,n_coord,l_coord); // (TILE_M,TILE_N)
|
||||
Tensor gBias = gBias_mnl(_,_,m_coord,n_coord,l_coord); // (TILE_M,TILE_N)
|
||||
|
||||
// Construct the smem tensors for source (sC) and output (sD)
|
||||
// Construct the smem tensors for source (sC) and output (sD, sT)
|
||||
Tensor sC = make_tensor(make_smem_ptr(shared_tensors.smem_C.data()), // (TILE_M,TILE_N)
|
||||
SmemLayoutC{})(_,_,load_pipe_consumer_state.index());
|
||||
Tensor bEsD = make_tensor(make_smem_ptr(shared_tensors.smem_D.data()), // (EPI_TILE_M,EPI_TILE_N,PIPE)
|
||||
@@ -416,21 +442,26 @@ public:
|
||||
SmemLayoutD{});
|
||||
|
||||
// Tile thread(b)lock tensors by (E)pilogue output tile shape (bE)
|
||||
Tensor bEsC = local_tile(sC, EpilogueTile{}, _); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
||||
Tensor bEgD = local_tile(gD, EpilogueTile{}, _); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
||||
Tensor bEgT = local_tile(gT, EpilogueTile{}, _); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
||||
Tensor bEgBias = local_tile(gBias, EpilogueTile{}, _); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
||||
Tensor bEsC = local_tile(sC, EpilogueTileShape{}, _); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
||||
Tensor bEgD = local_tile(gD, EpilogueTileShape{}, _); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
||||
Tensor bEgT = local_tile(gT, EpilogueTileShape{}, _); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
||||
Tensor bEgBias = local_tile(gBias, EpilogueTileShape{}, _); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
||||
|
||||
// Partition for register to smem copy (tRS_)
|
||||
TiledCopy tiled_r2s = make_tiled_copy_C_atom(Copy_Atom<CopyOpR2S,ElementD>{}, tiled_mma);
|
||||
using CopyAtomR2S = cute::conditional_t<cute::is_same_v<CopyOpR2S,DefaultCopy>,
|
||||
Copy_Atom<UniversalCopy<uint_byte_t<sizeof(ElementD)*2>>,ElementD>,
|
||||
Copy_Atom<CopyOpR2S,ElementD>>;
|
||||
TiledCopy tiled_r2s = make_tiled_copy_C_atom(CopyAtomR2S{}, tiled_mma);
|
||||
ThrCopy thread_r2s = tiled_r2s.get_slice(thread_idx);
|
||||
Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((R2S,R2S_V),MMA_M,MMA_N)
|
||||
Tensor tRS_sD = thread_r2s.partition_D(bEsD); // (R2S,R2S_M,R2S_N,PIPE)
|
||||
Tensor tRS_sD = conditional_return<ReuseSmemC>(
|
||||
thread_r2s.partition_D(recast<ElementD>(bEsC)), // (R2S,R2S_M,R2S_N,EPI_M,EPI_N)
|
||||
thread_r2s.partition_D(bEsD) ); // (R2S,R2S_M,R2S_N,PIPE)
|
||||
Tensor tRS_sT = thread_r2s.partition_D(bEsT); // (R2S,R2S_M,R2S_N,PIPE)
|
||||
|
||||
// Allocate register tensors
|
||||
auto tRS_rD_shape = take<0,3>(shape(thread_r2s.partition_S(bEsD))); // (R2S,R2S_M,R2S_N)
|
||||
Tensor tRS_rC = make_tensor<ElementC>(tRS_rD_shape); // (R2S,R2S_M,R2S_N)
|
||||
Tensor tRS_rC = make_tensor<InternalElementC>(tRS_rD_shape); // (R2S,R2S_M,R2S_N)
|
||||
Tensor tRS_rD = make_tensor<ElementD>(tRS_rD_shape); // (R2S,R2S_M,R2S_N)
|
||||
Tensor tRS_rT = make_tensor<ElementT>(tRS_rD_shape); // (R2S,R2S_M,R2S_N)
|
||||
|
||||
@@ -445,21 +476,23 @@ public:
|
||||
Tensor tRS_rBias_frg = recast<typename ThreadEpilogueOp::FragmentBias>(tRS_rBias);
|
||||
|
||||
// Partition for smem to register copy (tSR_)
|
||||
TiledCopy tiled_s2r = make_tiled_copy_S(Copy_Atom<CopyOpS2R,ElementC>{}, tiled_r2s);
|
||||
TiledCopy tiled_s2r = make_tiled_copy_S(Copy_Atom<CopyOpS2R,InternalElementC>{}, tiled_r2s);
|
||||
ThrCopy thread_s2r = tiled_s2r.get_slice(thread_idx);
|
||||
Tensor tSR_sC = thread_s2r.partition_S(bEsC); // (S2R,S2R_M,S2R_N,EPI_M,EPI_N)
|
||||
Tensor tSR_rC = thread_s2r.retile_D(tRS_rC); // (S2R,S2R_M,S2R_N)
|
||||
Tensor tSR_sC = thread_s2r.partition_S(bEsC); // (S2R,S2R_M,S2R_N,EPI_M,EPI_N)
|
||||
Tensor tSR_rC = thread_s2r.retile_D(tRS_rC); // (S2R,S2R_M,S2R_N)
|
||||
|
||||
// Partition for smem to gmem copy (tSG_)
|
||||
ThrCopy thrblk_s2g = params.tma_store_d.get_slice(Int<0>{});
|
||||
Tensor tSG_sD = thrblk_s2g.partition_S(bEsD); // (S2G,S2G_M,S2G_N,PIPE)
|
||||
Tensor tSG_gD = thrblk_s2g.partition_D(bEgD); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N)
|
||||
Tensor tSG_sD = conditional_return<ReuseSmemC>(
|
||||
thrblk_s2g.partition_S(recast<ElementD>(bEsC)), // (S2G,S2G_M,S2G_N,EPI_M,EPI_N)
|
||||
thrblk_s2g.partition_S(bEsD) ); // (S2G,S2G_M,S2G_N,PIPE)
|
||||
Tensor tSG_gD = thrblk_s2g.partition_D(bEgD); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N)
|
||||
|
||||
ThrCopy thrblk_s2g_t = params.tma_store_t.get_slice(Int<0>{});
|
||||
Tensor tSG_sT = thrblk_s2g_t.partition_S(bEsT); // (S2G,S2G_M,S2G_N,PIPE)
|
||||
Tensor tSG_gT = thrblk_s2g_t.partition_D(bEgT); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N)
|
||||
|
||||
CUTE_STATIC_ASSERT(size<0,0>(tRS_rAcc) % ThreadEpilogueOp::kCount == 0, "ThreadEpilogueOp does not vectorize properly");
|
||||
CUTE_STATIC_ASSERT(size<0>(tRS_rAcc) % ThreadEpilogueOp::kCount == 0, "ThreadEpilogueOp does not vectorize properly");
|
||||
CUTE_STATIC_ASSERT(mma_tile_m == epi_tile_m, "EPI_TILE_M must equal MMA_TILE_M");
|
||||
CUTE_STATIC_ASSERT(mma_tile_n % epi_tile_n == 0, "EPI_TILE_N must divide MMA_TILE_N");
|
||||
|
||||
@@ -470,11 +503,15 @@ public:
|
||||
// Predication for TMA store (one warp issues TMA store)
|
||||
bool issue_tma_store = (thread_idx / NumThreadsPerWarp) == 0;
|
||||
|
||||
if (epilogue_op.is_source_needed()) {
|
||||
if (is_source_supported && epilogue_op.is_source_needed()) {
|
||||
// Wait for epilogue load to fill smem buffer with C
|
||||
load_pipeline.consumer_wait(load_pipe_consumer_state);
|
||||
}
|
||||
|
||||
// Delay issue of TMA store by 1 iteration to achieve better instruction pipelining
|
||||
PipelineState store_pipe_producer_state_prev = store_pipe_producer_state;
|
||||
int epi_m_prev = 0, epi_n_prev = 0;
|
||||
|
||||
// For each output tile
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int epi_n = 0; epi_n < size<3>(bEgD); ++epi_n) {
|
||||
@@ -490,9 +527,14 @@ public:
|
||||
|
||||
// Elementwise operation with conversion
|
||||
int r2s_v = epi_n * size(tRS_rD_frg);
|
||||
if (epilogue_op.is_source_needed()) {
|
||||
if (is_source_supported && epilogue_op.is_source_needed()) {
|
||||
// Copy source tile to registers from smem
|
||||
copy(tiled_s2r, tSR_sC(_,_,_,epi_m,epi_n), tSR_rC);
|
||||
if constexpr (cute::is_same_v<CopyOpS2R,DefaultCopy>) {
|
||||
copy(tSR_sC(_,_,_,epi_m,epi_n), tSR_rC);
|
||||
}
|
||||
else {
|
||||
copy(tiled_s2r, tSR_sC(_,_,_,epi_m,epi_n), tSR_rC);
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(tRS_rD_frg); ++i) {
|
||||
@@ -506,40 +548,119 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for a smem buffer to be available
|
||||
if (issue_tma_store) {
|
||||
store_pipeline.producer_acquire(store_pipe_producer_state);
|
||||
}
|
||||
synchronize();
|
||||
if constexpr (ReuseSmemC) {
|
||||
// If ReuseSmemC is true, StoreT must be false. Therefore, we do not perform copies for T in this block.
|
||||
|
||||
// Copy tile to smem from register
|
||||
copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index()));
|
||||
// Issue the TMA store of the previous iteration
|
||||
if (not (epi_m == 0 && epi_n == 0)) {
|
||||
// Make sure smem writes are visible to TMA
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
synchronize(); // ensure all threads have issued their async fence
|
||||
|
||||
if constexpr (StoreT) {
|
||||
copy(tiled_r2s, tRS_rT, tRS_sT(_,_,_,store_pipe_producer_state.index()));
|
||||
}
|
||||
|
||||
// Make sure smem writes are visible to TMA
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
synchronize(); // ensure all threads have issued their async fence
|
||||
|
||||
// Write the tile to gmem from smem with TMA
|
||||
if (issue_tma_store) {
|
||||
copy(params.tma_store_d, tSG_sD(_,_,_,store_pipe_producer_state.index()), tSG_gD(_,_,_,epi_m,epi_n));
|
||||
if constexpr (StoreT) {
|
||||
copy(params.tma_store_t, tSG_sT(_,_,_,store_pipe_producer_state.index()), tSG_gT(_,_,_,epi_m,epi_n));
|
||||
// Write the tile to gmem from smem with TMA
|
||||
if (issue_tma_store) {
|
||||
copy(params.tma_store_d, tSG_sD(_,_,_,epi_m_prev,epi_n_prev), tSG_gD(_,_,_,epi_m_prev,epi_n_prev));
|
||||
}
|
||||
}
|
||||
store_pipeline.producer_commit(store_pipe_producer_state);
|
||||
|
||||
// Copy output tile to smem from register
|
||||
if constexpr (cute::is_same_v<CopyOpR2S,DefaultCopy>) {
|
||||
copy(tRS_rD, tRS_sD(_,_,_,epi_m,epi_n));
|
||||
}
|
||||
else {
|
||||
copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,epi_m,epi_n));
|
||||
}
|
||||
}
|
||||
else {
|
||||
// Issue the TMA store of the previous iteration
|
||||
if (not (epi_m == 0 && epi_n == 0)) {
|
||||
// Make sure smem writes are visible to TMA
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
synchronize(); // ensure all threads have issued their async fence
|
||||
|
||||
// Write the tile to gmem from smem with TMA
|
||||
if (issue_tma_store) {
|
||||
copy(params.tma_store_d, tSG_sD(_,_,_,store_pipe_producer_state_prev.index()), tSG_gD(_,_,_,epi_m_prev,epi_n_prev));
|
||||
if constexpr (StoreT) {
|
||||
copy(params.tma_store_t, tSG_sT(_,_,_,store_pipe_producer_state_prev.index()), tSG_gT(_,_,_,epi_m_prev,epi_n_prev));
|
||||
}
|
||||
store_pipeline.producer_commit(store_pipe_producer_state_prev);
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for a smem buffer to be available
|
||||
if (issue_tma_store) {
|
||||
store_pipeline.producer_acquire(store_pipe_producer_state);
|
||||
}
|
||||
synchronize();
|
||||
|
||||
// Copy tile to smem from register
|
||||
if constexpr (cute::is_same_v<CopyOpR2S,DefaultCopy>) {
|
||||
copy(tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index()));
|
||||
if constexpr (StoreT) {
|
||||
copy(tRS_rT, tRS_sT(_,_,_,store_pipe_producer_state.index()));
|
||||
}
|
||||
}
|
||||
else {
|
||||
copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index()));
|
||||
if constexpr (StoreT) {
|
||||
copy(tiled_r2s, tRS_rT, tRS_sT(_,_,_,store_pipe_producer_state.index()));
|
||||
}
|
||||
}
|
||||
|
||||
// Advance pipeline state
|
||||
store_pipe_producer_state_prev = store_pipe_producer_state;
|
||||
++store_pipe_producer_state;
|
||||
}
|
||||
|
||||
// Advance pipeline state
|
||||
++store_pipe_producer_state;
|
||||
epi_m_prev = epi_m;
|
||||
epi_n_prev = epi_n;
|
||||
}
|
||||
}
|
||||
|
||||
// Let dma warp know smem buffer is consumed and empty
|
||||
if (epilogue_op.is_source_needed()) {
|
||||
load_pipeline.consumer_release(load_pipe_consumer_state);
|
||||
if constexpr (ReuseSmemC) {
|
||||
// If ReuseSmemC is true, StoreT must be false. Therefore, we do not perform copies for T in this block.
|
||||
|
||||
// Fence and issue the TMA store of the last iteration
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
synchronize(); // ensure all threads have issued their async fence
|
||||
if (issue_tma_store) {
|
||||
copy(params.tma_store_d, tSG_sD(_,_,_,epi_m_prev,epi_n_prev), tSG_gD(_,_,_,epi_m_prev,epi_n_prev));
|
||||
}
|
||||
|
||||
// Arrive and advance pipeline state
|
||||
if (issue_tma_store) {
|
||||
store_pipeline.producer_commit(store_pipe_producer_state);
|
||||
}
|
||||
++store_pipe_producer_state;
|
||||
|
||||
// Wait for a smem buffer to be available
|
||||
if (issue_tma_store) {
|
||||
store_pipeline.producer_acquire(store_pipe_producer_state);
|
||||
}
|
||||
synchronize();
|
||||
|
||||
// Let dma warp know smem buffer is consumed and empty
|
||||
if (is_source_supported && epilogue_op.is_source_needed()) {
|
||||
load_pipeline.consumer_release(store_pipe_producer_state);
|
||||
}
|
||||
}
|
||||
else {
|
||||
// Fence and issue the TMA store of the last iteration
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
synchronize(); // ensure all threads have issued their async fence
|
||||
if (issue_tma_store) {
|
||||
copy(params.tma_store_d, tSG_sD(_,_,_,store_pipe_producer_state_prev.index()), tSG_gD(_,_,_,epi_m_prev,epi_n_prev));
|
||||
if (StoreT) {
|
||||
copy(params.tma_store_t, tSG_sT(_,_,_,store_pipe_producer_state_prev.index()), tSG_gT(_,_,_,epi_m_prev,epi_n_prev));
|
||||
}
|
||||
store_pipeline.producer_commit(store_pipe_producer_state_prev);
|
||||
}
|
||||
|
||||
// Let dma warp know smem buffer is consumed and empty
|
||||
if (epilogue_op.is_source_needed()) {
|
||||
load_pipeline.consumer_release(load_pipe_consumer_state);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -87,7 +87,7 @@ public:
|
||||
|
||||
using FragmentAccumulator = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
using FragmentCompute = Array<ElementCompute, kElementsPerAccess>;
|
||||
using FragmentC = Array<ElementOutput, kElementsPerAccess>;
|
||||
using FragmentC = Array<ElementC, kElementsPerAccess>;
|
||||
using FragmentZ = Array<ElementZ, kElementsPerAccess>;
|
||||
using FragmentT = Array<ElementT, kElementsPerAccess>;
|
||||
|
||||
|
||||
@@ -28,9 +28,9 @@
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*!
|
||||
/*!
|
||||
\file
|
||||
\brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and
|
||||
\brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and
|
||||
batched array variants.
|
||||
*/
|
||||
|
||||
@@ -57,7 +57,7 @@ namespace cutlass::gemm::device {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*!
|
||||
/*!
|
||||
GemmUniversalAdapter is a stateful, reusable GEMM handle built around a kernel
|
||||
of type cutlass::gemm::kernel::Gemm or cutlass::gemm::kernel::GemmUniversal.
|
||||
|
||||
@@ -159,10 +159,10 @@ public:
|
||||
typename CollectiveMainloop::GmemTiledCopyA, ElementA>();
|
||||
static int constexpr kAlignmentB = gemm::detail::get_alignment_count_from_gmem_tiled_copy<
|
||||
typename CollectiveMainloop::GmemTiledCopyB, ElementB>();
|
||||
|
||||
// NOTE: 3.0 DefaultEpilogues don't support vectorized stores (yet)
|
||||
static int constexpr kAlignmentC = 1;
|
||||
static int constexpr kAlignmentD = 1;
|
||||
static int constexpr kAlignmentC = gemm::detail::get_alignment_count_from_gmem_tiled_copy<
|
||||
typename CollectiveEpilogue::GmemTiledCopyC, ElementC>();
|
||||
static int constexpr kAlignmentD = gemm::detail::get_alignment_count_from_gmem_tiled_copy<
|
||||
typename CollectiveEpilogue::GmemTiledCopyD, ElementD>();
|
||||
|
||||
using EpilogueOutputOp = typename CollectiveEpilogue::ThreadEpilogueOp;
|
||||
|
||||
@@ -327,7 +327,7 @@ public:
|
||||
static Status
|
||||
run(Params& params, cudaStream_t stream = nullptr) {
|
||||
CUTLASS_TRACE_HOST("GemmUniversal::run()");
|
||||
dim3 constexpr block = GemmKernel::get_block_shape();
|
||||
dim3 const block = GemmKernel::get_block_shape();
|
||||
dim3 const grid = get_grid_shape(params);
|
||||
|
||||
// configure smem size and carveout
|
||||
@@ -404,19 +404,19 @@ public:
|
||||
|
||||
using GemmKernel = GemmKernel_;
|
||||
|
||||
static bool const kInternalTranspose =
|
||||
static bool const kInternalTranspose =
|
||||
cute::is_same<typename GemmKernel::LayoutC, cutlass::layout::RowMajor>::value;
|
||||
|
||||
using ThreadblockShape = typename GemmKernel::Mma::Shape;
|
||||
using WarpShape = typename GemmKernel::WarpShape;
|
||||
using InstructionShape = typename GemmKernel::InstructionShape;
|
||||
|
||||
// warp-level, arch-level (instruction), math operator
|
||||
// warp-level, arch-level (instruction), math operator
|
||||
using WarpMmaOperator = typename GemmKernel::Mma::Policy::Operator;
|
||||
using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator;
|
||||
using MathOperator = typename WarpMmaOperator::MathOperator;
|
||||
|
||||
// Operator class and arch tag extract bottom-up
|
||||
|
||||
// Operator class and arch tag extract bottom-up
|
||||
// set it for top-level gemm device-level template
|
||||
using OperatorClass = typename WarpMmaOperator::OperatorClass;
|
||||
using ArchTag = typename WarpMmaOperator::ArchTag;
|
||||
@@ -444,15 +444,15 @@ public:
|
||||
using LayoutB = typename MapArguments::LayoutB;
|
||||
static ComplexTransform const kTransformB = MapArguments::kTransformB;
|
||||
static int const kAlignmentB = MapArguments::kAlignmentB;
|
||||
|
||||
|
||||
using ElementC = typename GemmKernel::ElementC;
|
||||
using LayoutC = typename MapArguments::LayoutC;
|
||||
static int const kAlignmentC = GemmKernel::kAlignmentC;
|
||||
|
||||
|
||||
// C and D same type for 2.x kernel
|
||||
using ElementD = ElementC;
|
||||
using LayoutD = LayoutC;
|
||||
|
||||
|
||||
using TensorRefA = TensorRef<ElementA const, LayoutA>;
|
||||
using TensorRefB = TensorRef<ElementB const, LayoutB>;
|
||||
using TensorRefC = TensorRef<ElementC const, LayoutC>;
|
||||
@@ -493,12 +493,12 @@ public:
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t get_workspace_size(Arguments const &args) {
|
||||
|
||||
|
||||
return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args));
|
||||
}
|
||||
|
||||
/// Computes the grid shape
|
||||
static dim3 get_grid_shape(Arguments const &args) {
|
||||
static dim3 get_grid_shape(Arguments const &args) {
|
||||
return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args));
|
||||
}
|
||||
|
||||
@@ -532,12 +532,12 @@ public:
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr,
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) {
|
||||
|
||||
|
||||
Status status = initialize(args, workspace, stream);
|
||||
|
||||
|
||||
if (status == Status::kSuccess) {
|
||||
status = run(stream);
|
||||
}
|
||||
|
||||
@@ -75,6 +75,8 @@ public:
|
||||
static ComplexTransform const kTransformB = GemvKernel::kTransformB;
|
||||
|
||||
static int const kThreadCount = GemvKernel::kThreadCount;
|
||||
static int const kThreadsPerRow = GemvKernel::kThreadsPerRow;
|
||||
|
||||
static int const kStages = GemvKernel::kStages;
|
||||
|
||||
static int const kAlignmentA = GemvKernel::kAlignmentA;
|
||||
@@ -106,8 +108,23 @@ public:
|
||||
}
|
||||
|
||||
/// Computes the grid shape
|
||||
static dim3 get_grid_shape(Arguments const &args) {
|
||||
return dim3((args.problem_size.row() + (kThreadCount - 1)) / kThreadCount, 1, args.batch_count % 65565);
|
||||
static dim3 get_grid_shape(Arguments const &args, dim3 const &block) {
|
||||
if(platform::is_same<LayoutA, layout::ColumnMajor>::value) {
|
||||
return dim3((args.problem_size.row() + (block.x - 1)) / block.x, 1, args.batch_count % 65536);
|
||||
}
|
||||
else {
|
||||
return dim3((args.problem_size.row() + (block.y - 1)) / block.y, 1, args.batch_count % 65536);
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes the block shape
|
||||
static dim3 get_block_shape() {
|
||||
if(platform::is_same<LayoutA, layout::ColumnMajor>::value) {
|
||||
return dim3(kThreadCount, 1, 1);
|
||||
}
|
||||
else {
|
||||
return dim3(kThreadsPerRow, kThreadCount / kThreadsPerRow, 1);
|
||||
}
|
||||
}
|
||||
|
||||
/// Initializes Gemv state from arguments.
|
||||
@@ -124,8 +141,8 @@ public:
|
||||
/// Runs the kernel using initialized state.
|
||||
Status run(cudaStream_t stream = nullptr) {
|
||||
|
||||
dim3 grid = get_grid_shape(params_);
|
||||
dim3 block(GemvKernel::kThreadCount, 1, 1);
|
||||
dim3 block = get_block_shape();
|
||||
dim3 grid = get_grid_shape(params_, block);
|
||||
|
||||
int smem_size = int(sizeof(typename GemvKernel::SharedStorage));
|
||||
|
||||
@@ -137,11 +154,7 @@ public:
|
||||
//
|
||||
cudaError_t result = cudaGetLastError();
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
|
||||
@@ -1,167 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/device_kernel.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.h"
|
||||
|
||||
#include "cutlass/gemm/kernel/default_gemm_universal.h"
|
||||
#include "cutlass/gemm/device/default_gemm_configuration.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_base.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace device {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemvKernel_>
|
||||
class GemvStridedBatched {
|
||||
public:
|
||||
|
||||
using GemvKernel = GemvKernel_;
|
||||
|
||||
using ElementA = typename GemvKernel::ElementA;
|
||||
using LayoutA = typename GemvKernel::LayoutA;
|
||||
using ElementB = typename GemvKernel::ElementB;
|
||||
using ElementC = typename GemvKernel::ElementC;
|
||||
|
||||
using ElementAccumulator = typename GemvKernel::ElementAccumulator;
|
||||
|
||||
using EpilogueOutputOp = typename GemvKernel::EpilogueOutputOp;
|
||||
|
||||
static ComplexTransform const kTransformA = GemvKernel::kTransformA;
|
||||
static ComplexTransform const kTransformB = GemvKernel::kTransformB;
|
||||
|
||||
static int const kThreadCount = GemvKernel::kThreadCount;
|
||||
static int const mThreadCount = GemvKernel::mThreadCount;
|
||||
|
||||
static int const kStages = GemvKernel::kStages;
|
||||
|
||||
static int const kAlignmentA = GemvKernel::kAlignmentA;
|
||||
static int const kAlignmentB = GemvKernel::kAlignmentB;
|
||||
static int const kAlignmentC = GemvKernel::kAlignmentC;
|
||||
|
||||
using Arguments = typename GemvKernel::Arguments;
|
||||
using Params = typename GemvKernel::Params;
|
||||
|
||||
private:
|
||||
|
||||
Params params_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs the Gemv.
|
||||
GemvStridedBatched() {}
|
||||
|
||||
/// Determines whether the Gemv can execute the given problem.
|
||||
static Status can_implement(Arguments const& args) {
|
||||
return GemvKernel::can_implement(args);
|
||||
}
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t get_workspace_size(Arguments const& args) { return 0; }
|
||||
|
||||
/// Initializes Gemv state from arguments.
|
||||
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
params_ = Params(args);
|
||||
|
||||
if (args.problem_size.column() % GemvKernel::kElementsPerAccess) {
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Lightweight update given a subset of arguments
|
||||
Status update(Arguments const &args, void *workspace = nullptr) {
|
||||
return params_.update(args);
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status run(cudaStream_t stream = nullptr) {
|
||||
|
||||
dim3 grid(1, 1, params_.batch_count % 65536);
|
||||
dim3 block(kThreadCount, mThreadCount, 1);
|
||||
|
||||
int smem_size = 0;
|
||||
|
||||
// Launch
|
||||
cutlass::Kernel<GemvKernel><<<grid, block, smem_size, stream>>>(params_);
|
||||
|
||||
//
|
||||
// Query for errors
|
||||
//
|
||||
|
||||
cudaError_t result = cudaGetLastError();
|
||||
|
||||
return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(cudaStream_t stream = nullptr) { return run(stream); }
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) {
|
||||
|
||||
Status status = initialize(args, workspace, stream);
|
||||
|
||||
if (status == Status::kSuccess) {
|
||||
status = run(stream);
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace device
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -522,15 +522,27 @@ stride_to_layout_tag_B() {
|
||||
template <class GmemTiledCopy, class Element>
|
||||
constexpr int
|
||||
get_alignment_count_from_gmem_tiled_copy() {
|
||||
// For TMA tiled copies, we know the alignment has to be 128 bits
|
||||
if constexpr ( cute::is_base_of_v<cute::SM90_TMA_LOAD, GmemTiledCopy>
|
||||
|| cute::is_base_of_v<cute::SM90_TMA_LOAD_MULTICAST, GmemTiledCopy>
|
||||
) {
|
||||
return 128 / sizeof_bits<Element>::value;
|
||||
if constexpr (cute::is_void_v<GmemTiledCopy>) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Account for ElementC = void kernels
|
||||
else if constexpr (cute::is_void_v<Element>) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
else {
|
||||
// For non-TMA tiled copies, TiledCopy holds the alignment count directly in its TiledShape_MN
|
||||
return GmemTiledCopy::NumValSrc;
|
||||
// For TMA tiled copies, we know the alignment has to be 128 bits
|
||||
if constexpr ( cute::is_base_of_v<cute::SM90_TMA_LOAD, GmemTiledCopy>
|
||||
|| cute::is_base_of_v<cute::SM90_TMA_LOAD_MULTICAST, GmemTiledCopy>
|
||||
|| cute::is_base_of_v<cute::SM90_TMA_STORE, GmemTiledCopy>
|
||||
) {
|
||||
return 128 / sizeof_bits<Element>::value;
|
||||
}
|
||||
else {
|
||||
// For non-TMA tiled copies, TiledCopy holds the alignment count directly in its TiledShape_MN
|
||||
return GmemTiledCopy::NumValSrc;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -41,9 +41,13 @@
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/arch/cache_operation.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
@@ -58,18 +62,49 @@ template <
|
||||
typename ElementB_,
|
||||
typename ElementC_,
|
||||
typename ElementAccumulator_,
|
||||
typename EpilogueOutputOp_
|
||||
typename EpilogueOutputOp_,
|
||||
int kElementsPerAccess_ = 1, ///< Number of elements involved in a global access.
|
||||
int kThreadCount_ = 0, ///< Number of threads in the thread block.
|
||||
/// It will be calculated automatically if set to 0.
|
||||
int kThreadsPerRow_ = 0 ///< Number of threads in the k dimension.
|
||||
/// It will be calculated automatically if set to 0.
|
||||
>
|
||||
struct Gemv {
|
||||
struct Gemv;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Specializations
|
||||
//
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// GEMV for column-major A matrix
|
||||
template <
|
||||
typename ElementA_,
|
||||
typename ElementB_,
|
||||
typename ElementC_,
|
||||
typename ElementAccumulator_,
|
||||
typename EpilogueOutputOp_,
|
||||
int kElementsPerAccess_,
|
||||
int kThreadCount_,
|
||||
int kThreadsPerRow_
|
||||
>
|
||||
struct Gemv <
|
||||
ElementA_,
|
||||
layout::ColumnMajor,
|
||||
ElementB_,
|
||||
ElementC_,
|
||||
ElementAccumulator_,
|
||||
EpilogueOutputOp_,
|
||||
kElementsPerAccess_,
|
||||
kThreadCount_,
|
||||
kThreadsPerRow_
|
||||
>{
|
||||
public:
|
||||
|
||||
using ElementA = ElementA_;
|
||||
using LayoutA = layout::ColumnMajor;
|
||||
using TensorRefA = TensorRef<ElementA, LayoutA>;
|
||||
|
||||
static_assert(platform::is_same<LayoutA, LayoutA_>::value,
|
||||
"Only supported for column-major A matrix");
|
||||
|
||||
using ElementB = ElementB_;
|
||||
using ElementC = ElementC_;
|
||||
|
||||
@@ -79,7 +114,10 @@ public:
|
||||
static ComplexTransform const kTransformA = ComplexTransform::kNone;
|
||||
static ComplexTransform const kTransformB = ComplexTransform::kNone;
|
||||
|
||||
static int const kThreadCount = 32;
|
||||
// thread block shape (kThreadCount, 1, 1)
|
||||
static int const kThreadCount = (kThreadCount_ == 0) ? 32 : kThreadCount_;
|
||||
static int const kThreadsPerRow = kThreadsPerRow_;
|
||||
|
||||
static int const kStages = 1;
|
||||
|
||||
static int const kAlignmentA = 1;
|
||||
@@ -121,17 +159,17 @@ public:
|
||||
MatrixCoord problem_size,
|
||||
int batch_count,
|
||||
typename EpilogueOutputOp::Params output_op,
|
||||
TensorRefA ref_A,
|
||||
void const * ptr_B,
|
||||
void const * ptr_C,
|
||||
void * ptr_D,
|
||||
int64_t inc_B,
|
||||
int64_t inc_C,
|
||||
int64_t inc_D,
|
||||
int64_t batch_stride_A,
|
||||
int64_t batch_stride_B,
|
||||
int64_t batch_stride_C,
|
||||
int64_t batch_stride_D
|
||||
TensorRefA ref_A,
|
||||
void const *ptr_B,
|
||||
void const *ptr_C,
|
||||
void *ptr_D,
|
||||
int64_t inc_B,
|
||||
int64_t inc_C,
|
||||
int64_t inc_D,
|
||||
int64_t batch_stride_A,
|
||||
int64_t batch_stride_B,
|
||||
int64_t batch_stride_C,
|
||||
int64_t batch_stride_D
|
||||
):
|
||||
problem_size(problem_size),
|
||||
batch_count(batch_count),
|
||||
@@ -151,14 +189,44 @@ public:
|
||||
|
||||
Arguments(
|
||||
MatrixCoord problem_size,
|
||||
int batch_count,
|
||||
typename EpilogueOutputOp::Params output_op,
|
||||
TensorRefA ref_A,
|
||||
void const * ptr_B,
|
||||
void const * ptr_C,
|
||||
void * ptr_D,
|
||||
int64_t inc_B,
|
||||
int64_t inc_C,
|
||||
int64_t inc_D
|
||||
TensorRefA ref_A,
|
||||
void const *ptr_B,
|
||||
void const *ptr_C,
|
||||
void *ptr_D,
|
||||
int64_t batch_stride_A,
|
||||
int64_t batch_stride_B,
|
||||
int64_t batch_stride_C,
|
||||
int64_t batch_stride_D
|
||||
):
|
||||
Arguments(
|
||||
problem_size,
|
||||
batch_count,
|
||||
output_op,
|
||||
ref_A,
|
||||
ptr_B,
|
||||
ptr_C,
|
||||
ptr_D,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
batch_stride_A,
|
||||
batch_stride_B,
|
||||
batch_stride_C,
|
||||
batch_stride_D)
|
||||
{ }
|
||||
|
||||
Arguments(
|
||||
MatrixCoord problem_size,
|
||||
typename EpilogueOutputOp::Params output_op,
|
||||
TensorRefA ref_A,
|
||||
void const *ptr_B,
|
||||
void const *ptr_C,
|
||||
void *ptr_D,
|
||||
int64_t inc_B,
|
||||
int64_t inc_C,
|
||||
int64_t inc_D
|
||||
):
|
||||
Arguments(
|
||||
problem_size,
|
||||
@@ -206,7 +274,6 @@ public:
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(cutlass::MatrixCoord const & problem_size) {
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
@@ -214,7 +281,7 @@ public:
|
||||
return can_implement(args.problem_size);
|
||||
}
|
||||
|
||||
/// Executes one GEMM
|
||||
/// Executes one GEMV
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
||||
|
||||
@@ -282,6 +349,288 @@ public:
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// GEMV for row-major A matrix
|
||||
template <
|
||||
typename ElementA_,
|
||||
typename ElementB_,
|
||||
typename ElementC_,
|
||||
typename ElementAccumulator_,
|
||||
typename EpilogueOutputOp_,
|
||||
int kElementsPerAccess_,
|
||||
int kThreadCount_,
|
||||
int kThreadsPerRow_
|
||||
>
|
||||
struct Gemv <
|
||||
ElementA_,
|
||||
layout::RowMajor,
|
||||
ElementB_,
|
||||
ElementC_,
|
||||
ElementAccumulator_,
|
||||
EpilogueOutputOp_,
|
||||
kElementsPerAccess_,
|
||||
kThreadCount_,
|
||||
kThreadsPerRow_
|
||||
>{
|
||||
public:
|
||||
|
||||
using ElementA = ElementA_;
|
||||
using LayoutA = layout::RowMajor;
|
||||
using TensorRefA = TensorRef<ElementA, LayoutA>;
|
||||
|
||||
using ElementB = ElementB_;
|
||||
using ElementC = ElementC_;
|
||||
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using EpilogueOutputOp = EpilogueOutputOp_;
|
||||
|
||||
static ComplexTransform const kTransformA = ComplexTransform::kNone;
|
||||
static ComplexTransform const kTransformB = ComplexTransform::kNone;
|
||||
|
||||
static FloatRoundStyle const Round = cutlass::FloatRoundStyle::round_to_nearest;
|
||||
|
||||
// number of return elements in a global access
|
||||
static int const kElementsPerAccess = kElementsPerAccess_;
|
||||
|
||||
using FragmentA = Array<ElementA, kElementsPerAccess>;
|
||||
using FragmentB = Array<ElementB, kElementsPerAccess>;
|
||||
using FragmentCompute = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
|
||||
// thread block shape (kThreadsPerRow, kThreadCount / kThreadsPerRow, 1)
|
||||
static int const kThreadCount = (kThreadCount_ == 0) ? 128 : kThreadCount_;
|
||||
static int const kThreadsPerRow = (kThreadsPerRow_ == 0) ?
|
||||
std::min(static_cast<int>(kThreadCount / (kElementsPerAccess * sizeof(ElementA))), 16)
|
||||
: kThreadsPerRow_;
|
||||
|
||||
//
|
||||
// Structures
|
||||
//
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
MatrixCoord problem_size;
|
||||
int32_t batch_count;
|
||||
typename EpilogueOutputOp::Params output_op;
|
||||
|
||||
TensorRefA ref_A;
|
||||
|
||||
ElementB const *ptr_B;
|
||||
ElementC const *ptr_C;
|
||||
ElementC *ptr_D;
|
||||
|
||||
int64_t batch_stride_A;
|
||||
int64_t batch_stride_B;
|
||||
int64_t batch_stride_C;
|
||||
int64_t batch_stride_D;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
Arguments(): batch_count(0) { }
|
||||
|
||||
Arguments(
|
||||
MatrixCoord problem_size,
|
||||
int32_t batch_count,
|
||||
typename EpilogueOutputOp::Params output_op,
|
||||
TensorRefA ref_A,
|
||||
void const *ptr_B,
|
||||
void const *ptr_C,
|
||||
void *ptr_D,
|
||||
int64_t batch_stride_A,
|
||||
int64_t batch_stride_B,
|
||||
int64_t batch_stride_C,
|
||||
int64_t batch_stride_D
|
||||
):
|
||||
problem_size(problem_size),
|
||||
batch_count(batch_count),
|
||||
output_op(output_op),
|
||||
ref_A(ref_A),
|
||||
ptr_B(static_cast<ElementB const *>(ptr_B)),
|
||||
ptr_C(static_cast<ElementC const *>(ptr_C)),
|
||||
ptr_D(static_cast<ElementC *>(ptr_D)),
|
||||
batch_stride_A(batch_stride_A),
|
||||
batch_stride_B(batch_stride_B),
|
||||
batch_stride_C(batch_stride_C),
|
||||
batch_stride_D(batch_stride_D)
|
||||
{ }
|
||||
|
||||
Arguments(
|
||||
MatrixCoord problem_size,
|
||||
typename EpilogueOutputOp::Params output_op,
|
||||
TensorRefA ref_A,
|
||||
void const *ptr_B,
|
||||
void const *ptr_C,
|
||||
void *ptr_D
|
||||
):
|
||||
Arguments(
|
||||
problem_size,
|
||||
1,
|
||||
output_op,
|
||||
ref_A,
|
||||
ptr_B,
|
||||
ptr_C,
|
||||
ptr_D,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1)
|
||||
{ }
|
||||
|
||||
Status update(Arguments const &args) {
|
||||
problem_size = args.problem_size;
|
||||
batch_count = args.batch_count;
|
||||
output_op = args.output_op;
|
||||
ref_A = ref_A;
|
||||
ptr_B = args.ptr_B;
|
||||
ptr_C = args.ptr_C;
|
||||
ptr_D = args.ptr_D;
|
||||
batch_stride_A = args.batch_stride_A;
|
||||
batch_stride_B = args.batch_stride_B;
|
||||
batch_stride_C = args.batch_stride_C;
|
||||
batch_stride_D = args.batch_stride_D;
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
/// Shared memory storage structure
|
||||
union SharedStorage {
|
||||
|
||||
};
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_DEVICE
|
||||
Gemv() {}
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(cutlass::MatrixCoord const &problem_size) {
|
||||
if (problem_size.column() % kElementsPerAccess != 0) {
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static Status can_implement(Arguments const &args) {
|
||||
return can_implement(args.problem_size);
|
||||
}
|
||||
|
||||
/// Executes one GEMV
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
||||
|
||||
// Loop over batch indices
|
||||
for (int batch_idx = blockIdx.z; batch_idx < params.batch_count; batch_idx += gridDim.z) {
|
||||
int idx_col_k = threadIdx.x;
|
||||
int idx_row_m = blockIdx.x * blockDim.y + threadIdx.y;
|
||||
|
||||
if (idx_row_m < params.problem_size.row()) {
|
||||
// problem_size (row = m, column = k)
|
||||
// matrix A (batch, m, k)
|
||||
// vector B (batch, 1, k)
|
||||
// vector C (batch, m, 1)
|
||||
// vector D (batch, m, 1)
|
||||
|
||||
// move in the batch dimension
|
||||
ElementA const *ptr_A = params.ref_A.data() + batch_idx * params.batch_stride_A;
|
||||
ElementB const *ptr_B = params.ptr_B + batch_idx * params.batch_stride_B;
|
||||
|
||||
ElementC const *ptr_C = params.ptr_C + batch_idx * params.batch_stride_C;
|
||||
ElementC *ptr_D = params.ptr_D + batch_idx * params.batch_stride_D;
|
||||
|
||||
// move in the k dimension
|
||||
ptr_A += idx_col_k * kElementsPerAccess;
|
||||
ptr_B += idx_col_k * kElementsPerAccess;
|
||||
|
||||
// move in the m dimension
|
||||
ptr_A += idx_row_m * params.problem_size.column();
|
||||
ptr_C += idx_row_m;
|
||||
ptr_D += idx_row_m;
|
||||
|
||||
NumericArrayConverter<ElementAccumulator, ElementA, kElementsPerAccess, Round> srcA_converter;
|
||||
NumericArrayConverter<ElementAccumulator, ElementB, kElementsPerAccess, Round> srcB_converter;
|
||||
|
||||
ElementAccumulator accum = 0.f;
|
||||
|
||||
FragmentB fragB;
|
||||
FragmentA fragA;
|
||||
|
||||
int unroll_col_k = 0;
|
||||
|
||||
// rows of the rolling tile
|
||||
int const tileA_k = kThreadsPerRow * kElementsPerAccess;
|
||||
|
||||
for (; unroll_col_k < params.problem_size.column() / tileA_k * tileA_k; unroll_col_k += tileA_k) {
|
||||
|
||||
// fetch from matrix A
|
||||
arch::global_load<FragmentA,
|
||||
sizeof(FragmentA),
|
||||
arch::CacheOperation::LastUse>(fragA, (ptr_A + unroll_col_k), true);
|
||||
|
||||
// fetch from vector B
|
||||
arch::global_load<FragmentB,
|
||||
sizeof(FragmentB),
|
||||
arch::CacheOperation::Always>(fragB, (ptr_B + unroll_col_k), true);
|
||||
|
||||
FragmentCompute fragB_Compute = srcB_converter(fragB);
|
||||
FragmentCompute fragA_Compute = srcA_converter(fragA);
|
||||
|
||||
// Math
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int e = 0; e < kElementsPerAccess; e++) {
|
||||
accum += fragA_Compute.at(e) * fragB_Compute.at(e);
|
||||
}
|
||||
}
|
||||
|
||||
// calculate the rest of K elements
|
||||
// each thread fetch 1 element each time
|
||||
for (int k = unroll_col_k + idx_col_k; k < params.problem_size.column(); k += kThreadsPerRow) {
|
||||
ElementB b = *(ptr_B - idx_col_k * kElementsPerAccess + k);
|
||||
ElementA a = *(ptr_A - idx_col_k * kElementsPerAccess + k);
|
||||
|
||||
accum += ElementAccumulator(a) * ElementAccumulator(b);
|
||||
}
|
||||
|
||||
EpilogueOutputOp output_op(params.output_op);
|
||||
typename EpilogueOutputOp::FragmentOutput source_fragment;
|
||||
|
||||
// prefetch from source matrix C
|
||||
if (output_op.is_source_needed()) {
|
||||
source_fragment[0] = *(ptr_C);
|
||||
}
|
||||
|
||||
typename EpilogueOutputOp::FragmentAccumulator accum_fragment;
|
||||
typename EpilogueOutputOp::FragmentOutput output_fragment;
|
||||
|
||||
for (int mask = (kThreadsPerRow >> 1); mask > 0; mask >>= 1) {
|
||||
accum += __shfl_xor_sync(0xFFFFFFFF, accum, mask, 32);
|
||||
}
|
||||
|
||||
if (idx_col_k == 0) {
|
||||
accum_fragment[0] = accum;
|
||||
|
||||
if (output_op.is_source_needed()) {
|
||||
output_fragment = output_op(accum_fragment, source_fragment);
|
||||
}
|
||||
else {
|
||||
output_fragment = output_op(accum_fragment);
|
||||
}
|
||||
|
||||
*ptr_D = output_fragment[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
@@ -1,368 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/arch/cache_operation.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename ElementA_, /// matrix
|
||||
typename LayoutA_,
|
||||
typename ElementB_, /// vector
|
||||
typename ElementC_,
|
||||
typename ElementAccumulator_,
|
||||
int kElementsPerAccess_,
|
||||
typename EpilogueOutputOp_
|
||||
>
|
||||
struct GemvStridedBatched {
|
||||
public:
|
||||
|
||||
using ElementA = ElementA_;
|
||||
using LayoutA = layout::RowMajor;
|
||||
using TensorRefA = TensorRef<ElementA, LayoutA>;
|
||||
|
||||
static_assert(std::is_same<LayoutA, LayoutA_>::value,
|
||||
"Only supported for row-major A matrix");
|
||||
|
||||
using ElementB = ElementB_;
|
||||
using ElementC = ElementC_;
|
||||
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using EpilogueOutputOp = EpilogueOutputOp_;
|
||||
|
||||
static ComplexTransform const kTransformA = ComplexTransform::kNone;
|
||||
static ComplexTransform const kTransformB = ComplexTransform::kNone;
|
||||
|
||||
static FloatRoundStyle const Round = cutlass::FloatRoundStyle::round_to_nearest;
|
||||
|
||||
// number of return elements in a global access
|
||||
static int const kElementsPerAccess = kElementsPerAccess_;
|
||||
|
||||
using FragmentA = Array<ElementA, kElementsPerAccess>;
|
||||
using FragmentB = Array<ElementB, kElementsPerAccess>;
|
||||
using FragmentCompute = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
|
||||
// thread block shape (kThreadCount, mThreadCount)
|
||||
static int const kThreadCount = std::min(static_cast<int>(128 / (kElementsPerAccess * sizeof(ElementA))), 16);
|
||||
static int const mThreadCount = 128 / kThreadCount;
|
||||
|
||||
// rolling tile shape
|
||||
static int const kTileA = kThreadCount * kElementsPerAccess;
|
||||
static int const mTileA = mThreadCount * 8;
|
||||
|
||||
//
|
||||
// Structures
|
||||
//
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments
|
||||
{
|
||||
MatrixCoord problem_size;
|
||||
int32_t batch_count;
|
||||
typename EpilogueOutputOp::Params output_op;
|
||||
|
||||
TensorRefA ref_A;
|
||||
|
||||
ElementB const *ptr_B;
|
||||
ElementC const *ptr_C;
|
||||
ElementC *ptr_D;
|
||||
|
||||
int64_t batch_stride_A;
|
||||
int64_t batch_stride_B;
|
||||
int64_t batch_stride_C;
|
||||
int64_t batch_stride_D;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
Arguments() : batch_count(0) {}
|
||||
|
||||
Arguments(
|
||||
MatrixCoord problem_size,
|
||||
int32_t batch_count,
|
||||
|
||||
typename EpilogueOutputOp::Params output_op,
|
||||
TensorRefA ref_A,
|
||||
void const *ptr_B,
|
||||
void const *ptr_C,
|
||||
void *ptr_D,
|
||||
|
||||
int64_t batch_stride_A,
|
||||
int64_t batch_stride_B,
|
||||
int64_t batch_stride_C,
|
||||
int64_t batch_stride_D) : problem_size(problem_size),
|
||||
batch_count(batch_count),
|
||||
output_op(output_op),
|
||||
ref_A(ref_A),
|
||||
ptr_B(static_cast<ElementB const *>(ptr_B)),
|
||||
ptr_C(static_cast<ElementC const *>(ptr_C)),
|
||||
ptr_D(static_cast<ElementC *>(ptr_D)),
|
||||
|
||||
batch_stride_A(batch_stride_A),
|
||||
batch_stride_B(batch_stride_B),
|
||||
batch_stride_C(batch_stride_C),
|
||||
batch_stride_D(batch_stride_D)
|
||||
{
|
||||
}
|
||||
|
||||
Arguments(
|
||||
MatrixCoord problem_size,
|
||||
typename EpilogueOutputOp::Params output_op,
|
||||
TensorRefA ref_A,
|
||||
void const *ptr_B,
|
||||
void const *ptr_C,
|
||||
void *ptr_D) : Arguments(problem_size,
|
||||
1,
|
||||
1,
|
||||
output_op,
|
||||
ref_A,
|
||||
ptr_B,
|
||||
ptr_C,
|
||||
ptr_D,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1)
|
||||
{
|
||||
}
|
||||
|
||||
Status update(Arguments const &args)
|
||||
{
|
||||
problem_size = args.problem_size;
|
||||
batch_count = args.batch_count;
|
||||
output_op = args.output_op;
|
||||
ref_A = ref_A;
|
||||
ptr_B = args.ptr_B;
|
||||
ptr_C = args.ptr_C;
|
||||
ptr_D = args.ptr_D;
|
||||
batch_stride_A = args.batch_stride_A;
|
||||
batch_stride_B = args.batch_stride_B;
|
||||
batch_stride_C = args.batch_stride_C;
|
||||
batch_stride_D = args.batch_stride_D;
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
/// Shared memory storage structure
|
||||
union SharedStorage
|
||||
{
|
||||
};
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_DEVICE
|
||||
GemvStridedBatched() {}
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(cutlass::MatrixCoord const &problem_size)
|
||||
{
|
||||
if (problem_size.column() % kElementsPerAccess != 0)
|
||||
return Status::kErrorMisalignedOperand;
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static Status can_implement(Arguments const &args)
|
||||
{
|
||||
return can_implement(args.problem_size);
|
||||
}
|
||||
|
||||
/// Executes one GEMV
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const ¶ms, SharedStorage &shared_storage)
|
||||
{
|
||||
// Loop over batch indices
|
||||
for (int batch_idx = blockIdx.z; batch_idx < params.batch_count; batch_idx += gridDim.z)
|
||||
{
|
||||
int k_col_id = threadIdx.x;
|
||||
int m_row_id = threadIdx.y;
|
||||
|
||||
// problem_size (row = m, column = k)
|
||||
// matrix A (batch, m, k)
|
||||
// vector B (batch, 1, k)
|
||||
// vector C (batch, m, 1)
|
||||
// vector D (batch, m, 1)
|
||||
|
||||
// move in the batch dimension
|
||||
ElementA const *ptr_A = params.ref_A.data() + batch_idx * params.batch_stride_A;
|
||||
ElementB const *ptr_B = params.ptr_B + batch_idx * params.batch_stride_B;
|
||||
|
||||
ElementC const *ptr_C = params.ptr_C + batch_idx * params.batch_stride_C;
|
||||
ElementC *ptr_D = params.ptr_D + batch_idx * params.batch_stride_D;
|
||||
|
||||
// move in the k dimension
|
||||
ptr_A += k_col_id * kElementsPerAccess;
|
||||
ptr_B += k_col_id * kElementsPerAccess;
|
||||
|
||||
// move in the m dimension
|
||||
ptr_A += m_row_id * params.problem_size.column();
|
||||
ptr_C += m_row_id;
|
||||
ptr_D += m_row_id;
|
||||
|
||||
NumericArrayConverter<ElementAccumulator, ElementA, kElementsPerAccess, Round> srcA_converter;
|
||||
NumericArrayConverter<ElementAccumulator, ElementB, kElementsPerAccess, Round> srcB_converter;
|
||||
|
||||
for (; m_row_id < params.problem_size.row(); m_row_id += mTileA)
|
||||
{
|
||||
ElementAccumulator accum[mTileA / mThreadCount] = {0.f};
|
||||
|
||||
FragmentB fragB;
|
||||
FragmentA fragA[mTileA / mThreadCount];
|
||||
|
||||
int mElemCountPerTile = min(mTileA / mThreadCount, (params.problem_size.row() - m_row_id - 1) / mThreadCount + 1);
|
||||
|
||||
int kUnroll = 0;
|
||||
|
||||
for (; kUnroll < params.problem_size.column() / kTileA * kTileA; kUnroll += kTileA)
|
||||
{
|
||||
for (int m = 0; m < mElemCountPerTile; m++)
|
||||
{
|
||||
// fetch from matrix A
|
||||
arch::global_load<FragmentA,
|
||||
sizeof(FragmentA),
|
||||
arch::CacheOperation::LastUse>(fragA[m], (ptr_A + kUnroll + m * mThreadCount * params.problem_size.column()), true);
|
||||
}
|
||||
|
||||
// fetch from vector B
|
||||
arch::global_load<FragmentB,
|
||||
sizeof(FragmentB),
|
||||
arch::CacheOperation::Always>(fragB, (ptr_B + kUnroll), true);
|
||||
|
||||
for (int m = 0; m < mElemCountPerTile; m++)
|
||||
{
|
||||
FragmentCompute fragB_Compute = srcB_converter(fragB);
|
||||
FragmentCompute fragA_Compute = srcA_converter(fragA[m]);
|
||||
|
||||
// Math
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int e = 0; e < kElementsPerAccess; e++)
|
||||
{
|
||||
accum[m] += fragA_Compute.at(e) * fragB_Compute.at(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// calculate the rest of K elements
|
||||
// each thread fetch 1 element each time
|
||||
for (int k = kUnroll + k_col_id; k < params.problem_size.column(); k += kThreadCount)
|
||||
{
|
||||
ElementB b = *(ptr_B - k_col_id * kElementsPerAccess + k);
|
||||
for (int m = 0; m < mElemCountPerTile; m++)
|
||||
{
|
||||
ElementA a = *(ptr_A - k_col_id * kElementsPerAccess + k + m * mThreadCount * params.problem_size.column());
|
||||
accum[m] += ElementAccumulator(a) * ElementAccumulator(b);
|
||||
}
|
||||
}
|
||||
|
||||
EpilogueOutputOp output_op(params.output_op);
|
||||
typename EpilogueOutputOp::FragmentOutput source_fragment[mTileA / mThreadCount];
|
||||
|
||||
// prefetch from source matrix C
|
||||
if (output_op.is_source_needed())
|
||||
{
|
||||
for (int m = 0; m < mElemCountPerTile; m++)
|
||||
{
|
||||
source_fragment[m][0] = *(ptr_C + m * mThreadCount);
|
||||
}
|
||||
}
|
||||
|
||||
typename EpilogueOutputOp::FragmentAccumulator accum_fragment;
|
||||
typename EpilogueOutputOp::FragmentOutput output_fragment;
|
||||
|
||||
for (int m = 0; m < mElemCountPerTile; m++)
|
||||
{
|
||||
for (int mask = (kThreadCount >> 1); mask > 0; mask >>= 1)
|
||||
{
|
||||
accum[m] += __shfl_xor_sync(0xFFFFFFFF, accum[m], mask, 32);
|
||||
}
|
||||
|
||||
if (k_col_id == 0)
|
||||
{
|
||||
accum_fragment[0] = accum[m];
|
||||
|
||||
if (output_op.is_source_needed())
|
||||
{
|
||||
output_fragment = output_op(accum_fragment, source_fragment[m]);
|
||||
}
|
||||
else
|
||||
{
|
||||
output_fragment = output_op(accum_fragment);
|
||||
}
|
||||
|
||||
*(ptr_D + m * mThreadCount) = output_fragment[0];
|
||||
}
|
||||
}
|
||||
|
||||
ptr_A += mTileA * params.problem_size.column();
|
||||
ptr_C += mTileA;
|
||||
ptr_D += mTileA;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -129,21 +129,18 @@ public:
|
||||
};
|
||||
}
|
||||
|
||||
static
|
||||
bool
|
||||
static bool
|
||||
can_implement(Arguments const& args) {
|
||||
return args.mode == GemmUniversalMode::kGemm or
|
||||
(args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4);
|
||||
}
|
||||
|
||||
static
|
||||
int
|
||||
static int
|
||||
get_workspace_size(Arguments const& args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
static constexpr
|
||||
dim3
|
||||
static dim3
|
||||
get_grid_shape(Params const& params) {
|
||||
int batch_count = 1;
|
||||
if constexpr (rank(ProblemShape{}) == 4) {
|
||||
@@ -157,8 +154,7 @@ public:
|
||||
);
|
||||
}
|
||||
|
||||
static constexpr
|
||||
dim3
|
||||
static dim3
|
||||
get_block_shape() {
|
||||
return dim3(MaxThreadsPerBlock, 1, 1);
|
||||
}
|
||||
|
||||
@@ -172,20 +172,20 @@ public:
|
||||
auto N = get<1>(args.problem_shape);
|
||||
auto K = get<2>(args.problem_shape);
|
||||
// Contiguous dimension for the TMA tensor should be 128b aligned
|
||||
implementable = std::is_same_v<gemm::detail::StrideToLayoutTagA_t<StrideA>, layout::RowMajor> ?
|
||||
implementable = std::is_same_v<gemm::detail::StrideToLayoutTagA_t<StrideA>, layout::RowMajor> ?
|
||||
K % min_tma_aligned_elements == 0 : M % min_tma_aligned_elements == 0;
|
||||
implementable = implementable && (std::is_same_v<gemm::detail::StrideToLayoutTagB_t<StrideB>, layout::RowMajor> ?
|
||||
implementable = implementable && (std::is_same_v<gemm::detail::StrideToLayoutTagB_t<StrideB>, layout::RowMajor> ?
|
||||
N % min_tma_aligned_elements == 0 : K % min_tma_aligned_elements == 0);
|
||||
implementable = implementable && (!cutlass::epilogue::collective::detail::IF_EPILOGUE_USES_TMA<CollectiveEpilogue>::value ||
|
||||
(cutlass::epilogue::collective::detail::IF_EPILOGUE_USES_TMA<CollectiveEpilogue>::value &&
|
||||
std::is_same_v<gemm::detail::StrideToLayoutTagC_t<StrideC>, layout::RowMajor> ?
|
||||
std::is_same_v<gemm::detail::StrideToLayoutTagC_t<StrideC>, layout::RowMajor> ?
|
||||
N % min_tma_aligned_elements == 0 : M % min_tma_aligned_elements == 0));
|
||||
if (!implementable) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
|
||||
return implementable;
|
||||
}
|
||||
|
||||
constexpr bool is_beta_supported =
|
||||
constexpr bool is_beta_supported =
|
||||
CollectiveEpilogue::ThreadEpilogueOp::kScale == cutlass::epilogue::thread::ScaleType::Default;
|
||||
implementable = is_beta_supported || (args.epilogue.thread.beta == 0 && args.epilogue.thread.beta_ptr == nullptr);
|
||||
if (!implementable) {
|
||||
@@ -196,15 +196,13 @@ public:
|
||||
return implementable;
|
||||
}
|
||||
|
||||
static
|
||||
int
|
||||
static int
|
||||
get_workspace_size(Arguments const& args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Computes the kernel launch grid shape based on runtime parameters
|
||||
static constexpr
|
||||
dim3
|
||||
static dim3
|
||||
get_grid_shape(Params const& params) {
|
||||
auto cluster_shape = ClusterShape{};
|
||||
auto tile_shape = TileShape{};
|
||||
@@ -213,8 +211,7 @@ public:
|
||||
problem_shape_MNKL, tile_shape, cluster_shape);
|
||||
}
|
||||
|
||||
static constexpr
|
||||
dim3
|
||||
static dim3
|
||||
get_block_shape() {
|
||||
return dim3(MaxThreadsPerBlock, 1, 1);
|
||||
}
|
||||
@@ -243,7 +240,7 @@ public:
|
||||
int warp_idx = canonical_warp_idx();
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
// Issue Tma Descriptor Prefetch from a single thread
|
||||
// Issue Tma Descriptor Prefetch from a single thread
|
||||
if ((warp_idx == 0) && lane_predicate) {
|
||||
CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
|
||||
}
|
||||
|
||||
@@ -179,20 +179,20 @@ public:
|
||||
auto N = get<1>(args.problem_shape);
|
||||
auto K = get<2>(args.problem_shape);
|
||||
// Contiguous dimension for the TMA tensor should be 128b aligned
|
||||
implementable = std::is_same_v<gemm::detail::StrideToLayoutTagA_t<StrideA>, layout::RowMajor> ?
|
||||
implementable = std::is_same_v<gemm::detail::StrideToLayoutTagA_t<StrideA>, layout::RowMajor> ?
|
||||
K % min_tma_aligned_elements == 0 : M % min_tma_aligned_elements == 0;
|
||||
implementable = implementable && (std::is_same_v<gemm::detail::StrideToLayoutTagB_t<StrideB>, layout::RowMajor> ?
|
||||
implementable = implementable && (std::is_same_v<gemm::detail::StrideToLayoutTagB_t<StrideB>, layout::RowMajor> ?
|
||||
N % min_tma_aligned_elements == 0 : K % min_tma_aligned_elements == 0);
|
||||
implementable = implementable && (!cutlass::epilogue::collective::detail::IF_EPILOGUE_USES_TMA<CollectiveEpilogue>::value ||
|
||||
(cutlass::epilogue::collective::detail::IF_EPILOGUE_USES_TMA<CollectiveEpilogue>::value &&
|
||||
std::is_same_v<gemm::detail::StrideToLayoutTagC_t<StrideC>, layout::RowMajor> ?
|
||||
std::is_same_v<gemm::detail::StrideToLayoutTagC_t<StrideC>, layout::RowMajor> ?
|
||||
N % min_tma_aligned_elements == 0 : M % min_tma_aligned_elements == 0));
|
||||
if (!implementable) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
|
||||
return implementable;
|
||||
}
|
||||
|
||||
constexpr bool is_beta_supported =
|
||||
constexpr bool is_beta_supported = not cute::is_void_v<ElementC> &&
|
||||
CollectiveEpilogue::ThreadEpilogueOp::kScale == cutlass::epilogue::thread::ScaleType::Default;
|
||||
implementable = is_beta_supported || (args.epilogue.thread.beta == 0 && args.epilogue.thread.beta_ptr == nullptr);
|
||||
if (!implementable) {
|
||||
@@ -210,8 +210,7 @@ public:
|
||||
}
|
||||
|
||||
// Computes the kernel launch grid shape based on runtime parameters
|
||||
static constexpr
|
||||
dim3
|
||||
static dim3
|
||||
get_grid_shape(Params const& params) {
|
||||
auto cluster_shape = ClusterShape{};
|
||||
auto tile_shape = TileShape{};
|
||||
@@ -220,8 +219,7 @@ public:
|
||||
problem_shape_MNKL, tile_shape, cluster_shape);
|
||||
}
|
||||
|
||||
static constexpr
|
||||
dim3
|
||||
static dim3
|
||||
get_block_shape() {
|
||||
return dim3(MaxThreadsPerBlock, 1, 1);
|
||||
}
|
||||
@@ -300,7 +298,7 @@ public:
|
||||
typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state;
|
||||
typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state;
|
||||
|
||||
// For the DMA Load (producer) we start with an opposite phase
|
||||
// For the DMA Load (producer) we start with an opposite phase
|
||||
// i.e., we skip all waits since we know that the buffer is indeed empty
|
||||
PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state<MainloopPipeline>();
|
||||
PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state<EpiLoadPipeline>();
|
||||
|
||||
@@ -202,20 +202,20 @@ public:
|
||||
auto N = get<1>(args.problem_shape);
|
||||
auto K = get<2>(args.problem_shape);
|
||||
// Contiguous dimension for the TMA tensor should be 128b aligned
|
||||
implementable = std::is_same_v<gemm::detail::StrideToLayoutTagA_t<StrideA>, layout::RowMajor> ?
|
||||
implementable = std::is_same_v<gemm::detail::StrideToLayoutTagA_t<StrideA>, layout::RowMajor> ?
|
||||
K % min_tma_aligned_elements == 0 : M % min_tma_aligned_elements == 0;
|
||||
implementable = implementable && (std::is_same_v<gemm::detail::StrideToLayoutTagB_t<StrideB>, layout::RowMajor> ?
|
||||
implementable = implementable && (std::is_same_v<gemm::detail::StrideToLayoutTagB_t<StrideB>, layout::RowMajor> ?
|
||||
N % min_tma_aligned_elements == 0 : K % min_tma_aligned_elements == 0);
|
||||
implementable = implementable && (!cutlass::epilogue::collective::detail::IF_EPILOGUE_USES_TMA<CollectiveEpilogue>::value ||
|
||||
(cutlass::epilogue::collective::detail::IF_EPILOGUE_USES_TMA<CollectiveEpilogue>::value &&
|
||||
std::is_same_v<gemm::detail::StrideToLayoutTagC_t<StrideC>, layout::RowMajor> ?
|
||||
std::is_same_v<gemm::detail::StrideToLayoutTagC_t<StrideC>, layout::RowMajor> ?
|
||||
N % min_tma_aligned_elements == 0 : M % min_tma_aligned_elements == 0));
|
||||
if (!implementable) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
|
||||
return implementable;
|
||||
}
|
||||
|
||||
constexpr bool is_beta_supported =
|
||||
constexpr bool is_beta_supported =
|
||||
CollectiveEpilogue::ThreadEpilogueOp::kScale == cutlass::epilogue::thread::ScaleType::Default;
|
||||
implementable = is_beta_supported || (args.epilogue.thread.beta == 0 && args.epilogue.thread.beta_ptr == nullptr);
|
||||
if (!implementable) {
|
||||
@@ -233,15 +233,13 @@ public:
|
||||
}
|
||||
|
||||
// Computes the kernel launch grid shape based on runtime parameters
|
||||
static constexpr
|
||||
dim3
|
||||
static dim3
|
||||
get_grid_shape(Params const& params) {
|
||||
// Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently
|
||||
return detail::PersistentTileSchedulerSm90::get_grid_shape(params.problem_shape, TileShape{}, ClusterShape{}, params.hw_info);
|
||||
}
|
||||
|
||||
static constexpr
|
||||
dim3
|
||||
static dim3
|
||||
get_block_shape() {
|
||||
return dim3(MaxThreadsPerBlock, 1, 1);
|
||||
}
|
||||
@@ -333,7 +331,7 @@ public:
|
||||
typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state;
|
||||
typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state;
|
||||
|
||||
// For the DMA Load (producer) we start with an opposite phase
|
||||
// For the DMA Load (producer) we start with an opposite phase
|
||||
// i.e., we skip all waits since we know that the buffer is indeed empty
|
||||
PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state<MainloopPipeline>();
|
||||
PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state<EpiLoadPipeline>();
|
||||
|
||||
@@ -110,7 +110,7 @@ public:
|
||||
static constexpr uint32_t LoadRegisterRequirement = 40;
|
||||
static constexpr uint32_t MmaRegisterRequirement = 232;
|
||||
|
||||
// Order Sequence barrier with two stages: one for Mainloop and one for Epilogue
|
||||
// Order Sequence barrier with two stages: one for Mainloop and one for Epilogue
|
||||
static constexpr uint32_t StagesPerMathWarpGroup = 2;
|
||||
using MathWarpGroupOrderBarrier = cutlass::OrderedSequenceBarrier<
|
||||
StagesPerMathWarpGroup, NumMmaWarpGroups>;
|
||||
@@ -210,20 +210,20 @@ public:
|
||||
auto N = get<1>(args.problem_shape);
|
||||
auto K = get<2>(args.problem_shape);
|
||||
// Contiguous dimension for the TMA tensor should be 128b aligned
|
||||
implementable = std::is_same_v<gemm::detail::StrideToLayoutTagA_t<StrideA>, layout::RowMajor> ?
|
||||
implementable = std::is_same_v<gemm::detail::StrideToLayoutTagA_t<StrideA>, layout::RowMajor> ?
|
||||
K % min_tma_aligned_elements == 0 : M % min_tma_aligned_elements == 0;
|
||||
implementable = implementable && (std::is_same_v<gemm::detail::StrideToLayoutTagB_t<StrideB>, layout::RowMajor> ?
|
||||
implementable = implementable && (std::is_same_v<gemm::detail::StrideToLayoutTagB_t<StrideB>, layout::RowMajor> ?
|
||||
N % min_tma_aligned_elements == 0 : K % min_tma_aligned_elements == 0);
|
||||
implementable = implementable && (!cutlass::epilogue::collective::detail::IF_EPILOGUE_USES_TMA<CollectiveEpilogue>::value ||
|
||||
(cutlass::epilogue::collective::detail::IF_EPILOGUE_USES_TMA<CollectiveEpilogue>::value &&
|
||||
std::is_same_v<gemm::detail::StrideToLayoutTagC_t<StrideC>, layout::RowMajor> ?
|
||||
std::is_same_v<gemm::detail::StrideToLayoutTagC_t<StrideC>, layout::RowMajor> ?
|
||||
N % min_tma_aligned_elements == 0 : M % min_tma_aligned_elements == 0));
|
||||
if (!implementable) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
|
||||
return implementable;
|
||||
}
|
||||
|
||||
constexpr bool is_beta_supported =
|
||||
constexpr bool is_beta_supported =
|
||||
CollectiveEpilogue::ThreadEpilogueOp::kScale == cutlass::epilogue::thread::ScaleType::Default;
|
||||
implementable = is_beta_supported || (args.epilogue.thread.beta == 0 && args.epilogue.thread.beta_ptr == nullptr);
|
||||
if (!implementable) {
|
||||
@@ -241,15 +241,13 @@ public:
|
||||
}
|
||||
|
||||
// Computes the kernel launch grid shape based on runtime parameters
|
||||
static constexpr
|
||||
dim3
|
||||
static dim3
|
||||
get_grid_shape(Params const& params) {
|
||||
// Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently
|
||||
return detail::PersistentTileSchedulerSm90::get_grid_shape(params.problem_shape, TileShape{}, ClusterShape{}, params.hw_info);
|
||||
}
|
||||
|
||||
static constexpr
|
||||
dim3
|
||||
static dim3
|
||||
get_block_shape() {
|
||||
return dim3(MaxThreadsPerBlock, 1, 1);
|
||||
}
|
||||
@@ -341,7 +339,7 @@ public:
|
||||
typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state;
|
||||
typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state;
|
||||
|
||||
// For the DMA Load (producer) we start with an opposite phase
|
||||
// For the DMA Load (producer) we start with an opposite phase
|
||||
// i.e., we skip all waits since we know that the buffer is indeed empty
|
||||
PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state<MainloopPipeline>();
|
||||
PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state<EpiLoadPipeline>();
|
||||
@@ -389,9 +387,9 @@ public:
|
||||
detail::PersistentTileSchedulerSm90 scheduler;
|
||||
|
||||
if (warp_group_role == WarpGroupRole::Consumer1) {
|
||||
// Advance 2nd Math WG to the next work tile for the startup
|
||||
// Advance 2nd Math WG to the next work tile for the startup
|
||||
scheduler.advance_to_next_work();
|
||||
// Advance 2nd Math WG pipeline states to the end of 1st Math WG
|
||||
// Advance 2nd Math WG pipeline states to the end of 1st Math WG
|
||||
mainloop_pipe_consumer_state.advance(k_tile_count);
|
||||
epi_load_pipe_consumer_state.advance(c_tile_count);
|
||||
epi_store_pipe_producer_state.advance(d_tile_count);
|
||||
@@ -486,7 +484,7 @@ public:
|
||||
params.mainloop
|
||||
);
|
||||
|
||||
// Cue for next Math WG's MMA to start
|
||||
// Cue for next Math WG's MMA to start
|
||||
math_wg_order_barrier.arrive();
|
||||
|
||||
// Make sure the math instructions are done and free buffers before entering the epilogue
|
||||
@@ -522,7 +520,7 @@ public:
|
||||
// Wait for all TMA stores to complete
|
||||
epi_store_pipeline.producer_tail(epi_store_pipe_producer_state);
|
||||
|
||||
// Cue for next Math WG's Epilogue to start
|
||||
// Cue for next Math WG's Epilogue to start
|
||||
math_wg_order_barrier.arrive();
|
||||
|
||||
// Get next work tile
|
||||
|
||||
@@ -108,7 +108,7 @@ public:
|
||||
return {work_idx_m, work_idx_n, static_cast<int32_t>(work_idx_l), current_work_linear_idx_ < scheduler_params.blocks_per_problem_};
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
CUTLASS_DEVICE
|
||||
void
|
||||
advance_to_next_work(uint32_t advance_count = 1) {
|
||||
current_work_linear_idx_ += grid_blocks_total_ * advance_count;
|
||||
@@ -117,7 +117,7 @@ public:
|
||||
// Given the inputs, computes the total number of output blocks this problem will compute over
|
||||
// Note that this is only the logical size of our grid, not the physical grid we will actually launch.
|
||||
template<class ProblemShapeMNKL, class BlockShape, class ClusterShape>
|
||||
CUTLASS_HOST_DEVICE constexpr static
|
||||
CUTLASS_HOST_DEVICE static
|
||||
dim3
|
||||
get_tiled_blk_shape_mnl(ProblemShapeMNKL problem_shape_mnkl, BlockShape blk_shape, ClusterShape cluster_shape) {
|
||||
// Across M and N is our Cluster tile, so we must round up the blocks to the nearest whole number of Cluster tiles
|
||||
@@ -135,7 +135,7 @@ public:
|
||||
|
||||
// Given the inputs, computes the physical grid we should launch.
|
||||
template<class ProblemShapeMNKL, class BlockShape, class ClusterShape>
|
||||
CUTLASS_HOST_DEVICE constexpr static
|
||||
CUTLASS_HOST_DEVICE static
|
||||
dim3
|
||||
get_grid_shape(ProblemShapeMNKL problem_shape_mnk, BlockShape blk_shape, ClusterShape cluster_shape, KernelHardwareInfo hw_info) {
|
||||
int const sm_count = hw_info.sm_count;
|
||||
|
||||
@@ -630,6 +630,12 @@ public:
|
||||
accum = plus_accum(accum, tmp_accum);
|
||||
}
|
||||
|
||||
|
||||
// Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -690,6 +690,11 @@ public:
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -846,12 +846,10 @@ public:
|
||||
|
||||
}
|
||||
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
|
||||
// commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
}
|
||||
// commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
@@ -660,12 +660,10 @@ public:
|
||||
accum = plus_accum(accum, pipe_state.tmp_accum_);
|
||||
}
|
||||
|
||||
// Optionally commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
}
|
||||
// Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
|
||||
}
|
||||
|
||||
|
||||
@@ -628,6 +628,12 @@ public:
|
||||
|
||||
}
|
||||
|
||||
|
||||
// Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -739,6 +739,11 @@ public:
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -650,6 +650,12 @@ public:
|
||||
|
||||
}
|
||||
|
||||
|
||||
// Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -528,12 +528,10 @@ public:
|
||||
|
||||
}
|
||||
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
|
||||
// commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
}
|
||||
// commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
@@ -524,10 +524,10 @@ class PredicatedTileAccessIterator<Shape_, Element_, layout::PitchLinear,
|
||||
|
||||
if (kAdvanceRank) {
|
||||
pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided() - 1);
|
||||
pointer_ += Shape::kContiguous * tile_offset.contiguous();
|
||||
pointer_ += Shape::kContiguous * tile_offset.contiguous() * sizeof_bits<Element>::value / 8;
|
||||
} else {
|
||||
pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous() - 1);
|
||||
pointer_ += Shape::kStrided * tile_offset.strided();
|
||||
pointer_ += Shape::kStrided * tile_offset.strided() * sizeof_bits<Element>::value / 8;
|
||||
}
|
||||
} else {
|
||||
coord_offset_.strided() = the_predicates.thread_offset_.strided() + Shape::kStrided * (tile_offset.strided() - kAdvanceRank);
|
||||
|
||||
@@ -156,11 +156,11 @@ static bool
|
||||
can_implement(Arguments const& args);
|
||||
|
||||
// Returns a dim3 representing the threadblock shape.
|
||||
static constexpr dim3
|
||||
static dim3
|
||||
get_block_shape();
|
||||
|
||||
// Returns a dim3 representing the grid shape in terms of threadblocks.
|
||||
static constexpr dim3
|
||||
static dim3
|
||||
get_grid_shape(Params const& params);
|
||||
```
|
||||
|
||||
|
||||
@@ -353,8 +353,9 @@ Hopper architecture and beyond so as to indicate new features of the kernel with
|
||||
|
||||
To best illustrate this naming convention, we will walk through the meaning of each of the components
|
||||
in a GEMM kernel used by the profiler:
|
||||
|
||||
```
|
||||
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_f32_128x128x64_2x1x1_0_ntn_align8
|
||||
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_f16_f32_128x128x64_2x1x1_0_ntn_align8
|
||||
```
|
||||
|
||||
The components within this name are as follows:
|
||||
@@ -366,8 +367,7 @@ The components within this name are as follows:
|
||||
* `s`: indicates that the Tensor Core instruction being used accumulates in single precision
|
||||
(as opposed to `h`, which indicates half precision)
|
||||
* `64x128x16gemm`: indicates that the shape of the Tensor Core instruction being used (MxNxK) is 64x128x16
|
||||
* `f16_f16_f32_f16`: indicates that the data types for operands A, B, and C are each `f16`
|
||||
(half precision) and that accumulation is performed using `f32` (single precision)
|
||||
* `f16_f16_f32_f16_f16`: indicates that the data types for operands A, B, Accumulator, C and D (in that order).
|
||||
* `128x128x64`: indicates that the thread block shape used in the GEMM (MxNxK) is 128x128x64
|
||||
* `2x1x1`: indicates that the cluster shape being used is 2x1x1
|
||||
* `0`: indicates that the kernel uses the CollectiveBuilder's automatic stage calculation to determine the
|
||||
@@ -382,12 +382,24 @@ Note that in some special cases where the input A/B types do not match that of t
|
||||
instruction's, the MMA facing input type is added to the instruction string as well.
|
||||
|
||||
```
|
||||
cutlass3x_sm90_tensorop_s64x128x8tf32gemm_f32_f32_f32_f32_128x128x32_2x1x1_0_tnn_align4
|
||||
cutlass3x_sm90_tensorop_s64x128x8tf32gemm_f32_f32_f32_f32_f32_128x128x32_2x1x1_0_tnn_align4
|
||||
```
|
||||
|
||||
* `s64x128x8tf32gemm`: indicates that the MMA consumes inputs in `tf32` format, and therefore
|
||||
the kernel performs rounding of the `f32` values in global memory while loading them into shared memory.
|
||||
|
||||
For custom mainloop or epilogue schedules, details of the opted-in schedule are appended to the end of the
|
||||
kernel name. For example,
|
||||
|
||||
```
|
||||
cutlass3x_sm90_tensorop_h64x128x16gemm_f16_f16_f16_void_f16_128x128x64_1x1x1_0_nnn_align8_warpspecialized_cooperative_epi_tma
|
||||
```
|
||||
|
||||
* `warpspecialized_cooperative`: Mainloop employs a persistent warp-specialized mainloop and kernel schedule.
|
||||
* `epi_tma`: Kernel epilogue employs TMA based vectorization.
|
||||
* `f16_f16_f16_void_f16`: In this case, C type is set to `void`, indicating that residual matrix support
|
||||
is disabled.
|
||||
|
||||
# Convolution
|
||||
|
||||
The CUTLASS Profiler is capable of executing 2-D and 3-D convolution problems for forwards and backwards
|
||||
|
||||
@@ -41,7 +41,6 @@ add_custom_target(
|
||||
cutlass_test_unit_gemm_device_tensorop_planar_complex
|
||||
cutlass_test_unit_gemm_device_sparse_tensorop_sm80
|
||||
cutlass_test_unit_gemv_device
|
||||
cutlass_test_unit_gemv_device_strided_batched
|
||||
cutlass_test_unit_gemm_device_tensorop_sm90
|
||||
cutlass_test_unit_gemm_device_tensorop_cluster_multicast_sm90
|
||||
)
|
||||
@@ -61,7 +60,6 @@ add_custom_target(
|
||||
test_unit_gemm_device_tensorop_planar_complex
|
||||
test_unit_gemm_device_sparse_tensorop_sm80
|
||||
test_unit_gemv_device
|
||||
test_unit_gemv_device_strided_batched
|
||||
test_unit_gemm_device_tensorop_sm90
|
||||
)
|
||||
|
||||
@@ -500,15 +498,6 @@ cutlass_test_unit_add_executable(
|
||||
gemv.cu
|
||||
)
|
||||
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_gemv_device_strided_batched
|
||||
|
||||
BATCH_SOURCES ON
|
||||
BATCH_SIZE 4
|
||||
|
||||
gemv_strided_batched.cu
|
||||
)
|
||||
|
||||
if (NOT CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
|
||||
add_dependencies(
|
||||
|
||||
@@ -98,7 +98,7 @@ public:
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
uint64_t seed_ = 2080
|
||||
uint64_t seed_ = 2023
|
||||
):
|
||||
init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { }
|
||||
|
||||
@@ -156,22 +156,29 @@ public:
|
||||
|
||||
/// Initializes data structures
|
||||
void initialize(
|
||||
cutlass::MatrixCoord problem_size
|
||||
cutlass::MatrixCoord problem_size,
|
||||
int32_t batch_count
|
||||
) {
|
||||
|
||||
//
|
||||
// Allocate the GEMM workspace
|
||||
// Allocate the GEMV workspace
|
||||
//
|
||||
|
||||
tensor_A.resize(problem_size);
|
||||
tensor_B.resize({problem_size.column(), 1});
|
||||
tensor_C.resize({problem_size.row(), 1});
|
||||
tensor_D.resize({problem_size.row(), 1});
|
||||
reference_D.resize({problem_size.row(), 1}, false);
|
||||
if(std::is_same<LayoutA, cutlass::layout::ColumnMajor>::value) {
|
||||
tensor_A.resize({problem_size.row(), batch_count * problem_size.column()});
|
||||
}
|
||||
else {
|
||||
tensor_A.resize({batch_count * problem_size.row(), problem_size.column()});
|
||||
}
|
||||
|
||||
tensor_B.resize({batch_count * problem_size.column(), 1});
|
||||
tensor_C.resize({batch_count * problem_size.row(), 1});
|
||||
tensor_D.resize({batch_count * problem_size.row(), 1});
|
||||
reference_D.resize({batch_count * problem_size.row(), 1}, false);
|
||||
|
||||
EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019));
|
||||
EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018));
|
||||
EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017));
|
||||
EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 1));
|
||||
EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2));
|
||||
EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 3));
|
||||
|
||||
// It is possible to randomly initialize to all zeros, so override this with non-zeros
|
||||
// in the upper left corner of each operand.
|
||||
@@ -225,9 +232,14 @@ public:
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Verifies the result is a GEMM
|
||||
/// Verifies the result
|
||||
bool verify(
|
||||
cutlass::MatrixCoord problem_size,
|
||||
cutlass::MatrixCoord problem_size,
|
||||
int32_t batch_count,
|
||||
int64_t batch_stride_A,
|
||||
int64_t batch_stride_B,
|
||||
int64_t batch_stride_C,
|
||||
int64_t batch_stride_D,
|
||||
ElementCompute alpha,
|
||||
ElementCompute beta) {
|
||||
|
||||
@@ -242,7 +254,7 @@ public:
|
||||
ElementCompute, ElementAccumulator
|
||||
>(
|
||||
{problem_size.row(), 1, problem_size.column()},
|
||||
alpha,
|
||||
alpha,
|
||||
tensor_A.host_ref(),
|
||||
Gemv::kTransformA,
|
||||
tensor_B.host_ref(),
|
||||
@@ -250,7 +262,12 @@ public:
|
||||
beta,
|
||||
tensor_C.host_ref(),
|
||||
reference_D.host_ref(),
|
||||
ElementAccumulator(0)
|
||||
ElementAccumulator(0),
|
||||
batch_count,
|
||||
batch_stride_A,
|
||||
batch_stride_B,
|
||||
batch_stride_C,
|
||||
batch_stride_D
|
||||
);
|
||||
|
||||
return compare_reference(problem_size, alpha, beta);
|
||||
@@ -259,39 +276,50 @@ public:
|
||||
/// Runs one problem size
|
||||
bool run(
|
||||
cutlass::MatrixCoord problem_size,
|
||||
int32_t batch_count,
|
||||
int64_t batch_stride_A,
|
||||
int64_t batch_stride_B,
|
||||
int64_t batch_stride_C,
|
||||
int64_t batch_stride_D,
|
||||
ElementCompute alpha,
|
||||
ElementCompute beta) {
|
||||
|
||||
this->initialize(problem_size);
|
||||
this->initialize(problem_size, batch_count);
|
||||
|
||||
//
|
||||
// Initialize the GEMM operator
|
||||
// Initialize the GEMV operator
|
||||
//
|
||||
|
||||
typename Gemv::Arguments arguments{
|
||||
problem_size,
|
||||
batch_count,
|
||||
{alpha, beta},
|
||||
tensor_A.device_ref(),
|
||||
tensor_B.device_data(),
|
||||
tensor_C.device_data(),
|
||||
tensor_D.device_data(),
|
||||
tensor_B.layout().stride(0),
|
||||
tensor_C.layout().stride(0),
|
||||
tensor_D.layout().stride(0)
|
||||
batch_stride_A,
|
||||
batch_stride_B,
|
||||
batch_stride_C,
|
||||
batch_stride_D
|
||||
};
|
||||
|
||||
Gemv gemm_op;
|
||||
|
||||
cutlass::Status status = gemm_op.can_implement(arguments);
|
||||
|
||||
EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status);
|
||||
|
||||
size_t workspace_size = Gemv::get_workspace_size(arguments);
|
||||
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
cutlass::Status status = gemm_op.initialize(arguments, workspace.get());
|
||||
status = gemm_op.initialize(arguments, workspace.get());
|
||||
|
||||
EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status);
|
||||
|
||||
//
|
||||
// Run the GEMM
|
||||
// Run the GEMV
|
||||
//
|
||||
|
||||
status = gemm_op();
|
||||
@@ -302,8 +330,15 @@ public:
|
||||
// Verify
|
||||
//
|
||||
|
||||
bool passed = this->verify(problem_size, alpha, beta);
|
||||
|
||||
bool passed = this->verify(
|
||||
problem_size,
|
||||
batch_count,
|
||||
batch_stride_A,
|
||||
batch_stride_B,
|
||||
batch_stride_C,
|
||||
batch_stride_D,
|
||||
alpha,
|
||||
beta);
|
||||
return passed;
|
||||
}
|
||||
};
|
||||
@@ -315,12 +350,16 @@ bool TestAllGemv() {
|
||||
|
||||
using ElementCompute = typename Gemv::EpilogueOutputOp::ElementCompute;
|
||||
|
||||
int Batch[] = {
|
||||
1, 520, 1314
|
||||
};
|
||||
|
||||
int M[] = {
|
||||
8, 48, 192, 520
|
||||
1, 5, 16
|
||||
};
|
||||
|
||||
int K[] = {
|
||||
8, 192, 528
|
||||
8, 128, 256
|
||||
};
|
||||
|
||||
double Alpha[] = {
|
||||
@@ -331,15 +370,25 @@ bool TestAllGemv() {
|
||||
0, 1, 1.25
|
||||
};
|
||||
|
||||
for (int m : M) {
|
||||
for (int k : K) {
|
||||
for (double alpha : Alpha) {
|
||||
for (double beta : Beta) {
|
||||
for (int b : Batch) {
|
||||
for (int m : M) {
|
||||
for (int k : K) {
|
||||
for (double alpha : Alpha) {
|
||||
for (double beta : Beta) {
|
||||
|
||||
TestbedGemv<Gemv> testbed;
|
||||
TestbedGemv<Gemv> testbed;
|
||||
|
||||
if (!testbed.run({m, k}, ElementCompute(alpha), ElementCompute(beta))) {
|
||||
return false;
|
||||
if (!testbed.run(
|
||||
{m, k},
|
||||
b,
|
||||
m * k,
|
||||
k,
|
||||
m,
|
||||
m,
|
||||
ElementCompute(alpha),
|
||||
ElementCompute(beta))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -354,66 +403,100 @@ bool TestAllGemv() {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM50_Device_Gemv_f32n_f32_f32_simt_f32, Simple) {
|
||||
|
||||
using ElementOutput = float;
|
||||
using LayoutA = cutlass::layout::ColumnMajor;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput,
|
||||
1,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator>;
|
||||
|
||||
using Gemv = cutlass::gemm::device::Gemv<
|
||||
cutlass::gemm::kernel::Gemv<
|
||||
ElementOutput, // Element A
|
||||
LayoutA, // Layout A
|
||||
ElementOutput, // Element B
|
||||
ElementOutput, // Element C
|
||||
ElementAccumulator, // Element Accumulator
|
||||
EpilogueOp // Output operator
|
||||
>
|
||||
>;
|
||||
|
||||
|
||||
EXPECT_TRUE(test::gemm::TestAllGemv<Gemv>());
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM50_Device_Gemv_f16n_f16_f32_simt_f32, Simple) {
|
||||
TEST(SM50_Device_Gemv_f16n_f16_f16_simt_f32, RowMajorA) {
|
||||
|
||||
using ElementInput = cutlass::half_t;
|
||||
using ElementOutput = float;
|
||||
using LayoutA = cutlass::layout::ColumnMajor;
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
int const kElementsPerAccess = 8;
|
||||
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput,
|
||||
1,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator>;
|
||||
|
||||
using Gemv = cutlass::gemm::device::Gemv<
|
||||
cutlass::gemm::kernel::Gemv<
|
||||
ElementInput, // Element A
|
||||
LayoutA, // Layout A
|
||||
ElementInput, // Element B
|
||||
ElementOutput, // Element C
|
||||
ElementAccumulator, // Element Accumulator
|
||||
EpilogueOp // Output operator
|
||||
>
|
||||
>;
|
||||
|
||||
cutlass::gemm::kernel::Gemv<
|
||||
ElementInput, // Element A
|
||||
LayoutA, // Layout A
|
||||
ElementInput, // Element B
|
||||
ElementOutput, // Element C
|
||||
ElementAccumulator, // Element accumulator
|
||||
EpilogueOp, // Output operator
|
||||
kElementsPerAccess // Element access granularity
|
||||
>
|
||||
>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::TestAllGemv<Gemv>());
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM50_Device_Gemv_f16n_f16_f16_simt_f32, Simple) {
|
||||
TEST(SM50_Device_Gemv_f32n_f32_f32_simt_f32, RowMajorA) {
|
||||
|
||||
using ElementInput = float;
|
||||
using ElementOutput = float;
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using ElementAccumulator = float;
|
||||
int const kElementsPerAccess = 4;
|
||||
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput,
|
||||
1,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator>;
|
||||
|
||||
using Gemv = cutlass::gemm::device::Gemv<
|
||||
cutlass::gemm::kernel::Gemv<
|
||||
ElementInput, // Element A
|
||||
LayoutA, // Layout A
|
||||
ElementInput, // Element B
|
||||
ElementOutput, // Element C
|
||||
ElementAccumulator, // Element accumulator
|
||||
EpilogueOp, // Output operator
|
||||
kElementsPerAccess // Element access granularity
|
||||
>
|
||||
>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::TestAllGemv<Gemv>());
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM50_Device_Gemv_f64n_f64_f64_simt_f64, RowMajorA) {
|
||||
|
||||
using ElementInput = double;
|
||||
using ElementOutput = double;
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using ElementAccumulator = double;
|
||||
int const kElementsPerAccess = 2;
|
||||
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput,
|
||||
1,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator>;
|
||||
|
||||
using Gemv = cutlass::gemm::device::Gemv<
|
||||
cutlass::gemm::kernel::Gemv<
|
||||
ElementInput, // Element A
|
||||
LayoutA, // Layout A
|
||||
ElementInput, // Element B
|
||||
ElementOutput, // Element C
|
||||
ElementAccumulator, // Element accumulator
|
||||
EpilogueOp, // Output operator
|
||||
kElementsPerAccess // Element access granularity
|
||||
>
|
||||
>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::TestAllGemv<Gemv>());
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM50_Device_Gemv_f16n_f16_f16_simt_f32, ColumnMajorA) {
|
||||
|
||||
using ElementInput = cutlass::half_t;
|
||||
using ElementOutput = cutlass::half_t;
|
||||
@@ -442,3 +525,63 @@ TEST(SM50_Device_Gemv_f16n_f16_f16_simt_f32, Simple) {
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM50_Device_Gemv_f32n_f32_f32_simt_f32, ColumnMajorA) {
|
||||
|
||||
using ElementInput = float;
|
||||
using ElementOutput = float;
|
||||
using LayoutA = cutlass::layout::ColumnMajor;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput,
|
||||
1,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator>;
|
||||
|
||||
using Gemv = cutlass::gemm::device::Gemv<
|
||||
cutlass::gemm::kernel::Gemv<
|
||||
ElementInput, // Element A
|
||||
LayoutA, // Layout A
|
||||
ElementInput, // Element B
|
||||
ElementOutput, // Element C
|
||||
ElementAccumulator, // Element Accumulator
|
||||
EpilogueOp // Output operator
|
||||
>
|
||||
>;
|
||||
|
||||
|
||||
EXPECT_TRUE(test::gemm::TestAllGemv<Gemv>());
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM50_Device_Gemv_f64n_f64_f64_simt_f64, ColumnMajorA) {
|
||||
|
||||
using ElementInput = double;
|
||||
using ElementOutput = double;
|
||||
using LayoutA = cutlass::layout::ColumnMajor;
|
||||
using ElementAccumulator = double;
|
||||
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput,
|
||||
1,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator>;
|
||||
|
||||
using Gemv = cutlass::gemm::device::Gemv<
|
||||
cutlass::gemm::kernel::Gemv<
|
||||
ElementInput, // Element A
|
||||
LayoutA, // Layout A
|
||||
ElementInput, // Element B
|
||||
ElementOutput, // Element C
|
||||
ElementAccumulator, // Element Accumulator
|
||||
EpilogueOp // Output operator
|
||||
>
|
||||
>;
|
||||
|
||||
|
||||
EXPECT_TRUE(test::gemm::TestAllGemv<Gemv>());
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@@ -1,490 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Tests for device-wide strided batched GEMV interface
|
||||
*/
|
||||
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/kernel/gemv_strided_batched.h"
|
||||
#include "cutlass/gemm/device/gemv_strided_batched.h"
|
||||
|
||||
#include "../../common/cutlass_unit_test.h"
|
||||
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
#include "cutlass/util/reference/host/gemm_complex.h"
|
||||
|
||||
#include "testbed_utils.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace test {
|
||||
namespace gemm {
|
||||
|
||||
template <typename GemvStridedBatched>
|
||||
class TestbedStridedBatchedGemv
|
||||
{
|
||||
public:
|
||||
|
||||
using ElementA = typename GemvStridedBatched::ElementA;
|
||||
using LayoutA = typename GemvStridedBatched::LayoutA;
|
||||
using ElementB = typename GemvStridedBatched::ElementB;
|
||||
using ElementC = typename GemvStridedBatched::ElementC;
|
||||
|
||||
using ElementAccumulator = typename GemvStridedBatched::ElementAccumulator;
|
||||
using ElementCompute = typename GemvStridedBatched::EpilogueOutputOp::ElementCompute;
|
||||
|
||||
using LayoutV = cutlass::layout::RowMajor;
|
||||
|
||||
private:
|
||||
|
||||
/// Initialization
|
||||
cutlass::Distribution::Kind init_A;
|
||||
cutlass::Distribution::Kind init_B;
|
||||
cutlass::Distribution::Kind init_C;
|
||||
uint64_t seed;
|
||||
|
||||
cutlass::HostTensor<ElementA, LayoutA> tensor_A;
|
||||
cutlass::HostTensor<ElementB, LayoutV> tensor_B;
|
||||
cutlass::HostTensor<ElementC, LayoutV> tensor_C;
|
||||
cutlass::HostTensor<ElementC, LayoutV> tensor_D;
|
||||
cutlass::HostTensor<ElementC, LayoutV> reference_D;
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
TestbedStridedBatchedGemv(
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
uint64_t seed_ = 2023):
|
||||
init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) {}
|
||||
|
||||
/// Helper to initialize a tensor view
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
int bits_output = cutlass::sizeof_bits<typename GemvStridedBatched::ElementC>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
} else if (bits_input <= 8) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
} else if (bits_output == 16) {
|
||||
scope_max = 5;
|
||||
scope_min = -5;
|
||||
} else {
|
||||
scope_max = 8;
|
||||
scope_min = -8;
|
||||
}
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
|
||||
cutlass::reference::host::BlockFillSequential(
|
||||
view.data(), view.capacity());
|
||||
}
|
||||
else {
|
||||
// TODO: Implement the rest
|
||||
EXPECT_TRUE(false) << "Not implemented";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initializes data structures
|
||||
void initialize(
|
||||
cutlass::MatrixCoord problem_size,
|
||||
int32_t batch_count
|
||||
) {
|
||||
|
||||
//
|
||||
// Allocate the GEMV workspace
|
||||
//
|
||||
|
||||
tensor_A.resize({batch_count * problem_size.row(), problem_size.column()});
|
||||
tensor_B.resize({batch_count * problem_size.column(), 1});
|
||||
tensor_C.resize({batch_count * problem_size.row(), 1});
|
||||
tensor_D.resize({batch_count * problem_size.row(), 1});
|
||||
reference_D.resize({batch_count * problem_size.row(), 1}, false);
|
||||
|
||||
EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 1));
|
||||
EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2));
|
||||
EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 3));
|
||||
|
||||
// It is possible to randomly initialize to all zeros, so override this with non-zeros
|
||||
// in the upper left corner of each operand.
|
||||
tensor_A.host_view().at({0, 0}) = typename GemvStridedBatched::ElementA(1);
|
||||
tensor_B.host_view().at({0, 0}) = typename GemvStridedBatched::ElementB(1);
|
||||
tensor_C.host_view().at({0, 0}) = typename GemvStridedBatched::ElementC(1);
|
||||
|
||||
cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view());
|
||||
|
||||
tensor_A.sync_device();
|
||||
tensor_B.sync_device();
|
||||
tensor_C.sync_device();
|
||||
tensor_D.sync_device();
|
||||
}
|
||||
|
||||
/// Compares computed reference with device reference and outputs to a file if incorrect
|
||||
bool compare_reference(
|
||||
cutlass::MatrixCoord problem_size,
|
||||
ElementCompute alpha,
|
||||
ElementCompute beta) {
|
||||
|
||||
tensor_D.sync_host();
|
||||
|
||||
EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0);
|
||||
EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0);
|
||||
EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0);
|
||||
|
||||
EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0);
|
||||
EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0);
|
||||
|
||||
bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view());
|
||||
|
||||
EXPECT_TRUE(passed) << " mismatched reference";
|
||||
|
||||
if (!passed) {
|
||||
|
||||
std::ofstream file("testbed_universal_errors.txt");
|
||||
|
||||
file
|
||||
<< "problem: " << problem_size
|
||||
<< ", alpha: " << alpha << ", beta: " << beta << "\n\n";
|
||||
|
||||
file
|
||||
<< "A =\n" << tensor_A.host_view()
|
||||
<< "\nB =\n" << tensor_B.host_view()
|
||||
<< "\nC =\n" << tensor_C.host_view()
|
||||
<< "\n\nReference =\n" << reference_D.host_view()
|
||||
<< "\nComputed =\n" << tensor_D.host_view();
|
||||
}
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Verifies the result
|
||||
bool verify(
|
||||
cutlass::MatrixCoord problem_size,
|
||||
int32_t batch_count,
|
||||
int64_t batch_stride_A,
|
||||
int64_t batch_stride_B,
|
||||
int64_t batch_stride_C,
|
||||
int64_t batch_stride_D,
|
||||
ElementCompute alpha,
|
||||
ElementCompute beta) {
|
||||
|
||||
//
|
||||
// Verify
|
||||
//
|
||||
|
||||
cutlass::reference::host::GemmComplex<
|
||||
typename GemvStridedBatched::ElementA, typename GemvStridedBatched::LayoutA,
|
||||
typename GemvStridedBatched::ElementB, LayoutV,
|
||||
typename GemvStridedBatched::ElementC, LayoutV,
|
||||
ElementCompute, ElementAccumulator
|
||||
>(
|
||||
{problem_size.row(), 1, problem_size.column()},
|
||||
alpha,
|
||||
tensor_A.host_ref(),
|
||||
GemvStridedBatched::kTransformA,
|
||||
tensor_B.host_ref(),
|
||||
GemvStridedBatched::kTransformB,
|
||||
beta,
|
||||
tensor_C.host_ref(),
|
||||
reference_D.host_ref(),
|
||||
ElementAccumulator(0),
|
||||
batch_count,
|
||||
batch_stride_A,
|
||||
batch_stride_B,
|
||||
batch_stride_C,
|
||||
batch_stride_D
|
||||
);
|
||||
|
||||
return compare_reference(problem_size, alpha, beta);
|
||||
}
|
||||
|
||||
/// Runs one problem size
|
||||
bool run(
|
||||
cutlass::MatrixCoord problem_size,
|
||||
int32_t batch_count,
|
||||
int64_t batch_stride_A,
|
||||
int64_t batch_stride_B,
|
||||
int64_t batch_stride_C,
|
||||
int64_t batch_stride_D,
|
||||
ElementCompute alpha,
|
||||
ElementCompute beta) {
|
||||
|
||||
this->initialize(problem_size, batch_count);
|
||||
|
||||
//
|
||||
// Initialize the GEMV operator
|
||||
//
|
||||
|
||||
typename GemvStridedBatched::Arguments arguments{
|
||||
problem_size,
|
||||
batch_count,
|
||||
{alpha, beta},
|
||||
tensor_A.device_ref(),
|
||||
tensor_B.device_data(),
|
||||
tensor_C.device_data(),
|
||||
tensor_D.device_data(),
|
||||
batch_stride_A,
|
||||
batch_stride_B,
|
||||
batch_stride_C,
|
||||
batch_stride_D
|
||||
};
|
||||
|
||||
GemvStridedBatched gemm_op;
|
||||
|
||||
cutlass::Status status = gemm_op.can_implement(arguments);
|
||||
|
||||
EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status);
|
||||
|
||||
size_t workspace_size = GemvStridedBatched::get_workspace_size(arguments);
|
||||
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
status = gemm_op.initialize(arguments, workspace.get());
|
||||
|
||||
EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status);
|
||||
|
||||
//
|
||||
// Run the GEMV
|
||||
//
|
||||
|
||||
status = gemm_op();
|
||||
|
||||
EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status);
|
||||
|
||||
//
|
||||
// Verify
|
||||
//
|
||||
|
||||
bool passed = this->verify(
|
||||
problem_size,
|
||||
batch_count,
|
||||
batch_stride_A,
|
||||
batch_stride_B,
|
||||
batch_stride_C,
|
||||
batch_stride_D,
|
||||
alpha,
|
||||
beta);
|
||||
return passed;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemvStridedBatched>
|
||||
bool TestAllGemv() {
|
||||
|
||||
using ElementCompute = typename GemvStridedBatched::EpilogueOutputOp::ElementCompute;
|
||||
|
||||
int Batch[] = {
|
||||
1, 520, 1314
|
||||
};
|
||||
|
||||
int M[] = {
|
||||
1, 5, 16
|
||||
};
|
||||
|
||||
int K[] = {
|
||||
8, 128, 256
|
||||
};
|
||||
|
||||
double Alpha[] = {
|
||||
1, 1.25
|
||||
};
|
||||
|
||||
double Beta[] = {
|
||||
0, 1, 1.25
|
||||
};
|
||||
|
||||
for (int b : Batch) {
|
||||
for (int m : M) {
|
||||
for (int k : K) {
|
||||
for (double alpha : Alpha) {
|
||||
for (double beta : Beta) {
|
||||
|
||||
TestbedStridedBatchedGemv<GemvStridedBatched> testbed;
|
||||
|
||||
if (!testbed.run(
|
||||
{m, k},
|
||||
b,
|
||||
m * k,
|
||||
k,
|
||||
m,
|
||||
m,
|
||||
ElementCompute(alpha),
|
||||
ElementCompute(beta))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace test
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM50_Device_StridedBatchedGemv_f16n_f16_f16_simt_f32, Simple) {
|
||||
|
||||
using ElementInput = cutlass::half_t;
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementAccumulator = float;
|
||||
int const kElementsPerAccess = 8;
|
||||
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput,
|
||||
1,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator>;
|
||||
|
||||
using GemvStridedBatched = cutlass::gemm::device::GemvStridedBatched<
|
||||
cutlass::gemm::kernel::GemvStridedBatched<
|
||||
ElementInput, // Element A
|
||||
LayoutA, // Layout A
|
||||
ElementInput, // Element B
|
||||
ElementOutput, // Element C
|
||||
ElementAccumulator, // Element accumulator
|
||||
kElementsPerAccess, // Element access granularity
|
||||
EpilogueOp // Output operator
|
||||
>>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::TestAllGemv<GemvStridedBatched>());
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM50_Device_StridedBatchedGemv_f32n_f32_f32_simt_f32, Simple) {
|
||||
|
||||
using ElementInput = float;
|
||||
using ElementOutput = float;
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementAccumulator = float;
|
||||
int const kElementsPerAccess = 4;
|
||||
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput,
|
||||
1,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator>;
|
||||
|
||||
using GemvStridedBatched = cutlass::gemm::device::GemvStridedBatched<
|
||||
cutlass::gemm::kernel::GemvStridedBatched<
|
||||
ElementInput, // Element A
|
||||
LayoutA, // Layout A
|
||||
ElementInput, // Element B
|
||||
ElementOutput, // Element C
|
||||
ElementAccumulator, // Element accumulator
|
||||
kElementsPerAccess, // Element access granularity
|
||||
EpilogueOp // Output operator
|
||||
>>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::TestAllGemv<GemvStridedBatched>());}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM50_Device_StridedBatchedGemv_f64n_f64_f64_simt_f64, Simple) {
|
||||
|
||||
using ElementInput = double;
|
||||
using ElementOutput = double;
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementAccumulator = double;
|
||||
int const kElementsPerAccess = 2;
|
||||
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput,
|
||||
1,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator>;
|
||||
|
||||
using GemvStridedBatched = cutlass::gemm::device::GemvStridedBatched<
|
||||
cutlass::gemm::kernel::GemvStridedBatched<
|
||||
ElementInput, // Element A
|
||||
LayoutA, // Layout A
|
||||
ElementInput, // Element B
|
||||
ElementOutput, // Element C
|
||||
ElementAccumulator, // Element accumulator
|
||||
kElementsPerAccess, // Element access granularity
|
||||
EpilogueOp // Output operator
|
||||
>>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::TestAllGemv<GemvStridedBatched>());}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -773,7 +773,7 @@ TEST(SM90_Device_Gemm_f16n_f16n_f32n_tensor_op_gmma_f32_cooperative, 256x128x64_
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1) {
|
||||
TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
@@ -810,7 +810,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 25
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1) {
|
||||
TEST(SM90_Device_Gemm_f16t_f16n_f16t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
@@ -847,4 +847,78 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 128x128x64_2x2x1) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
using TileShape_MNK = Shape<_128,_128,_64>;
|
||||
using ClusterShape_MNK = Shape<_2,_2,_1>;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
float, LayoutC, 4,
|
||||
float, LayoutC, 4,
|
||||
cutlass::epilogue::TmaWarpSpecializedCooperative
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::half_t, LayoutA, 8,
|
||||
cutlass::half_t, LayoutB, 8,
|
||||
float,
|
||||
TileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperative
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 128x128x64_2x2x1) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using TileShape_MNK = Shape<_128,_128,_64>;
|
||||
using ClusterShape_MNK = Shape<_2,_2,_1>;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
float, LayoutC, 4,
|
||||
float, LayoutC, 4,
|
||||
cutlass::epilogue::TmaWarpSpecializedCooperative
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::half_t, LayoutA, 8,
|
||||
cutlass::half_t, LayoutB, 8,
|
||||
float,
|
||||
TileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperative
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
@@ -363,4 +363,48 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25
|
||||
EXPECT_TRUE(passed);
|
||||
}
|
||||
|
||||
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasMul_ReLU_VoidC) {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using TileShape_MNK = Shape<_256,_128,_64>;
|
||||
using ClusterShape_MNK = Shape<_2,_2,_1>;
|
||||
|
||||
static constexpr bool StoreT = true;
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperativeBiasElementwise<
|
||||
cutlass::epilogue::thread::ReLu, cutlass::half_t, cutlass::multiplies, StoreT, float>;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
void, LayoutC, 8,
|
||||
cutlass::half_t, LayoutC, 8,
|
||||
EpilogueSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::half_t, LayoutA, 8,
|
||||
cutlass::half_t, LayoutB, 8,
|
||||
float,
|
||||
TileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperative
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
bool passed = test::gemm::device::TestAllBiasElementwise<Gemm>();
|
||||
EXPECT_TRUE(passed);
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
@@ -1104,13 +1104,13 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16t_tensor_op_gmma_f32_persistent_Epilogue, 128
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_persistent, 128x128x64_2x2x1) {
|
||||
TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1) {
|
||||
using ElementA = cutlass::half_t;
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using ElementB = cutlass::half_t;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using ElementAccumulator = float;
|
||||
using ElementC = ElementA;
|
||||
using ElementC = cutlass::half_t;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
using TileShape_MNK = Shape<_128,_128,_64>;
|
||||
using ClusterShape_MNK = Shape<_2,_2,_1>;
|
||||
@@ -1121,20 +1121,20 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_persistent, 128x128x64_2
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
cutlass::half_t, LayoutC, 8,
|
||||
cutlass::half_t, LayoutC, 8,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutC, 16 / sizeof(ElementC),
|
||||
ElementC, LayoutC, 16 / sizeof(ElementC),
|
||||
cutlass::epilogue::TmaWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
ElementA, LayoutA, 8,
|
||||
ElementB, LayoutB, 8,
|
||||
ElementA, LayoutA, 16 / sizeof(ElementA),
|
||||
ElementB, LayoutB, 16 / sizeof(ElementB),
|
||||
ElementAccumulator,
|
||||
TileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedPingpong
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
@@ -1147,13 +1147,13 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_persistent, 128x128x64_2
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent, 128x128x64_2x2x1) {
|
||||
TEST(SM90_Device_Gemm_f16t_f16n_f16t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1) {
|
||||
using ElementA = cutlass::half_t;
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using ElementB = cutlass::half_t;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using ElementAccumulator = float;
|
||||
using ElementC = ElementA;
|
||||
using ElementC = cutlass::half_t;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using TileShape_MNK = Shape<_128,_128,_64>;
|
||||
using ClusterShape_MNK = Shape<_2,_2,_1>;
|
||||
@@ -1164,20 +1164,106 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent, 128x128x64_2
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
cutlass::half_t, LayoutC, 8,
|
||||
cutlass::half_t, LayoutC, 8,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutC, 16 / sizeof(ElementC),
|
||||
ElementC, LayoutC, 16 / sizeof(ElementC),
|
||||
cutlass::epilogue::TmaWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
ElementA, LayoutA, 8,
|
||||
ElementB, LayoutB, 8,
|
||||
ElementA, LayoutA, 16 / sizeof(ElementA),
|
||||
ElementB, LayoutB, 16 / sizeof(ElementB),
|
||||
ElementAccumulator,
|
||||
TileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedPingpong
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1) {
|
||||
using ElementA = cutlass::half_t;
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using ElementB = cutlass::half_t;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using ElementAccumulator = float;
|
||||
using ElementC = float;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
using TileShape_MNK = Shape<_128,_128,_64>;
|
||||
using ClusterShape_MNK = Shape<_2,_2,_1>;
|
||||
using StageCountType = cutlass::gemm::collective::StageCountAuto;
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutC, 16 / sizeof(ElementC),
|
||||
ElementC, LayoutC, 16 / sizeof(ElementC),
|
||||
cutlass::epilogue::TmaWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
ElementA, LayoutA, 16 / sizeof(ElementA),
|
||||
ElementB, LayoutB, 16 / sizeof(ElementB),
|
||||
ElementAccumulator,
|
||||
TileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1) {
|
||||
using ElementA = cutlass::half_t;
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using ElementB = cutlass::half_t;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using ElementAccumulator = float;
|
||||
using ElementC = float;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using TileShape_MNK = Shape<_128,_128,_64>;
|
||||
using ClusterShape_MNK = Shape<_2,_2,_1>;
|
||||
using StageCountType = cutlass::gemm::collective::StageCountAuto;
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutC, 16 / sizeof(ElementC),
|
||||
ElementC, LayoutC, 16 / sizeof(ElementC),
|
||||
cutlass::epilogue::TmaWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
ElementA, LayoutA, 16 / sizeof(ElementA),
|
||||
ElementB, LayoutB, 16 / sizeof(ElementB),
|
||||
ElementAccumulator,
|
||||
TileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
|
||||
@@ -362,4 +362,48 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128
|
||||
EXPECT_TRUE(passed);
|
||||
}
|
||||
|
||||
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_BiasMul_ReLU_VoidC) {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using TileShape_MNK = Shape<_128,_128,_64>;
|
||||
using ClusterShape_MNK = Shape<_2,_2,_1>;
|
||||
|
||||
static constexpr bool StoreT = true;
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedBiasElementwise<
|
||||
cutlass::epilogue::thread::ReLu, cutlass::half_t, cutlass::multiplies, StoreT, float>;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
float, float,
|
||||
void, LayoutC, 8,
|
||||
cutlass::half_t, LayoutC, 8,
|
||||
EpilogueSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::half_t, LayoutA, 8,
|
||||
cutlass::half_t, LayoutB, 8,
|
||||
float,
|
||||
TileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedPingpong
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
bool passed = test::gemm::device::TestAllBiasElementwise<Gemm>();
|
||||
EXPECT_TRUE(passed);
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
@@ -269,6 +269,150 @@ TEST(SM90_Device_Gemm_s8t_s8n_s8n_tensor_op_gmma_s32, 128x128x128_2x2x1) {
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM90_Device_Gemm_s8t_s8n_s8n_tensor_op_gmma_s32_pingpong_epilogue, 64x128x128) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_64,_128,_128>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
int32_t, int32_t,
|
||||
int8_t, LayoutC, 16,
|
||||
int8_t, LayoutC, 16,
|
||||
cutlass::epilogue::TmaWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
int8_t, LayoutA, 16,
|
||||
int8_t, LayoutB, 16,
|
||||
int32_t,
|
||||
Shape<_64,_128,_128>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedPingpong
|
||||
>::CollectiveOp;
|
||||
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM90_Device_Gemm_s8t_s8n_s8t_tensor_op_gmma_s32_pingpong_epilogue, 64x128x128) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_64,_128,_128>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
int32_t, int32_t,
|
||||
int8_t, LayoutC, 16,
|
||||
int8_t, LayoutC, 16,
|
||||
cutlass::epilogue::TmaWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
int8_t, LayoutA, 16,
|
||||
int8_t, LayoutB, 16,
|
||||
int32_t,
|
||||
Shape<_64,_128,_128>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedPingpong
|
||||
>::CollectiveOp;
|
||||
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM90_Device_Gemm_s8t_s8n_s8n_tensor_op_gmma_s32_cooperative_epilogue, 128x128x128) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
int32_t, int32_t,
|
||||
int8_t, LayoutC, 16,
|
||||
int8_t, LayoutC, 16,
|
||||
cutlass::epilogue::TmaWarpSpecializedCooperative
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
int8_t, LayoutA, 16,
|
||||
int8_t, LayoutB, 16,
|
||||
int32_t,
|
||||
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperative
|
||||
>::CollectiveOp;
|
||||
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM90_Device_Gemm_s8t_s8n_s8t_tensor_op_gmma_s32_cooperative_epilogue, 128x128x128) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
int32_t, int32_t,
|
||||
int8_t, LayoutC, 16,
|
||||
int8_t, LayoutC, 16,
|
||||
cutlass::epilogue::TmaWarpSpecializedCooperative
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
int8_t, LayoutA, 16,
|
||||
int8_t, LayoutB, 16,
|
||||
int32_t,
|
||||
Shape<_128,_128,_128>, Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperative
|
||||
>::CollectiveOp;
|
||||
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>());
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
@@ -4079,32 +4079,36 @@ def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version):
|
||||
max_cc = 90
|
||||
|
||||
for math_inst in math_instructions:
|
||||
tile_descriptions = [
|
||||
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||
tile_descriptions_small = [
|
||||
# Not compatible with TmaWarpSpecializedCooperative
|
||||
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
||||
]
|
||||
tile_descriptions_medium = [
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
||||
]
|
||||
tile_descriptions_large = [
|
||||
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4],
|
||||
0, [4, 2, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||
#TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
# 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), - Not compatible with TmaWarpSpecializedCooperative
|
||||
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4],
|
||||
0, [4, 2, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||
#TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
# 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),- Not compatible with TmaWarpSpecializedCooperative
|
||||
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
||||
#TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
# 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), - Not compatible with TmaWarpSpecializedCooperative
|
||||
]
|
||||
tile_descriptions = tile_descriptions_medium + tile_descriptions_large
|
||||
|
||||
data_type = {
|
||||
"a_type" : math_inst.element_a,
|
||||
@@ -4139,11 +4143,17 @@ def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version):
|
||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, schedules)
|
||||
|
||||
# persistent kernels with TMA epilogues
|
||||
if data_type["c_type"] in [DataType.f16, DataType.bf16] and CudaToolkitVersionSatisfies(cuda_version, 12, 1):
|
||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
|
||||
[[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecialized],
|
||||
[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative]])
|
||||
# Emit instance without C allocation+load
|
||||
if CudaToolkitVersionSatisfies(cuda_version, 12, 1):
|
||||
# not enough smem for 256x128 f32 out with C allocation
|
||||
if data_type["d_type"] == DataType.f32:
|
||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions_medium, data_type,
|
||||
[[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecialized],
|
||||
[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative]])
|
||||
else:
|
||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
|
||||
[[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecialized],
|
||||
[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative]])
|
||||
# Emit instance without C allocation + load
|
||||
data_type["c_type"] = DataType.void
|
||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
|
||||
[[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecialized],
|
||||
@@ -4170,7 +4180,7 @@ def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version):
|
||||
|
||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type_mixed, schedules)
|
||||
# persistent kernels with TMA epilogues
|
||||
if data_type_mixed["c_type"] in [DataType.f16, DataType.bf16] and CudaToolkitVersionSatisfies(cuda_version, 12, 1):
|
||||
if CudaToolkitVersionSatisfies(cuda_version, 12, 1):
|
||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type_mixed,
|
||||
[[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecialized],
|
||||
[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative]])
|
||||
@@ -4331,6 +4341,28 @@ def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version):
|
||||
for data_type in data_types:
|
||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type)
|
||||
|
||||
# persistent kernels with TMA epilogues
|
||||
if CudaToolkitVersionSatisfies(cuda_version, 12, 1):
|
||||
# Emit instance without C allocation+load
|
||||
data_types += [
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
"c_type" : DataType.void,
|
||||
"d_type" : math_inst.element_accumulator,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : math_inst.element_accumulator
|
||||
}
|
||||
]
|
||||
for data_type in data_types:
|
||||
# Set alignment d based on Destination format.
|
||||
for layout in layouts:
|
||||
layout[2][1] = 128 // DataTypeSize[data_type["d_type"]]
|
||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
|
||||
[[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecialized],
|
||||
[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative]])
|
||||
|
||||
|
||||
#
|
||||
def GenerateSM90_TensorOp_1684(manifest, cuda_version):
|
||||
|
||||
|
||||
@@ -345,358 +345,6 @@ void initialize_gemm_reference_operations(Manifest &manifest) {
|
||||
complex<double>
|
||||
>(manifest);
|
||||
|
||||
//
|
||||
// FP8 GEMMs
|
||||
//
|
||||
//////////////////////////////////
|
||||
/// ElementC: half_t
|
||||
//////////////////////////////////
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
half_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float , // ElementAccumulator
|
||||
float_e4m3_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e5m2_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
//////////////////////////////////
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float , // ElementAccumulator
|
||||
half_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e4m3_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e5m2_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
//////////////////////////////////
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
half_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e4m3_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e5m2_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
//////////////////////////////////
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
half_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e4m3_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e5m2_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
//////////////////////////////////
|
||||
/// ElementC: bfloat16_t
|
||||
//////////////////////////////////
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
bfloat16_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
bfloat16_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
bfloat16_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e4m3_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
bfloat16_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e5m2_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
//////////////////////////////////
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
bfloat16_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
bfloat16_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
bfloat16_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e4m3_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
bfloat16_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e5m2_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
//////////////////////////////////
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
bfloat16_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
bfloat16_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
bfloat16_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e4m3_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
bfloat16_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e5m2_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
//////////////////////////////////
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
bfloat16_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
bfloat16_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
bfloat16_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e4m3_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
bfloat16_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e5m2_t // ElementD
|
||||
>(manifest);
|
||||
//////////////////////////////////
|
||||
/// ElementC: float
|
||||
//////////////////////////////////
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
float, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
float, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e4m3_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
float, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e5m2_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
//////////////////////////////////
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
float, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
float, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e4m3_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
float, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e5m2_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
//////////////////////////////////
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
float, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
float, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e4m3_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
float, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e5m2_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
//////////////////////////////////
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
float, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
float, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e4m3_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
float, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e5m2_t // ElementD
|
||||
>(manifest);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@@ -103,14 +103,20 @@ bool get_cublas_transpose_operation(
|
||||
|
||||
/// Maps a CUTLASS numeric type to a cuBLAS data type enumeration
|
||||
bool get_cublas_datatype(cublasDataType_t &data_type, library::NumericTypeID element_type) {
|
||||
switch (element_type) {
|
||||
switch (element_type) {
|
||||
case library::NumericTypeID::kFE4M3:
|
||||
#if (__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8))
|
||||
data_type = CUDA_R_8F_E4M3;
|
||||
return true;
|
||||
#endif
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kFE5M2:
|
||||
#if (__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8))
|
||||
data_type = CUDA_R_8F_E5M2;
|
||||
return true;
|
||||
#endif
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kF16:
|
||||
data_type = CUDA_R_16F;
|
||||
@@ -139,7 +145,7 @@ bool get_cublas_datatype(cublasDataType_t &data_type, library::NumericTypeID ele
|
||||
return true;
|
||||
|
||||
case library::NumericTypeID::kS16:
|
||||
break;
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kS32:
|
||||
data_type = CUDA_R_32I;
|
||||
@@ -260,6 +266,13 @@ Status cublas_satisfies(library::GemmDescription const &desc) {
|
||||
return Status::kErrorNotSupported;
|
||||
}
|
||||
|
||||
// input type BF16 and TF32 not supported in cuBLAS
|
||||
if (desc.A.element == library::NumericTypeID::kBF16 ||
|
||||
desc.A.element == library::NumericTypeID::kTF32) {
|
||||
|
||||
return Status::kErrorNotSupported;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user