releaase 2.11 (#703)

This commit is contained in:
Aditya Atluri
2022-11-19 06:02:15 -08:00
committed by GitHub
parent 3c90f6aea6
commit c975e2ccbb
329 changed files with 47332 additions and 10607 deletions

View File

@@ -78,6 +78,7 @@ void FilterArchitecture() {
{ "SM70*", 70, 75},
{ "SM75*", 75, kMaxDevice},
{ "SM80*", 80, kMaxDevice},
{ "SM90*", 90, kMaxDevice},
{ 0, 0, false }
};

View File

@@ -110,7 +110,9 @@ cutlass_test_unit_add_executable(
# F16
conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu
depthwise_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu
depthwise_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu
depthwise_conv2d_fprop_direct_conv_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu
depthwise_conv2d_fprop_direct_conv_fixed_stride_dilation_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu
)
if (CUTLASS_NVCC_MAX_ARCH GREATER_EQUAL 80)

View File

@@ -776,6 +776,29 @@ struct TestbedGroupConv2dProblemSizes {
2 // groups
));
// Larger problem sizes
default_single_group_sizes.push_back(cutlass::conv::Conv2dProblemSize(
{1, 56, 56, 696}, // input size (NHWC)
{768, 3, 3, 232}, // filter size (KRSC)
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
{2, 2}, // stride (stride_h, stride_w)
{1, 1}, // dilation (dilation_h, dilation_w)
cutlass::conv::Mode::kCrossCorrelation,
1, // split_k_slices
3 // groups
));
default_single_group_sizes.push_back(cutlass::conv::Conv2dProblemSize(
{1, 14, 14, 1392}, // input size (NHWC)
{1536, 3, 3, 232}, // filter size (KRSC)
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
{1, 1}, // stride (stride_h, stride_w)
{1, 1}, // dilation (dilation_h, dilation_w)
cutlass::conv::Mode::kCrossCorrelation,
1, // split_k_slices
3 // groups
));
////////////////////////////////////////////////////////////////////////////////////
// One CTA calculate multiple groups: CTA::N % k_per_group = 0
////////////////////////////////////////////////////////////////////////////////////

View File

@@ -192,7 +192,7 @@ public:
// Determine SMEM requirements and waive if not satisfied
//
int smem_size = int(sizeof(typename Conv2d::ImplicitGemmKernel::SharedStorage));
int smem_size = int(sizeof(typename Conv2d::UnderlyingKernel::SharedStorage));
cudaDeviceProp properties;
int device_idx;
@@ -208,7 +208,7 @@ public:
throw std::runtime_error("cudaGetDeviceProperties() failed");
}
if (properties.sharedMemPerMultiprocessor < smem_size) {
if (properties.sharedMemPerBlockOptin < smem_size) {
return false;
}
@@ -305,15 +305,15 @@ public:
cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, problem_size),
{
reinterpret_cast<ElementAccumulator*> (workspace.get()),
ReductionStrideIndex(tensor_C.stride()[Conv2d::ImplicitGemmKernel::kTensorCStrideIdx])
ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx])
},
{
tensor_D_computed.device_data(),
ReductionStrideIndex(tensor_C.stride()[Conv2d::ImplicitGemmKernel::kTensorCStrideIdx])
ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx])
},
{
tensor_C.device_data(),
ReductionStrideIndex(tensor_C.stride()[Conv2d::ImplicitGemmKernel::kTensorCStrideIdx])
ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx])
},
// apply alpha, beta to obtain the following equation alpha * ReduceAdd(A * B) + beta * C
{alpha, beta}
@@ -637,7 +637,7 @@ bool TestAllConv2d(
// CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1}
if ((ImplicitGemm::kConvolutionalOperator ==
cutlass::conv::Operator::kDgrad) &&
(ImplicitGemm::ImplicitGemmKernel::Mma::IteratorA::kStrideSupport ==
(ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport ==
cutlass::conv::StrideSupport::kUnity)) {
if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) {
continue;
@@ -645,17 +645,17 @@ bool TestAllConv2d(
}
// Fixed channels algorithm requires channel count to match access size
if (ImplicitGemm::ImplicitGemmKernel::Mma::IteratorA::kIteratorAlgorithm ==
if (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kIteratorAlgorithm ==
cutlass::conv::IteratorAlgorithm::kFixedChannels) {
if (conv_problem.C != ImplicitGemm::ImplicitGemmKernel::Mma::IteratorA::AccessType::kElements) {
if (conv_problem.C != ImplicitGemm::UnderlyingKernel::Mma::IteratorA::AccessType::kElements) {
continue;
}
}
// Few channels algorithm requires channel count to match access size
if (ImplicitGemm::ImplicitGemmKernel::Mma::IteratorA::kIteratorAlgorithm ==
if (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kIteratorAlgorithm ==
cutlass::conv::IteratorAlgorithm::kFewChannels) {
if (conv_problem.C % ImplicitGemm::ImplicitGemmKernel::Mma::IteratorA::AccessType::kElements) {
if (conv_problem.C % ImplicitGemm::UnderlyingKernel::Mma::IteratorA::AccessType::kElements) {
continue;
}
}
@@ -665,7 +665,7 @@ bool TestAllConv2d(
// to run strided dgrad for non-unity strides
if ((ImplicitGemm::kConvolutionalOperator ==
cutlass::conv::Operator::kDgrad) &&
(ImplicitGemm::ImplicitGemmKernel::Mma::IteratorA::kStrideSupport ==
(ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport ==
cutlass::conv::StrideSupport::kStrided)) {
if (((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) {
continue;
@@ -704,14 +704,14 @@ bool TestAllConv2d(
}
// Small-channels convolution can't run here.
if (ImplicitGemm::ImplicitGemmKernel::Mma::IteratorA::kIteratorAlgorithm ==
if (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kIteratorAlgorithm ==
cutlass::conv::IteratorAlgorithm::kFixedChannels) {
return true;
}
// Small-channels convolution can't run here.
if (ImplicitGemm::ImplicitGemmKernel::Mma::IteratorA::kIteratorAlgorithm ==
if (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kIteratorAlgorithm ==
cutlass::conv::IteratorAlgorithm::kFewChannels) {
return true;
@@ -720,7 +720,7 @@ bool TestAllConv2d(
// CUTLASS DGRAD's *strided* specialization does not support split-k mode
if ((ImplicitGemm::kConvolutionalOperator ==
cutlass::conv::Operator::kDgrad) &&
(ImplicitGemm::ImplicitGemmKernel::Mma::IteratorA::kStrideSupport ==
(ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport ==
cutlass::conv::StrideSupport::kStrided)) {
passed = testbed.run(

View File

@@ -257,15 +257,15 @@ public:
cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, problem_size),
{
reinterpret_cast<ElementAccumulator*> (workspace.get()),
ReductionStrideIndex(tensor_C.stride()[Conv2d::ImplicitGemmKernel::kTensorCStrideIdx])
ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx])
},
{
tensor_D_computed.device_data(),
ReductionStrideIndex(tensor_C.stride()[Conv2d::ImplicitGemmKernel::kTensorCStrideIdx])
ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx])
},
{
tensor_C.device_data(),
ReductionStrideIndex(tensor_C.stride()[Conv2d::ImplicitGemmKernel::kTensorCStrideIdx])
ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx])
},
// apply alpha, beta to obtain the following equation alpha * ReduceAdd(A * B) + beta * C
{alpha, beta}
@@ -536,7 +536,7 @@ bool TestAllInterleavedConv2d(
// CUTLASS DGRAD's unity stride specialization only support stride {1, 1}
if ((ImplicitGemm::kConvolutionalOperator ==
cutlass::conv::Operator::kDgrad) &&
(ImplicitGemm::ImplicitGemmKernel::Mma::IteratorA::kStrideSupport ==
(ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport ==
cutlass::conv::StrideSupport::kUnity)) {
if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) {
continue;

View File

@@ -253,7 +253,7 @@ public:
// Determine SMEM requirements and waive if not satisfied
//
int smem_size = int(sizeof(typename Conv2d::ImplicitGemmKernel::SharedStorage));
int smem_size = int(sizeof(typename Conv2d::UnderlyingKernel::SharedStorage));
cudaDeviceProp properties;
int device_idx;
@@ -269,7 +269,7 @@ public:
throw std::runtime_error("cudaGetDeviceProperties() failed");
}
if (properties.sharedMemPerMultiprocessor < smem_size) {
if (properties.sharedMemPerBlockOptin < smem_size) {
return false;
}
@@ -557,7 +557,7 @@ bool TestAllConv2dWithBroadcast(
// CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1}
if ((ImplicitGemm::kConvolutionalOperator ==
cutlass::conv::Operator::kDgrad) &&
(ImplicitGemm::ImplicitGemmKernel::Mma::IteratorA::kStrideSupport ==
(ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport ==
cutlass::conv::StrideSupport::kUnity)) {
if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) {
continue;
@@ -568,7 +568,7 @@ bool TestAllConv2dWithBroadcast(
// CUTLASS DGRAD's *strided* specialization only support stride >= {2, 2}
if ((ImplicitGemm::kConvolutionalOperator ==
cutlass::conv::Operator::kDgrad) &&
(ImplicitGemm::ImplicitGemmKernel::Mma::IteratorA::kStrideSupport ==
(ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport ==
cutlass::conv::StrideSupport::kStrided)) {
if (((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) {
continue;
@@ -605,7 +605,7 @@ bool TestAllConv2dWithBroadcast(
// CUTLASS DGRAD's *strided* specialization does not support split-k mode
if ((ImplicitGemm::kConvolutionalOperator ==
cutlass::conv::Operator::kDgrad) &&
(ImplicitGemm::ImplicitGemmKernel::Mma::IteratorA::kStrideSupport ==
(ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport ==
cutlass::conv::StrideSupport::kStrided)) {
passed = testbed.run(

View File

@@ -182,7 +182,7 @@ public:
// Determine SMEM requirements and waive if not satisfied
//
int smem_size = int(sizeof(typename Conv2d::ImplicitGemmKernel::SharedStorage));
int smem_size = int(sizeof(typename Conv2d::UnderlyingKernel::SharedStorage));
cudaDeviceProp properties;
int device_idx;
@@ -198,7 +198,7 @@ public:
throw std::runtime_error("cudaGetDeviceProperties() failed");
}
if (properties.sharedMemPerMultiprocessor < smem_size) {
if (properties.sharedMemPerBlockOptin < smem_size) {
return false;
}
@@ -516,7 +516,7 @@ bool TestAllConv2dWithReduction(
// CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1}
if ((ImplicitGemm::kConvolutionalOperator ==
cutlass::conv::Operator::kDgrad) &&
(ImplicitGemm::ImplicitGemmKernel::Mma::IteratorA::kStrideSupport ==
(ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport ==
cutlass::conv::StrideSupport::kUnity)) {
if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) {
continue;
@@ -527,7 +527,7 @@ bool TestAllConv2dWithReduction(
// CUTLASS DGRAD's *strided* specialization only support stride >= {2, 2}
if ((ImplicitGemm::kConvolutionalOperator ==
cutlass::conv::Operator::kDgrad) &&
(ImplicitGemm::ImplicitGemmKernel::Mma::IteratorA::kStrideSupport ==
(ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport ==
cutlass::conv::StrideSupport::kStrided)) {
if (((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) {
continue;
@@ -564,7 +564,7 @@ bool TestAllConv2dWithReduction(
// CUTLASS DGRAD's *strided* specialization does not support split-k mode
if ((ImplicitGemm::kConvolutionalOperator ==
cutlass::conv::Operator::kDgrad) &&
(ImplicitGemm::ImplicitGemmKernel::Mma::IteratorA::kStrideSupport ==
(ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport ==
cutlass::conv::StrideSupport::kStrided)) {
passed = testbed.run(

View File

@@ -184,7 +184,7 @@ public:
// Determine SMEM requirements and waive if not satisfied
//
int smem_size = int(sizeof(typename Conv3d::ImplicitGemmKernel::SharedStorage));
int smem_size = int(sizeof(typename Conv3d::UnderlyingKernel::SharedStorage));
cudaDeviceProp properties;
int device_idx;
@@ -200,7 +200,7 @@ public:
throw std::runtime_error("cudaGetDeviceProperties() failed");
}
if (properties.sharedMemPerMultiprocessor < smem_size) {
if (properties.sharedMemPerBlockOptin < smem_size) {
return false;
}
@@ -294,15 +294,15 @@ public:
cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, problem_size),
{
reinterpret_cast<ElementAccumulator*> (workspace.get()),
ReductionStrideIndex(tensor_C.stride()[Conv3d::ImplicitGemmKernel::kTensorCStrideIdx])
ReductionStrideIndex(tensor_C.stride()[Conv3d::UnderlyingKernel::kTensorCStrideIdx])
},
{
tensor_D_computed.device_data(),
ReductionStrideIndex(tensor_C.stride()[Conv3d::ImplicitGemmKernel::kTensorCStrideIdx])
ReductionStrideIndex(tensor_C.stride()[Conv3d::UnderlyingKernel::kTensorCStrideIdx])
},
{
tensor_C.device_data(),
ReductionStrideIndex(tensor_C.stride()[Conv3d::ImplicitGemmKernel::kTensorCStrideIdx])
ReductionStrideIndex(tensor_C.stride()[Conv3d::UnderlyingKernel::kTensorCStrideIdx])
},
// apply alpha, beta to obtain the following equation alpha * ReduceAdd(A * B) + beta * C
{alpha, beta}
@@ -573,9 +573,9 @@ bool TestAllConv3d(
// CUTLASS DGRAD's unity stride specialization only support stride {1, 1, 1}
if ((ImplicitGemm::kConvolutionalOperator ==
cutlass::conv::Operator::kDgrad) &&
((ImplicitGemm::ImplicitGemmKernel::Mma::IteratorA::kStrideSupport ==
((ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport ==
cutlass::conv::StrideSupport::kUnity) ||
(ImplicitGemm::ImplicitGemmKernel::Mma::IteratorB::kStrideSupport ==
(ImplicitGemm::UnderlyingKernel::Mma::IteratorB::kStrideSupport ==
cutlass::conv::StrideSupport::kUnity))) {
if (!((conv_problem.stride_d == 1) &&
(conv_problem.stride_h == 1) &&

View File

@@ -0,0 +1,473 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Depthwise Direct Conv testbed
*/
#pragma once
#include <fstream>
#include "../../common/cutlass_unit_test.h"
#include "cache_testbed_output.h"
#include "conv2d_problems.h"
#include "cutlass/conv/device/direct_convolution.h"
#include "cutlass/core_io.h"
#include "cutlass/cutlass.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/device/convolution.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/host/convolution.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
namespace test {
namespace conv {
namespace device {
template <typename Conv2d>
class TestbedDepthwiseDirectConv2d {
public:
using ElementA = typename Conv2d::ElementA;
using LayoutA = typename Conv2d::LayoutA;
using ElementB = typename Conv2d::ElementB;
using LayoutB = typename Conv2d::LayoutB;
using ElementC = typename Conv2d::ElementC;
using LayoutC = typename Conv2d::LayoutC;
using ElementAccumulator = typename Conv2d::ElementAccumulator;
using ElementCompute = typename Conv2d::ElementCompute;
using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp;
static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator;
public:
/// Initialization
cutlass::Distribution::Kind init_A;
cutlass::Distribution::Kind init_B;
cutlass::Distribution::Kind init_C;
uint64_t seed;
cutlass::HostTensor<ElementA, LayoutA> tensor_A;
cutlass::HostTensor<ElementB, LayoutB> tensor_B;
cutlass::HostTensor<ElementB, LayoutB> tensor_reordered_B;
cutlass::HostTensor<ElementC, LayoutC> tensor_C;
cutlass::HostTensor<ElementC, LayoutC> tensor_D_computed;
cutlass::HostTensor<ElementC, LayoutC> tensor_D_reference;
int tested_problem_count;
public:
TestbedDepthwiseDirectConv2d(cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
uint64_t seed_ = 2080)
: init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_), tested_problem_count(0) {}
/// Helper to initialize a tensor view
template <typename Element, typename Layout>
void initialize_tensor(cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
if (dist_kind == cutlass::Distribution::Uniform) {
int scope;
int bits = cutlass::sizeof_bits<Element>::value;
if (bits <= 8) {
scope = 2;
} else if (bits == 16) {
if (cutlass::sizeof_bits<ElementAccumulator>::value <= 16) {
scope = 3;
} else {
scope = 5;
}
} else {
scope = 8;
}
cutlass::reference::host::TensorFillRandomUniform(view, seed, scope, -scope, 0);
} else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
} else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
} else if (dist_kind == cutlass::Distribution::Sequential) {
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
} else {
}
}
void initialize(cutlass::conv::Conv2dProblemSize const &problem_size, uint64_t seed = 2019) {
tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size));
tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size));
tensor_reordered_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size));
tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
initialize_tensor(tensor_A.host_view(), init_A, seed);
initialize_tensor(tensor_B.host_view(), init_B, seed * 17);
initialize_tensor(tensor_reordered_B.host_view(), init_B, seed * 17);
initialize_tensor(tensor_C.host_view(), init_C, seed * 39);
tensor_A.sync_device();
tensor_B.sync_device();
tensor_reordered_B.sync_device();
tensor_C.sync_device();
tensor_D_computed.sync_device();
tensor_D_reference.sync_device();
}
bool sufficient(int smem_size) const {
//
// Determine SMEM requirements and waive if not satisfied
//
cudaDeviceProp properties;
int device_idx;
cudaError_t result = cudaGetDevice(&device_idx);
if (result != cudaSuccess) {
throw std::runtime_error("cudaGetDevice() API call failed.");
}
result = cudaGetDeviceProperties(&properties, device_idx);
if (result != cudaSuccess) {
throw std::runtime_error("cudaGetDeviceProperties() failed");
}
if (properties.sharedMemPerBlockOptin < smem_size) {
return false;
}
return true;
}
/// Executes one test
bool run(cutlass::conv::Conv2dProblemSize const &problem_size,
cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial,
ElementCompute alpha = ElementCompute(1.5),
ElementCompute beta = ElementCompute(1)) {
// increment tested problem count run by the testbed
tested_problem_count++;
#if 0 // display conv2d problem size for debugging
std::cout << problem_size << std::endl
<< "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl
<< "split_k_mode: "
<< ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)")
<< std::endl
<< std::endl;
#endif
initialize(problem_size);
// configure the operator
Conv2d conv2d_op;
typename Conv2d::Arguments conv2d_args(problem_size,
tensor_A.device_ref(),
tensor_B.device_ref(),
tensor_C.device_ref(),
tensor_D_computed.device_ref(),
{alpha, beta},
tensor_reordered_B.device_ref(),
split_k_mode);
// find workspace requirement for parallel split-k reduction
size_t workspace_size = Conv2d::get_workspace_size(conv2d_args);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
cutlass::Status status = conv2d_op.can_implement(problem_size);
if (status != cutlass::Status::kSuccess) {
cudaError_t error = cudaGetLastError();
std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n";
return true;
}
status = conv2d_op.initialize(conv2d_args, workspace.get());
if (status != cutlass::Status::kSuccess) {
cudaError_t error = cudaGetLastError();
std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n";
return true;
}
if (!sufficient(conv2d_op.get_smem_size())) {
if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) {
std::cerr << "Test waived due to insufficient CUDA device." << std::endl;
}
return true;
}
// run conv2d operator
status = conv2d_op();
EXPECT_TRUE(status == cutlass::Status::kSuccess);
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to run." << std::endl;
return false;
}
bool passed = false;
cudaError_t result = cudaDeviceSynchronize();
EXPECT_EQ(result, cudaSuccess) << " device reference error: " << cudaGetErrorString(result);
tensor_D_computed.sync_host();
//
// Reference check - support caching results
//
CachedTestKey cached_test_key =
CreateCachedConv2dTestKey<ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
ElementAccumulator,
ElementCompute>(kConvolutionalOperator,
problem_size,
alpha,
beta,
tensor_A.host_view(),
tensor_B.host_view(),
tensor_C.host_view());
//
// Look for the cached key
//
bool cached_result_loaded = false;
CachedTestResult cached_test_result;
std::string conv2d_result_cache_name =
std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt";
if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) {
CachedTestResultListing cached_results(conv2d_result_cache_name);
auto cached = cached_results.find(cached_test_key);
cached_result_loaded = cached.first;
if (cached_result_loaded) {
cached_test_result = cached.second;
}
}
if (!cached_result_loaded) {
#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED
cutlass::reference::device::Conv2d<ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
ElementCompute,
ElementAccumulator>(kConvolutionalOperator,
problem_size,
tensor_A.device_ref(),
tensor_B.device_ref(),
tensor_C.device_ref(),
tensor_D_reference.device_ref(),
alpha,
beta);
// sync host (copy device data to host) for dumping error output in case of mismatches
tensor_D_reference.sync_host();
#else
cutlass::reference::host::Conv2d<ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
ElementCompute,
ElementAccumulator>(kConvolutionalOperator,
problem_size,
tensor_A.host_ref(),
tensor_B.host_ref(),
tensor_C.host_ref(),
tensor_D_reference.host_ref(),
alpha,
beta);
#endif
if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) {
cached_test_result.D = TensorHash(tensor_D_reference.host_view());
CachedTestResultListing cached_results(conv2d_result_cache_name);
cached_results.append(cached_test_key, cached_test_result);
cached_results.write(conv2d_result_cache_name);
}
} // if (!cached_result_loaded)
uint32_t tensor_D_hash = TensorHash(tensor_D_computed.host_view());
if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) {
passed = (tensor_D_hash == cached_test_result.D);
EXPECT_EQ(tensor_D_hash, cached_test_result.D)
<< "Hash-based comparison failed for key:" << "\n" << cached_test_key << "\n";
}
else {
passed = cutlass::reference::host::TensorEquals(
tensor_D_computed.host_view(),
tensor_D_reference.host_view());
}
EXPECT_TRUE(passed);
std::stringstream ss_problem_size_text;
ss_problem_size_text << "nhwc_"
<< problem_size.N << "x"
<< problem_size.H << "x"
<< problem_size.W << "x"
<< problem_size.C
<< "_krsc_"
<< problem_size.K << "x"
<< problem_size.R << "x"
<< problem_size.S << "x"
<< problem_size.C
<< "_padding_"
<< problem_size.pad_h << "x"
<< problem_size.pad_w
<< "_stride_"
<< problem_size.stride_h << "x"
<< problem_size.stride_w
<< "_dilation_"
<< problem_size.dilation_h << "x"
<< problem_size.dilation_w << "_"
<< (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_");
if (!passed) {
std::stringstream fname;
fname << "error_Conv2d_DirectConv_device_"
<< (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_")
<< (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" :
(Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : "wgrad_"))
<< ss_problem_size_text.str()
<< Conv2d::ThreadblockShape::kM << "x"
<< Conv2d::ThreadblockShape::kN << "x"
<< Conv2d::ThreadblockShape::kK << "_"
<< Conv2d::WarpShape::kM << "x"
<< Conv2d::WarpShape::kN << "x"
<< Conv2d::WarpShape::kK << ".txt";
std::cout << fname.str() << std::endl;
std::ofstream results(fname.str());
results << problem_size << std::endl;
results
<< "\nA:\n" << tensor_A.host_view() << "\n"
<< "\nB:\n" << tensor_B.host_view() << "\n"
<< "\nC:\n" << tensor_C.host_view() << "\n";
results << "\nD reference (hash: " << cached_test_result.D << ")\n";
if (!cached_result_loaded) {
results
<< tensor_D_reference.host_view() << "\n";
}
results
<< "\nD computed (hash: " << tensor_D_hash << ")\n"
<< tensor_D_computed.host_view() << "\n";
}
return passed;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename DirectConv>
bool TestSpecificDepthwiseDirectConv2d(const Conv2dProblemVector &problem_sizes) {
bool passed = true;
//
// Testbed object
//
TestbedDepthwiseDirectConv2d<DirectConv> testbed;
// Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0)
for (auto conv_problem : problem_sizes) {
//
// Test
//
// test mode = xcross
passed = testbed.run(
conv_problem,
cutlass::conv::SplitKMode::kSerial);
if (!passed) {
return false;
}
// test mode = convolution
passed = testbed.run(
conv_problem.reset_mode(cutlass::conv::Mode::kConvolution),
cutlass::conv::SplitKMode::kSerial);
if (!passed) {
return false;
}
}
return true;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace device
} // namespace conv
} // namespace test
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,426 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide Depthwise Direct Conv interface
*/
#include "../../common/cutlass_unit_test.h"
#include "cutlass/cutlass.h"
#include "cutlass/conv/kernel/default_depthwise_fprop.h"
#include "cutlass/conv/device/direct_convolution.h"
#include "conv2d_testbed.h"
#include "depthwise_conv2d_direct_conv_testbed.h"
std::vector<cutlass::conv::Conv2dProblemSize> DepthwiseFpropProblemSizes_filter3x3() {
std::vector<cutlass::conv::Conv2dProblemSize> problems;
for (int channels = 16; channels <= 512; channels += 16) {
problems.push_back(cutlass::conv::Conv2dProblemSize(
{1, 8, 8, channels}, // input size (NHWC)
{channels, 3, 3, 1}, // filter size (KRSC)
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
{1, 1}, // stride (stride_h, stride_w)
{1, 1}, // dilation (dilation_h, dilation_w)
cutlass::conv::Mode::kCrossCorrelation, // Convolution mode
16, // split_k_slices
channels // groups
));
// if(channels == 512 || channels == 16*14)
problems.push_back(cutlass::conv::Conv2dProblemSize(
{1, 16, 16, channels}, // input size (NHWC)
{channels, 3, 3, 1}, // filter size (KRSC)
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
{2, 2}, // stride (stride_h, stride_w)
{2, 2}, // dilation (dilation_h, dilation_w)
cutlass::conv::Mode::kCrossCorrelation, // Convolution mode
16, // split_k_slices
channels // groups
));
}
return problems;
}
std::vector<cutlass::conv::Conv2dProblemSize> DepthwiseFpropProblemSizes_filter5x5() {
std::vector<cutlass::conv::Conv2dProblemSize> problems;
for (int channels = 16; channels < 256; channels += 16) {
problems.push_back(cutlass::conv::Conv2dProblemSize(
{1, 16, 16, channels}, // input size (NHWC)
{channels, 5, 5, 1}, // filter size (KRSC)
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
{1, 1}, // stride (stride_h, stride_w)
{1, 1}, // dilation (dilation_h, dilation_w)
cutlass::conv::Mode::kCrossCorrelation, // Convolution mode
16, // split_k_slices
channels // groups
));
problems.push_back(cutlass::conv::Conv2dProblemSize(
{1, 112, 112, channels}, // input size (NHWC)
{channels, 5, 5, 1}, // filter size (KRSC)
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
{1, 1}, // stride (stride_h, stride_w)
{1, 1}, // dilation (dilation_h, dilation_w)
cutlass::conv::Mode::kCrossCorrelation, // Convolution mode
16, // split_k_slices
channels // groups
));
problems.push_back(cutlass::conv::Conv2dProblemSize(
{1, 112, 112, channels}, // input size (NHWC)
{channels, 5, 5, 1}, // filter size (KRSC)
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
{2, 2}, // stride (stride_h, stride_w)
{2, 2}, // dilation (dilation_h, dilation_w)
cutlass::conv::Mode::kCrossCorrelation, // Convolution mode
16, // split_k_slices
channels // groups
));
}
return problems;
}
std::vector<cutlass::conv::Conv2dProblemSize> DepthwiseFpropProblemSizes_filter5x37() {
std::vector<cutlass::conv::Conv2dProblemSize> problems;
for (int channels = 16; channels < 256; channels += 16) {
problems.push_back(cutlass::conv::Conv2dProblemSize(
{1, 128, 128, channels}, // input size (NHWC)
{channels, 5, 37, 1}, // filter size (KRSC)
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
{1, 1}, // stride (stride_h, stride_w)
{1, 1}, // dilation (dilation_h, dilation_w)
cutlass::conv::Mode::kCrossCorrelation, // Convolution mode
108, // split_k_slices
channels // groups
));
}
return problems;
}
////////////////////////////////////////////////////////////////////////////////
TEST(
SM60_Device_Depthwise_conv2d_Fprop_Direct_Conv_Optimized_f16nhwc_f16nhwc_f16nhwc_simt_f16,
64x32_4_8x32_3x3) {
using ElementInputA = cutlass::half_t;
using ElementInputB = cutlass::half_t;
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using ElementComputeEpilogue = cutlass::half_t;
using LayoutInputA = cutlass::layout::TensorNHWC;
using LayoutInputB = cutlass::layout::TensorNHWC;
using LayoutOutput = cutlass::layout::TensorNHWC;
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU
// SM
using MMAOp = cutlass::arch::OpClassSimt;
// This code section describes CUDA SM architecture number
using SmArch = cutlass::arch::Sm60;
// This code section describes the groups a thread block will compute
constexpr int groups_per_cta = 32;
// This code section describes the output tile <N, P, Q, C> a thread block will compute
using ThreadBlockOutputShape = cutlass::conv::TensorNHWCShape<1, 8, 8, groups_per_cta>;
// This code section describes the filter shape <R, S>
using FilterShape = cutlass::MatrixShape<3, 3>;
// Threadblock tile shape
using ThreadblockShape =
cutlass::gemm::GemmShape<ThreadBlockOutputShape::kNHW, groups_per_cta, FilterShape::kCount>;
// This code section describes tile size a warp will computes
using WarpShape = cutlass::gemm::GemmShape<8, groups_per_cta, FilterShape::kCount>;
// This code section describes the size of MMA op
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock =
cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle<
1,
ThreadBlockOutputShape::kN,
ThreadBlockOutputShape::kH,
ThreadBlockOutputShape::kW>;
// Number of pipelines you want to use
constexpr int NumStages = 4;
// This code section describe iterator algorithm selected is Analytic or Optimized
static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm =
cutlass::conv::IteratorAlgorithm::kOptimized;
constexpr int kEpilogueElementsPerAccess = 128 / cutlass::sizeof_bits<ElementOutput>::value;
// This code section describes the epilogue part of the kernel, we use default value
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput, // Data type of output matrix.
kEpilogueElementsPerAccess, // The number of elements per vectorized.
// memory access. This becomes the vector width of
// math instructions in the epilogue too.
ElementAccumulator, // Data type of accumulator
ElementComputeEpilogue, // Data type for alpha/beta in linear combination
cutlass::epilogue::thread::ScaleType::Default>;
using DepthwiseDirect2dConv = typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConvFprop<
ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ThreadblockShape,
ThreadBlockOutputShape,
FilterShape,
WarpShape,
InstructionShape,
EpilogueOp,
SwizzleThreadBlock,
NumStages,
cutlass::arch::OpMultiplyAdd,
IteratorAlgorithm,
cutlass::conv::StrideSupport::kStrided>::Kernel;
using Direct2dConv = cutlass::conv::device::DirectConvolution<DepthwiseDirect2dConv>;
/// Run all unit test sizes with device-level Conv2d instance
EXPECT_TRUE(test::conv::device::TestSpecificDepthwiseDirectConv2d<Direct2dConv>(
DepthwiseFpropProblemSizes_filter3x3()));
}
////////////////////////////////////////////////////////////////////////////////
TEST(
SM60_Device_Depthwise_conv2d_Fprop_Direct_Conv_Optimized_f16nhwc_f16nhwc_f16nhwc_simt_f16,
64x64_3_16x64_5x5) {
using ElementInputA = cutlass::half_t;
using ElementInputB = cutlass::half_t;
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using ElementComputeEpilogue = cutlass::half_t;
using LayoutInputA = cutlass::layout::TensorNHWC;
using LayoutInputB = cutlass::layout::TensorNHWC;
using LayoutOutput = cutlass::layout::TensorNHWC;
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU
// SM
using MMAOp = cutlass::arch::OpClassSimt;
// This code section describes CUDA SM architecture number
using SmArch = cutlass::arch::Sm60;
// This code section describes the groups a thread block will compute
constexpr int groups_per_cta = 64;
// This code section describes the output tile <N, P, Q, C> a thread block will compute
using ThreadBlockOutputShape = cutlass::conv::TensorNHWCShape<1, 8, 8, groups_per_cta>;
// This code section describes the filter shape <R, S>
using FilterShape = cutlass::MatrixShape<5, 5>;
// Threadblock tile shape
using ThreadblockShape =
cutlass::gemm::GemmShape<ThreadBlockOutputShape::kNHW, groups_per_cta, FilterShape::kCount>;
// This code section describes tile size a warp will computes
using WarpShape = cutlass::gemm::GemmShape<16, groups_per_cta, FilterShape::kCount>;
// This code section describes the size of MMA op
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock =
cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle<
1,
ThreadBlockOutputShape::kN,
ThreadBlockOutputShape::kH,
ThreadBlockOutputShape::kW>;
// Number of pipelines you want to use
constexpr int NumStages = 3;
// This code section describe iterator algorithm selected is Analytic or Optimized
static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm =
cutlass::conv::IteratorAlgorithm::kOptimized;
constexpr int kEpilogueElementsPerAccess = 128 / cutlass::sizeof_bits<ElementOutput>::value;
// This code section describes the epilogue part of the kernel, we use default value
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput, // Data type of output matrix.
kEpilogueElementsPerAccess, // The number of elements per vectorized.
// memory access. This becomes the vector width of
// math instructions in the epilogue too.
ElementAccumulator, // Data type of accumulator
ElementComputeEpilogue, // Data type for alpha/beta in linear combination
cutlass::epilogue::thread::ScaleType::Default>;
using DepthwiseDirect2dConv = typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConvFprop<
ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ThreadblockShape,
ThreadBlockOutputShape,
FilterShape,
WarpShape,
InstructionShape,
EpilogueOp,
SwizzleThreadBlock,
NumStages,
cutlass::arch::OpMultiplyAdd,
IteratorAlgorithm,
cutlass::conv::StrideSupport::kStrided>::Kernel;
using Direct2dConv = cutlass::conv::device::DirectConvolution<DepthwiseDirect2dConv>;
/// Run all unit test sizes with device-level Conv2d instance
EXPECT_TRUE(test::conv::device::TestSpecificDepthwiseDirectConv2d<Direct2dConv>(
DepthwiseFpropProblemSizes_filter5x5()));
}
////////////////////////////////////////////////////////////////////////////////
TEST(
SM60_Device_Depthwise_conv2d_Fprop_Direct_Conv_Optimized_f16nhwc_f16nhwc_f16nhwc_simt_f16,
64x32_3_16x32_5x37) {
using ElementInputA = cutlass::half_t;
using ElementInputB = cutlass::half_t;
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using ElementComputeEpilogue = cutlass::half_t;
using LayoutInputA = cutlass::layout::TensorNHWC;
using LayoutInputB = cutlass::layout::TensorNHWC;
using LayoutOutput = cutlass::layout::TensorNHWC;
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU
// SM
using MMAOp = cutlass::arch::OpClassSimt;
// This code section describes CUDA SM architecture number
using SmArch = cutlass::arch::Sm60;
// This code section describes the groups a thread block will compute
constexpr int groups_per_cta = 32;
// This code section describes the output tile <N, P, Q, C> a thread block will compute
using ThreadBlockOutputShape = cutlass::conv::TensorNHWCShape<1, 8, 8, groups_per_cta>;
// This code section describes the filter shape <R, S>
using FilterShape = cutlass::MatrixShape<5, 37>;
// Threadblock tile shape
using ThreadblockShape =
cutlass::gemm::GemmShape<ThreadBlockOutputShape::kNHW, groups_per_cta, FilterShape::kCount>;
// This code section describes tile size a warp will computes
using WarpShape = cutlass::gemm::GemmShape<16, groups_per_cta, FilterShape::kCount>;
// This code section describes the size of MMA op
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock =
cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle<
1,
ThreadBlockOutputShape::kN,
ThreadBlockOutputShape::kH,
ThreadBlockOutputShape::kW>;
// Number of pipelines you want to use
constexpr int NumStages = 2;
// This code section describe iterator algorithm selected is Analytic or Optimized
static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm =
cutlass::conv::IteratorAlgorithm::kOptimized;
constexpr int kEpilogueElementsPerAccess = 128 / cutlass::sizeof_bits<ElementOutput>::value;
// This code section describes the epilogue part of the kernel, we use default value
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput, // Data type of output matrix.
kEpilogueElementsPerAccess, // The number of elements per vectorized.
// memory access. This becomes the vector width of
// math instructions in the epilogue too.
ElementAccumulator, // Data type of accumulator
ElementComputeEpilogue, // Data type for alpha/beta in linear combination
cutlass::epilogue::thread::ScaleType::Default>;
using DepthwiseDirect2dConv = typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConvFprop<
ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ThreadblockShape,
ThreadBlockOutputShape,
FilterShape,
WarpShape,
InstructionShape,
EpilogueOp,
SwizzleThreadBlock,
NumStages,
cutlass::arch::OpMultiplyAdd,
IteratorAlgorithm,
cutlass::conv::StrideSupport::kStrided>::Kernel;
using Direct2dConv = cutlass::conv::device::DirectConvolution<DepthwiseDirect2dConv>;
/// Run all unit test sizes with device-level Conv2d instance
EXPECT_TRUE(test::conv::device::TestSpecificDepthwiseDirectConv2d<Direct2dConv>(
DepthwiseFpropProblemSizes_filter5x37()));
}

View File

@@ -0,0 +1,522 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide Depthwise Direct Conv interface
*/
#include "../../common/cutlass_unit_test.h"
#include "cutlass/cutlass.h"
#include "cutlass/conv/kernel/default_depthwise_fprop.h"
#include "cutlass/conv/device/direct_convolution.h"
#include "conv2d_testbed.h"
#include "depthwise_conv2d_direct_conv_testbed.h"
std::vector<cutlass::conv::Conv2dProblemSize> DepthwiseFpropProblemSizes_filter3x3_stride1x1_dilation1x1() {
std::vector<cutlass::conv::Conv2dProblemSize> problems;
for (int channels = 16; channels <= 512; channels += 16) {
problems.push_back(cutlass::conv::Conv2dProblemSize(
{1, 8, 8, channels}, // input size (NHWC)
{channels, 3, 3, 1}, // filter size (KRSC)
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
{1, 1}, // stride (stride_h, stride_w)
{1, 1}, // dilation (dilation_h, dilation_w)
cutlass::conv::Mode::kCrossCorrelation, // Convolution mode
16, // split_k_slices
channels // groups
));
}
return problems;
}
std::vector<cutlass::conv::Conv2dProblemSize> DepthwiseFpropProblemSizes_filter3x3_stride2x2_dilation2x2() {
std::vector<cutlass::conv::Conv2dProblemSize> problems;
for (int channels = 16; channels <= 512; channels += 16) {
problems.push_back(cutlass::conv::Conv2dProblemSize(
{1, 16, 16, channels}, // input size (NHWC)
{channels, 3, 3, 1}, // filter size (KRSC)
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
{2, 2}, // stride (stride_h, stride_w)
{2, 2}, // dilation (dilation_h, dilation_w)
cutlass::conv::Mode::kCrossCorrelation, // Convolution mode
16, // split_k_slices
channels // groups
));
}
return problems;
}
std::vector<cutlass::conv::Conv2dProblemSize> DepthwiseFpropProblemSizes_filter5x5_stride1x1_dilation1x1() {
std::vector<cutlass::conv::Conv2dProblemSize> problems;
for (int channels = 16; channels < 256; channels += 16) {
problems.push_back(cutlass::conv::Conv2dProblemSize(
{1, 16, 16, channels}, // input size (NHWC)
{channels, 5, 5, 1}, // filter size (KRSC)
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
{1, 1}, // stride (stride_h, stride_w)
{1, 1}, // dilation (dilation_h, dilation_w)
cutlass::conv::Mode::kCrossCorrelation, // Convolution mode
16, // split_k_slices
channels // groups
));
}
return problems;
}
std::vector<cutlass::conv::Conv2dProblemSize> DepthwiseFpropProblemSizes_filter5x5_stride2x2_dilation2x2() {
std::vector<cutlass::conv::Conv2dProblemSize> problems;
for (int channels = 16; channels < 256; channels += 16) {
problems.push_back(cutlass::conv::Conv2dProblemSize(
{1, 112, 112, channels}, // input size (NHWC)
{channels, 5, 5, 1}, // filter size (KRSC)
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
{2, 2}, // stride (stride_h, stride_w)
{2, 2}, // dilation (dilation_h, dilation_w)
cutlass::conv::Mode::kCrossCorrelation, // Convolution mode
16, // split_k_slices
channels // groups
));
}
return problems;
}
////////////////////////////////////////////////////////////////////////////////
TEST(
SM60_Device_Depthwise_conv2d_Fprop_Direct_Conv_FixedStrideDilation_f16nhwc_f16nhwc_f16nhwc_simt_f16,
64x32_4_8x32_Filter3x3_Stride1x1_Dilation1x1) {
using ElementInputA = cutlass::half_t;
using ElementInputB = cutlass::half_t;
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using ElementComputeEpilogue = cutlass::half_t;
using LayoutInputA = cutlass::layout::TensorNHWC;
using LayoutInputB = cutlass::layout::TensorNHWC;
using LayoutOutput = cutlass::layout::TensorNHWC;
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU
// SM
using MMAOp = cutlass::arch::OpClassSimt;
// This code section describes CUDA SM architecture number
using SmArch = cutlass::arch::Sm60;
// This code section describes the groups a thread block will compute
constexpr int groups_per_cta = 32;
// This code section describes the output tile <N, P, Q, C> a thread block will compute
using ThreadBlockOutputShape = cutlass::conv::TensorNHWCShape<1, 8, 8, groups_per_cta>;
// This code section describes the filter shape <R, S>
using FilterShape = cutlass::MatrixShape<3, 3>;
// Threadblock tile shape
using ThreadblockShape =
cutlass::gemm::GemmShape<ThreadBlockOutputShape::kNHW, groups_per_cta, FilterShape::kCount>;
// This code section describes tile size a warp will computes
using WarpShape = cutlass::gemm::GemmShape<8, groups_per_cta, FilterShape::kCount>;
// This code section describes the size of MMA op
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock =
cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle<
1,
ThreadBlockOutputShape::kN,
ThreadBlockOutputShape::kH,
ThreadBlockOutputShape::kW>;
// Number of pipelines you want to use
constexpr int NumStages = 4;
// This code section describe iterator algorithm selected is Analytic or Optimized
static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm =
cutlass::conv::IteratorAlgorithm::kFixedStrideDilation;
using StrideShape = cutlass::MatrixShape<1, 1>;
using DilationShape = cutlass::MatrixShape<1, 1>;
constexpr int kEpilogueElementsPerAccess = 128 / cutlass::sizeof_bits<ElementOutput>::value;
// This code section describes the epilogue part of the kernel, we use default value
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput, // Data type of output matrix.
kEpilogueElementsPerAccess, // The number of elements per vectorized.
// memory access. This becomes the vector width of
// math instructions in the epilogue too.
ElementAccumulator, // Data type of accumulator
ElementComputeEpilogue, // Data type for alpha/beta in linear combination
cutlass::epilogue::thread::ScaleType::Default>;
using DepthwiseDirect2dConv = typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConvFprop<
ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ThreadblockShape,
ThreadBlockOutputShape,
FilterShape,
WarpShape,
InstructionShape,
EpilogueOp,
SwizzleThreadBlock,
NumStages,
cutlass::arch::OpMultiplyAdd,
IteratorAlgorithm,
cutlass::conv::StrideSupport::kFixed,
StrideShape,
DilationShape>::Kernel;
using Direct2dConv = cutlass::conv::device::DirectConvolution<DepthwiseDirect2dConv>;
/// Run all unit test sizes with device-level Conv2d instance
EXPECT_TRUE(test::conv::device::TestSpecificDepthwiseDirectConv2d<Direct2dConv>(
DepthwiseFpropProblemSizes_filter3x3_stride1x1_dilation1x1()));
}
////////////////////////////////////////////////////////////////////////////////
TEST(
SM60_Device_Depthwise_conv2d_Fprop_Direct_Conv_FixedStrideDilation_f16nhwc_f16nhwc_f16nhwc_simt_f16,
64x32_4_8x32_Filter3x3_Stride2x2_Dilation2x2) {
using ElementInputA = cutlass::half_t;
using ElementInputB = cutlass::half_t;
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using ElementComputeEpilogue = cutlass::half_t;
using LayoutInputA = cutlass::layout::TensorNHWC;
using LayoutInputB = cutlass::layout::TensorNHWC;
using LayoutOutput = cutlass::layout::TensorNHWC;
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU
// SM
using MMAOp = cutlass::arch::OpClassSimt;
// This code section describes CUDA SM architecture number
using SmArch = cutlass::arch::Sm60;
// This code section describes the groups a thread block will compute
constexpr int groups_per_cta = 32;
// This code section describes the output tile <N, P, Q, C> a thread block will compute
using ThreadBlockOutputShape = cutlass::conv::TensorNHWCShape<1, 8, 8, groups_per_cta>;
// This code section describes the filter shape <R, S>
using FilterShape = cutlass::MatrixShape<3, 3>;
// Threadblock tile shape
using ThreadblockShape =
cutlass::gemm::GemmShape<ThreadBlockOutputShape::kNHW, groups_per_cta, FilterShape::kCount>;
// This code section describes tile size a warp will computes
using WarpShape = cutlass::gemm::GemmShape<8, groups_per_cta, FilterShape::kCount>;
// This code section describes the size of MMA op
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock =
cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle<
1,
ThreadBlockOutputShape::kN,
ThreadBlockOutputShape::kH,
ThreadBlockOutputShape::kW>;
// Number of pipelines you want to use
constexpr int NumStages = 4;
// This code section describe iterator algorithm selected is Analytic or Optimized
static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm =
cutlass::conv::IteratorAlgorithm::kFixedStrideDilation;
using StrideShape = cutlass::MatrixShape<2, 2>;
using DilationShape = cutlass::MatrixShape<2, 2>;
constexpr int kEpilogueElementsPerAccess = 128 / cutlass::sizeof_bits<ElementOutput>::value;
// This code section describes the epilogue part of the kernel, we use default value
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput, // Data type of output matrix.
kEpilogueElementsPerAccess, // The number of elements per vectorized.
// memory access. This becomes the vector width of
// math instructions in the epilogue too.
ElementAccumulator, // Data type of accumulator
ElementComputeEpilogue, // Data type for alpha/beta in linear combination
cutlass::epilogue::thread::ScaleType::Default>;
using DepthwiseDirect2dConv = typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConvFprop<
ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ThreadblockShape,
ThreadBlockOutputShape,
FilterShape,
WarpShape,
InstructionShape,
EpilogueOp,
SwizzleThreadBlock,
NumStages,
cutlass::arch::OpMultiplyAdd,
IteratorAlgorithm,
cutlass::conv::StrideSupport::kFixed,
StrideShape,
DilationShape>::Kernel;
using Direct2dConv = cutlass::conv::device::DirectConvolution<DepthwiseDirect2dConv>;
/// Run all unit test sizes with device-level Conv2d instance
EXPECT_TRUE(test::conv::device::TestSpecificDepthwiseDirectConv2d<Direct2dConv>(
DepthwiseFpropProblemSizes_filter3x3_stride2x2_dilation2x2()));
}
////////////////////////////////////////////////////////////////////////////////
TEST(
SM60_Device_Depthwise_conv2d_Fprop_Direct_Conv_FixedStrideDilation_f16nhwc_f16nhwc_f16nhwc_simt_f16,
64x64_3_16x64_Filter5x5_Stride1x1_Dilation1x1) {
using ElementInputA = cutlass::half_t;
using ElementInputB = cutlass::half_t;
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using ElementComputeEpilogue = cutlass::half_t;
using LayoutInputA = cutlass::layout::TensorNHWC;
using LayoutInputB = cutlass::layout::TensorNHWC;
using LayoutOutput = cutlass::layout::TensorNHWC;
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU
// SM
using MMAOp = cutlass::arch::OpClassSimt;
// This code section describes CUDA SM architecture number
using SmArch = cutlass::arch::Sm60;
// This code section describes the groups a thread block will compute
constexpr int groups_per_cta = 64;
// This code section describes the output tile <N, P, Q, C> a thread block will compute
using ThreadBlockOutputShape = cutlass::conv::TensorNHWCShape<1, 8, 8, groups_per_cta>;
// This code section describes the filter shape <R, S>
using FilterShape = cutlass::MatrixShape<5, 5>;
// Threadblock tile shape
using ThreadblockShape =
cutlass::gemm::GemmShape<ThreadBlockOutputShape::kNHW, groups_per_cta, FilterShape::kCount>;
// This code section describes tile size a warp will computes
using WarpShape = cutlass::gemm::GemmShape<16, groups_per_cta, FilterShape::kCount>;
// This code section describes the size of MMA op
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock =
cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle<
1,
ThreadBlockOutputShape::kN,
ThreadBlockOutputShape::kH,
ThreadBlockOutputShape::kW>;
// Number of pipelines you want to use
constexpr int NumStages = 3;
// This code section describe iterator algorithm selected is Analytic or Optimized
static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm =
cutlass::conv::IteratorAlgorithm::kFixedStrideDilation;
using StrideShape = cutlass::MatrixShape<1, 1>;
using DilationShape = cutlass::MatrixShape<1, 1>;
constexpr int kEpilogueElementsPerAccess = 128 / cutlass::sizeof_bits<ElementOutput>::value;
// This code section describes the epilogue part of the kernel, we use default value
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput, // Data type of output matrix.
kEpilogueElementsPerAccess, // The number of elements per vectorized.
// memory access. This becomes the vector width of
// math instructions in the epilogue too.
ElementAccumulator, // Data type of accumulator
ElementComputeEpilogue, // Data type for alpha/beta in linear combination
cutlass::epilogue::thread::ScaleType::Default>;
using DepthwiseDirect2dConv = typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConvFprop<
ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ThreadblockShape,
ThreadBlockOutputShape,
FilterShape,
WarpShape,
InstructionShape,
EpilogueOp,
SwizzleThreadBlock,
NumStages,
cutlass::arch::OpMultiplyAdd,
IteratorAlgorithm,
cutlass::conv::StrideSupport::kFixed,
StrideShape,
DilationShape>::Kernel;
using Direct2dConv = cutlass::conv::device::DirectConvolution<DepthwiseDirect2dConv>;
/// Run all unit test sizes with device-level Conv2d instance
EXPECT_TRUE(test::conv::device::TestSpecificDepthwiseDirectConv2d<Direct2dConv>(
DepthwiseFpropProblemSizes_filter5x5_stride1x1_dilation1x1()));
}
////////////////////////////////////////////////////////////////////////////////
TEST(
SM60_Device_Depthwise_conv2d_Fprop_Direct_Conv_FixedStrideDilation_f16nhwc_f16nhwc_f16nhwc_simt_f16,
64x64_3_16x64_Filter5x5_Stride2x2_Dilation2x2) {
using ElementInputA = cutlass::half_t;
using ElementInputB = cutlass::half_t;
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using ElementComputeEpilogue = cutlass::half_t;
using LayoutInputA = cutlass::layout::TensorNHWC;
using LayoutInputB = cutlass::layout::TensorNHWC;
using LayoutOutput = cutlass::layout::TensorNHWC;
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU
// SM
using MMAOp = cutlass::arch::OpClassSimt;
// This code section describes CUDA SM architecture number
using SmArch = cutlass::arch::Sm60;
// This code section describes the groups a thread block will compute
constexpr int groups_per_cta = 32;
// This code section describes the output tile <N, P, Q, C> a thread block will compute
using ThreadBlockOutputShape = cutlass::conv::TensorNHWCShape<1, 8, 8, groups_per_cta>;
// This code section describes the filter shape <R, S>
using FilterShape = cutlass::MatrixShape<5, 5>;
// Threadblock tile shape
using ThreadblockShape =
cutlass::gemm::GemmShape<ThreadBlockOutputShape::kNHW, groups_per_cta, FilterShape::kCount>;
// This code section describes tile size a warp will computes
using WarpShape = cutlass::gemm::GemmShape<16, groups_per_cta, FilterShape::kCount>;
// This code section describes the size of MMA op
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock =
cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle<
1,
ThreadBlockOutputShape::kN,
ThreadBlockOutputShape::kH,
ThreadBlockOutputShape::kW>;
// Number of pipelines you want to use
constexpr int NumStages = 3;
// This code section describe iterator algorithm selected is Analytic or Optimized
static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm =
cutlass::conv::IteratorAlgorithm::kFixedStrideDilation;
using StrideShape = cutlass::MatrixShape<2, 2>;
using DilationShape = cutlass::MatrixShape<2, 2>;
constexpr int kEpilogueElementsPerAccess = 128 / cutlass::sizeof_bits<ElementOutput>::value;
// This code section describes the epilogue part of the kernel, we use default value
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput, // Data type of output matrix.
kEpilogueElementsPerAccess, // The number of elements per vectorized.
// memory access. This becomes the vector width of
// math instructions in the epilogue too.
ElementAccumulator, // Data type of accumulator
ElementComputeEpilogue, // Data type for alpha/beta in linear combination
cutlass::epilogue::thread::ScaleType::Default>;
using DepthwiseDirect2dConv = typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConvFprop<
ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ThreadblockShape,
ThreadBlockOutputShape,
FilterShape,
WarpShape,
InstructionShape,
EpilogueOp,
SwizzleThreadBlock,
NumStages,
cutlass::arch::OpMultiplyAdd,
IteratorAlgorithm,
cutlass::conv::StrideSupport::kFixed,
StrideShape,
DilationShape>::Kernel;
using Direct2dConv = cutlass::conv::device::DirectConvolution<DepthwiseDirect2dConv>;
/// Run all unit test sizes with device-level Conv2d instance
EXPECT_TRUE(test::conv::device::TestSpecificDepthwiseDirectConv2d<Direct2dConv>(
DepthwiseFpropProblemSizes_filter5x5_stride2x2_dilation2x2()));
}

View File

@@ -29,7 +29,7 @@
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide Implicit GEMM interface
\brief Tests for Depthwise Direct Conv interface
*/
#include "../../common/cutlass_unit_test.h"

View File

@@ -241,6 +241,155 @@ TEST(SM80_Device_Conv2d_Group_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhw
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_Conv2d_Group_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32,
SingleGroupPerCTA_128x128_64x3_64x64x64) {
/// Conv operation element types for the Gemm equivalent (ImplicitGemm)
using ElementA = cutlass::half_t;
using ElementB = cutlass::half_t;
using ElementC = cutlass::half_t;
using ElementAccumulator = float;
using ElementCompute = float;
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
/// Device-level Conv2d instance
using Conv2dGroupFpropKernel = typename cutlass::conv::kernel::DefaultConv2dGroupFprop<
ElementA, cutlass::layout::TensorNHWC,
ElementB, cutlass::layout::TensorNHWC,
ElementC, cutlass::layout::TensorNHWC,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
ThreadblockShape,
WarpShape,
InstructionShape,
cutlass::epilogue::thread::LinearCombination<
ElementC,
128 / cutlass::sizeof_bits<ElementC>::value,
ElementAccumulator,
ElementCompute
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3,
cutlass::arch::OpMultiplyAdd,
cutlass::conv::GroupMode::kSingleGroup,
cutlass::conv::IteratorAlgorithm::kOptimized
>::Kernel;
using Conv2dGroupFprop = cutlass::conv::device::ImplicitGemmConvolution<Conv2dGroupFpropKernel>;
/// Run group conv unit test sizes with device-level Conv2d instance
test::conv::device::TestbedGroupConv2dProblemSizes problem_sizes(
ThreadblockShape::kN, ThreadblockShape::kK,
128/cutlass::sizeof_bits<ElementA>::value
);
EXPECT_TRUE(test::conv::device::TestSpecificConv2d<Conv2dGroupFprop>(problem_sizes.default_single_group_sizes));
}
////////////////////////////////////////////////////////////////////////////////
// Optimized multistage singleGroup kernel
TEST(SM80_Device_Conv2d_Group_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32,
SingleGroupPerCTA_64x64_64x3_32x32x64) {
/// Conv operation element types for the Gemm equivalent (ImplicitGemm)
using ElementA = cutlass::half_t;
using ElementB = cutlass::half_t;
using ElementC = cutlass::half_t;
using ElementAccumulator = float;
using ElementCompute = float;
using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>;
using WarpShape = cutlass::gemm::GemmShape<32, 32, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
/// Device-level Conv2d instance
using Conv2dGroupFpropKernel = typename cutlass::conv::kernel::DefaultConv2dGroupFprop<
ElementA, cutlass::layout::TensorNHWC,
ElementB, cutlass::layout::TensorNHWC,
ElementC, cutlass::layout::TensorNHWC,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
ThreadblockShape,
WarpShape,
InstructionShape,
cutlass::epilogue::thread::LinearCombination<
ElementC,
128 / cutlass::sizeof_bits<ElementC>::value,
ElementAccumulator,
ElementCompute
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3,
cutlass::arch::OpMultiplyAdd,
cutlass::conv::GroupMode::kSingleGroup,
cutlass::conv::IteratorAlgorithm::kOptimized
>::Kernel;
using Conv2dGroupFprop = cutlass::conv::device::ImplicitGemmConvolution<Conv2dGroupFpropKernel>;
/// Run group conv unit test sizes with device-level Conv2d instance
test::conv::device::TestbedGroupConv2dProblemSizes problem_sizes(
ThreadblockShape::kN, ThreadblockShape::kK,
128/cutlass::sizeof_bits<ElementA>::value
);
EXPECT_TRUE(test::conv::device::TestSpecificConv2d<Conv2dGroupFprop>(problem_sizes.default_single_group_sizes));
}
////////////////////////////////////////////////////////////////////////////////
// Optimized 2 stage singleGroup kernel
TEST(SM80_Device_Conv2d_Group_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32,
SingleGroupPerCTA_64x64_64x2_32x32x64) {
/// Conv operation element types for the Gemm equivalent (ImplicitGemm)
using ElementA = cutlass::half_t;
using ElementB = cutlass::half_t;
using ElementC = float;
using ElementAccumulator = float;
using ElementCompute = float;
using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>;
using WarpShape = cutlass::gemm::GemmShape<32, 32, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
/// Device-level Conv2d instance
using Conv2dGroupFpropKernel = typename cutlass::conv::kernel::DefaultConv2dGroupFprop<
ElementA, cutlass::layout::TensorNHWC,
ElementB, cutlass::layout::TensorNHWC,
ElementC, cutlass::layout::TensorNHWC,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
ThreadblockShape,
WarpShape,
InstructionShape,
cutlass::epilogue::thread::LinearCombination<
ElementC,
128 / cutlass::sizeof_bits<ElementC>::value,
ElementAccumulator,
ElementCompute
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2,
cutlass::arch::OpMultiplyAdd,
cutlass::conv::GroupMode::kSingleGroup,
cutlass::conv::IteratorAlgorithm::kOptimized
>::Kernel;
using Conv2dGroupFprop = cutlass::conv::device::ImplicitGemmConvolution<Conv2dGroupFpropKernel>;
/// Run group conv unit test sizes with device-level Conv2d instance
test::conv::device::TestbedGroupConv2dProblemSizes problem_sizes(
ThreadblockShape::kN, ThreadblockShape::kK,
128/cutlass::sizeof_bits<ElementA>::value
);
EXPECT_TRUE(test::conv::device::TestSpecificConv2d<Conv2dGroupFprop>(problem_sizes.default_single_group_sizes));
}
////////////////////////////////////////////////////////////////////////////////
#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED
////////////////////////////////////////////////////////////////////////////////

View File

@@ -31,6 +31,7 @@ cutlass_test_unit_add_executable(
array.cu
half.cu
bfloat16.cu
float8.cu
tfloat32.cu
complex.cu
quaternion.cu

103
test/unit/core/float8.cu Normal file
View File

@@ -0,0 +1,103 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for basic float8 functionality
*/
#include "../common/cutlass_unit_test.h"
#include "cutlass/numeric_types.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(float_e4m3_t, host_conversion) {
for (int i = -8; i < 8; ++i) {
float f = static_cast<float>(i);
cutlass::float_e4m3_t x = static_cast<cutlass::float_e4m3_t>(i);
cutlass::float_e4m3_t y = static_cast<cutlass::float_e4m3_t>(f);
EXPECT_TRUE(static_cast<int>(x) == i);
EXPECT_TRUE(static_cast<float>(y) == f);
}
// Try out default-ctor (zero initialization of primitive proxy type)
EXPECT_TRUE(cutlass::float_e4m3_t() == 0.0_fe4m3);
// Try out user-defined literals
EXPECT_TRUE(cutlass::float_e4m3_t(7) == 7_fe4m3);
EXPECT_TRUE(7 == static_cast<int>(7_fe4m3));
}
TEST(float_e5m2_t, host_conversion) {
for (int i = -8; i < 8; ++i) {
float f = static_cast<float>(i);
cutlass::float_e5m2_t x = static_cast<cutlass::float_e5m2_t>(i);
cutlass::float_e5m2_t y = static_cast<cutlass::float_e5m2_t>(f);
EXPECT_TRUE(static_cast<int>(x) == i);
EXPECT_TRUE(static_cast<float>(y) == f);
}
// Try out default-ctor (zero initialization of primitive proxy type)
EXPECT_TRUE(cutlass::float_e5m2_t() == 0.0_fe5m2);
// Try out user-defined literals
EXPECT_TRUE(cutlass::float_e5m2_t(7) == 7_fe5m2);
EXPECT_TRUE(7 == static_cast<int>(7_fe5m2));
}
TEST(float_e4m3_t, host_arithmetic) {
for (int i = -4; i < 4; ++i) {
for (int j = -4; j < 4; ++j) {
cutlass::float_e4m3_t x = static_cast<cutlass::float_e4m3_t>(i);
cutlass::float_e4m3_t y = static_cast<cutlass::float_e4m3_t>(j);
EXPECT_TRUE(static_cast<int>(x + y) == (i + j));
}
}
}
TEST(float_e5m2_t, host_arithmetic) {
for (int i = -4; i < 4; ++i) {
for (int j = -4; j < 4; ++j) {
cutlass::float_e5m2_t x = static_cast<cutlass::float_e5m2_t>(i);
cutlass::float_e5m2_t y = static_cast<cutlass::float_e5m2_t>(j);
EXPECT_TRUE(static_cast<int>(x + y) == (i + j));
}
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -47,10 +47,10 @@ namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Conversion template
/// Simple conversion function
template <typename Destination, typename Source, int Count>
__global__ void convert(
cutlass::Array<Destination, Count> *destination,
cutlass::Array<Destination, Count> *destination,
cutlass::Array<Source, Count> const *source) {
cutlass::NumericArrayConverter<Destination, Source, Count> convert;
@@ -60,47 +60,9 @@ __global__ void convert(
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace core
} // namespace test
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(NumericConversion, f32_to_f16_rn) {
int const kN = 1;
using Source = float;
using Destination = cutlass::half_t;
dim3 grid(1, 1);
dim3 block(1, 1);
cutlass::HostTensor<cutlass::half_t, cutlass::layout::RowMajor> destination({1, kN});
cutlass::HostTensor<float, cutlass::layout::RowMajor> source({1, kN});
for (int i = 0; i < kN; ++i) {
source.host_data()[i] = float(i);
}
source.sync_device();
test::core::kernel::convert<Destination, Source, 1><<< grid, block >>>(
reinterpret_cast<cutlass::Array<Destination, 1> *>(destination.device_data()),
reinterpret_cast<cutlass::Array<Source, 1> const *>(source.device_data())
);
destination.sync_host();
for (int i = 0; i < kN; ++i) {
EXPECT_TRUE(float(destination.host_data()[i]) == source.host_data()[i]);
}
}
TEST(NumericConversion, f32x8_to_f16x8_rn) {
int const kN = 8;
using Source = float;
using Destination = cutlass::half_t;
template <typename Destination, typename Source, int Count>
void run_test() {
const int kN = Count;
dim3 grid(1, 1);
dim3 block(1, 1);
@@ -109,12 +71,12 @@ TEST(NumericConversion, f32x8_to_f16x8_rn) {
cutlass::HostTensor<Source, cutlass::layout::RowMajor> source({1, kN});
for (int i = 0; i < kN; ++i) {
source.host_data()[i] = float(i);
source.host_data()[i] = Source(i % 4);
}
source.sync_device();
test::core::kernel::convert<Destination, Source, kN><<< grid, block >>>(
convert<Destination, Source, kN><<< grid, block >>>(
reinterpret_cast<cutlass::Array<Destination, kN> *>(destination.device_data()),
reinterpret_cast<cutlass::Array<Source, kN> const *>(source.device_data())
);
@@ -122,70 +84,247 @@ TEST(NumericConversion, f32x8_to_f16x8_rn) {
destination.sync_host();
for (int i = 0; i < kN; ++i) {
EXPECT_TRUE(float(destination.host_data()[i]) == source.host_data()[i]);
EXPECT_TRUE(float(destination.host_data()[i]) == float(source.host_data()[i]));
}
}
} // namespace kernel
} // namespace core
} // namespace test
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(NumericConversion, f32_to_f16_rn) {
int const kN = 1;
using Source = float;
using Destination = cutlass::half_t;
test::core::kernel::run_test<Destination, Source, kN>();
}
TEST(NumericConversion, f32x8_to_f16x8_rn) {
int const kN = 8;
using Source = float;
using Destination = cutlass::half_t;
test::core::kernel::run_test<Destination, Source, kN>();
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(NumericConversion, f16_to_f32_rn) {
TEST(NumericConversion, f16_to_f32_rn) {
int const kN = 1;
using Source = cutlass::half_t;
using Destination = float;
dim3 grid(1, 1);
dim3 block(1, 1);
cutlass::HostTensor<float, cutlass::layout::RowMajor> destination({1, kN});
cutlass::HostTensor<cutlass::half_t, cutlass::layout::RowMajor> source({1, kN});
for (int i = 0; i < kN; ++i) {
source.host_data()[i] = Source(i);
}
source.sync_device();
test::core::kernel::convert<Destination, Source, kN><<< grid, block >>>(
reinterpret_cast<cutlass::Array<Destination, kN> *>(destination.device_data()),
reinterpret_cast<cutlass::Array<Source, kN> const *>(source.device_data())
);
destination.sync_host();
for (int i = 0; i < kN; ++i) {
EXPECT_TRUE(float(destination.host_data()[i]) == float(source.host_data()[i]));
}
test::core::kernel::run_test<Destination, Source, kN>();
}
TEST(NumericConversion, f16x8_to_f32x8_rn) {
int const kN = 8;
using Source = cutlass::half_t;
using Destination = float;
test::core::kernel::run_test<Destination, Source, kN>();
}
dim3 grid(1, 1);
dim3 block(1, 1);
/////////////////////////////////////////////////////////////////////////////////////////////////
cutlass::HostTensor<float, cutlass::layout::RowMajor> destination({1, kN});
cutlass::HostTensor<cutlass::half_t, cutlass::layout::RowMajor> source({1, kN});
TEST(NumericConversion, f32_to_fe4m3_rn) {
int const kN = 1;
using Source = float;
using Destination = cutlass::float_e4m3_t;
test::core::kernel::run_test<Destination, Source, kN>();
}
for (int i = 0; i < kN; ++i) {
source.host_data()[i] = float(i);
}
TEST(NumericConversion, f32_to_fe4m3_rn_array) {
int const kN = 27;
using Source = float;
using Destination = cutlass::float_e4m3_t;
source.sync_device();
test::core::kernel::run_test<Destination, Source, kN>();
}
test::core::kernel::convert<Destination, Source, kN><<< grid, block >>>(
reinterpret_cast<cutlass::Array<Destination, kN> *>(destination.device_data()),
reinterpret_cast<cutlass::Array<Source, kN> const *>(source.device_data())
);
TEST(NumericConversion, f32_to_fe5m2_rn) {
int const kN = 1;
using Source = float;
using Destination = cutlass::float_e5m2_t;
test::core::kernel::run_test<Destination, Source, kN>();
}
destination.sync_host();
TEST(NumericConversion, f32_to_fe5m2_rn_array) {
int const kN = 27;
using Source = float;
using Destination = cutlass::float_e5m2_t;
for (int i = 0; i < kN; ++i) {
EXPECT_TRUE(float(destination.host_data()[i]) == float(source.host_data()[i]));
}
test::core::kernel::run_test<Destination, Source, kN>();
}
TEST(NumericConversion, f16_to_fe4m3_rn) {
int const kN = 1;
using Source = cutlass::half_t;
using Destination = cutlass::float_e4m3_t;
test::core::kernel::run_test<Destination, Source, kN>();
}
TEST(NumericConversion, f16_to_fe4m3_rn_array) {
int const kN = 27;
using Source = cutlass::half_t;
using Destination = cutlass::float_e4m3_t;
test::core::kernel::run_test<Destination, Source, kN>();
}
TEST(NumericConversion, f16_to_fe5m2_rn) {
int const kN = 1;
using Source = cutlass::half_t;
using Destination = cutlass::float_e5m2_t;
test::core::kernel::run_test<Destination, Source, kN>();
}
TEST(NumericConversion, f16_to_fe5m2_rn_array) {
int const kN = 27;
using Source = cutlass::half_t;
using Destination = cutlass::float_e5m2_t;
test::core::kernel::run_test<Destination, Source, kN>();
}
TEST(NumericConversion, bf16_to_fe4m3_rn) {
int const kN = 1;
using Source = cutlass::bfloat16_t;
using Destination = cutlass::float_e4m3_t;
test::core::kernel::run_test<Destination, Source, kN>();
}
TEST(NumericConversion, bf16_to_fe4m3_rn_array) {
int const kN = 27;
using Source = cutlass::bfloat16_t;
using Destination = cutlass::float_e4m3_t;
test::core::kernel::run_test<Destination, Source, kN>();
}
TEST(NumericConversion, bf16_to_fe5m2_rn) {
int const kN = 1;
using Source = cutlass::bfloat16_t;
using Destination = cutlass::float_e5m2_t;
test::core::kernel::run_test<Destination, Source, kN>();
}
TEST(NumericConversion, bf16_to_fe5m2_rn_array) {
int const kN = 27;
using Source = cutlass::bfloat16_t;
using Destination = cutlass::float_e5m2_t;
test::core::kernel::run_test<Destination, Source, kN>();
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(NumericConversion, fe4m3_to_fe5m2_rn) {
int const kN = 1;
using Source = cutlass::float_e4m3_t;
using Destination = cutlass::float_e5m2_t;
test::core::kernel::run_test<Destination, Source, kN>();
}
TEST(NumericConversion, fe4m3_to_fe5m2_array) {
int const kN = 27;
using Source = cutlass::float_e4m3_t;
using Destination = cutlass::float_e5m2_t;
test::core::kernel::run_test<Destination, Source, kN>();
}
TEST(NumericConversion, fe5m2_to_fe4m3_rn) {
int const kN = 1;
using Source = cutlass::float_e5m2_t;
using Destination = cutlass::float_e4m3_t;
test::core::kernel::run_test<Destination, Source, kN>();
}
TEST(NumericConversion, fe5m2_to_fe4m3_array) {
int const kN = 27;
using Source = cutlass::float_e5m2_t;
using Destination = cutlass::float_e4m3_t;
test::core::kernel::run_test<Destination, Source, kN>();
}
TEST(NumericConversion, fe4m3_to_f32_rn) {
int const kN = 1;
using Source = cutlass::float_e4m3_t;
using Destination = float;
test::core::kernel::run_test<Destination, Source, kN>();
}
TEST(NumericConversion, fe4m3_to_f32_array) {
int const kN = 27;
using Source = cutlass::float_e4m3_t;
using Destination = float;
test::core::kernel::run_test<Destination, Source, kN>();
}
TEST(NumericConversion, fe5m2_to_f32_rn) {
int const kN = 1;
using Source = cutlass::float_e5m2_t;
using Destination = float;
test::core::kernel::run_test<Destination, Source, kN>();
}
TEST(NumericConversion, fe5m2_to_f32_array) {
int const kN = 27;
using Source = cutlass::float_e5m2_t;
using Destination = float;
test::core::kernel::run_test<Destination, Source, kN>();
}
TEST(NumericConversion, fe4m3_to_f16_rn) {
int const kN = 1;
using Source = cutlass::float_e4m3_t;
using Destination = cutlass::half_t;
test::core::kernel::run_test<Destination, Source, kN>();
}
TEST(NumericConversion, fe4m3_to_f16_array) {
int const kN = 27;
using Source = cutlass::float_e4m3_t;
using Destination = cutlass::half_t;
test::core::kernel::run_test<Destination, Source, kN>();
}
TEST(NumericConversion, fe5m2_to_f16_rn) {
int const kN = 1;
using Source = cutlass::float_e5m2_t;
using Destination = cutlass::half_t;
test::core::kernel::run_test<Destination, Source, kN>();
}
TEST(NumericConversion, fe5m2_to_f16_array) {
int const kN = 27;
using Source = cutlass::float_e5m2_t;
using Destination = cutlass::half_t;
test::core::kernel::run_test<Destination, Source, kN>();
}
TEST(NumericConversion, fe4m3_to_bf16_rn) {
int const kN = 1;
using Source = cutlass::float_e4m3_t;
using Destination = cutlass::bfloat16_t;
test::core::kernel::run_test<Destination, Source, kN>();
}
TEST(NumericConversion, fe4m3_to_bf16_array) {
int const kN = 27;
using Source = cutlass::float_e4m3_t;
using Destination = cutlass::bfloat16_t;
test::core::kernel::run_test<Destination, Source, kN>();
}
TEST(NumericConversion, fe5m2_to_bf16_rn) {
int const kN = 1;
using Source = cutlass::float_e5m2_t;
using Destination = cutlass::bfloat16_t;
test::core::kernel::run_test<Destination, Source, kN>();
}
TEST(NumericConversion, fe5m2_to_bf16_array) {
int const kN = 27;
using Source = cutlass::float_e5m2_t;
using Destination = cutlass::bfloat16_t;
test::core::kernel::run_test<Destination, Source, kN>();
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -34,6 +34,7 @@
#include "../../common/cutlass_unit_test.h"
#include "cutlass/layout/layout.h"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/util/host_tensor.h"

View File

@@ -68,7 +68,12 @@ cutlass_test_unit_add_executable(
simt_sgemm_nt_sm80.cu
simt_sgemm_tn_sm80.cu
simt_cgemm_nt_sm80.cu
simt_cgemm_tn_sm80.cu
simt_f8gemm_tn_sm50.cu
simt_cgemm_nn_sm50.cu
simt_cgemm_nt_sm50.cu
simt_cgemm_tn_sm50.cu
@@ -239,6 +244,13 @@ cutlass_test_unit_add_executable(
gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm80.cu
gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian_sm80.cu
# SM90 device level tests
gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu
gemm_f64t_f64n_f64t_tensor_op_f64_sm90.cu
gemm_cf64n_cf64t_cf64t_tensor_op_f64_sm90.cu
gemm_cf64t_cf64n_cf64t_tensor_op_f64_sm90.cu
gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm90.cu
gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian_sm90.cu
)
cutlass_test_unit_add_executable(
@@ -430,7 +442,7 @@ cutlass_test_unit_add_executable(
BATCH_SOURCES ON
BATCH_SIZE 4
## SYRK
## SYRK
# Syrk SM80 f64 tests
syrk_f64n_f64t_tensor_op_f64_sm80.cu
syrk_f64t_f64n_tensor_op_f64_sm80.cu
@@ -452,6 +464,12 @@ cutlass_test_unit_add_executable(
syrk_cf32n_cf32t_tensor_op_fast_f32_sm80.cu
syrk_cf32n_cf32n_tensor_op_fast_f32_sm80.cu
# Syrk SM90 f64 tests
syrk_f64_f64_tensor_op_f64_sm90.cu
# Syrk SM90 complex f64 tests
syrk_cf64_cf64_tensor_op_f64_sm90.cu
## HERK
# Herk SM80 complex f64 tests
herk_cf64h_cf64n_tensor_op_f64_sm80.cu
@@ -460,6 +478,9 @@ cutlass_test_unit_add_executable(
herk_cf32h_cf32n_tensor_op_f32_sm80.cu
herk_cf32h_cf32n_tensor_op_fast_f32_sm80.cu
# Herk SM90 complex f64 tests
herk_cf64_cf64_tensor_op_f64_sm90.cu
## TRMM
# Trmm SM80 f64 tests
trmm_f64n_f64n_f64t_tensor_op_f64_ls_sm80.cu
@@ -486,6 +507,12 @@ cutlass_test_unit_add_executable(
trmm_cf32n_cf32n_cf32t_tensor_op_f32_sm80.cu
trmm_cf32n_cf32n_cf32t_tensor_op_fast_f32_sm80.cu
# Trmm SM90 f64 tests
trmm_f64_f64_f64_tensor_op_f64_sm90.cu
# Trmm SM90 complex f64 tests
trmm_cf64_cf64_cf64_tensor_op_f64_sm90.cu
## SYR2K
# Syr2k SM80 f64 tests
syr2k_f64n_f64t_tensor_op_f64_sm80.cu
@@ -508,6 +535,12 @@ cutlass_test_unit_add_executable(
syr2k_cf32n_cf32n_tensor_op_fast_f32_sm80.cu
syr2k_cf32n_cf32t_tensor_op_fast_f32_sm80.cu
# Syr2k SM90 f64 tests
syr2k_f64_f64_tensor_op_f64_sm90.cu
# Syr2k SM90 complex f64 tests
syr2k_cf64_cf64_tensor_op_f64_sm90.cu
## HER2K
# Her2k SM80 complex f64 tests
her2k_cf64n_cf64n_tensor_op_f64_sm80.cu
@@ -516,6 +549,9 @@ cutlass_test_unit_add_executable(
her2k_cf32h_cf32n_tensor_op_f32_sm80.cu
her2k_cf32h_cf32n_tensor_op_fast_f32_sm80.cu
# Her2k SM90 complex f64 tests
her2k_cf64_cf64_tensor_op_f64_sm90.cu
## SYMM
# Symm SM80 f64 tests
symm_f64n_f64n_tensor_op_f64_ls_sm80.cu
@@ -546,6 +582,12 @@ cutlass_test_unit_add_executable(
symm_cf32n_cf32n_tensor_op_fast_f32_ls_sm80.cu
symm_cf32n_cf32n_tensor_op_fast_f32_rs_sm80.cu
# Symm SM90 f64 tests
symm_f64_f64_tensor_op_f64_sm90.cu
# Symm SM90 complex f64 tests
symm_cf64_cf64_cf64_tensor_op_f64_sm90.cu
# Hemm SM80 complex f64 tests
hemm_cf64h_cf64n_cf64n_tensor_op_ls_f64_sm80.cu
hemm_cf64h_cf64n_cf64n_tensor_op_rs_f64_sm80.cu
@@ -556,7 +598,10 @@ cutlass_test_unit_add_executable(
hemm_cf32h_cf32n_tensor_op_f32_rs_sm80.cu
hemm_cf32h_cf32n_tensor_op_fast_f32_ls_sm80.cu
hemm_cf32h_cf32n_tensor_op_fast_f32_rs_sm80.cu
)
# Hemm SM90 complex f64 tests
hemm_cf64_cf64_cf64_tensor_op_f64_sm90.cu
)
cutlass_test_unit_add_executable(
cutlass_test_unit_gemm_device_grouped_blas3
@@ -582,3 +627,13 @@ cutlass_test_unit_add_executable(
)
endif()
if (NOT CUDA_COMPILER MATCHES "[Cc]lang")
cutlass_test_unit_add_executable(
cutlass_test_unit_gemm_device_broadcast
gemm_f16t_f16n_f16t_tensor_op_f16_broadcast_sm80.cu
)
endif()

View File

@@ -71,7 +71,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 128x256x512_64x64x512) {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 256x128x512_64x64x512) {
@@ -93,7 +93,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 256x128x512_64x64x512) {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 128x128x512_64x64x512) {
@@ -115,7 +115,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 128x128x512_64x64x512) {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 64x256x512_64x64x512) {
@@ -137,7 +137,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 64x256x512_64x64x512) {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 256x64x512_64x64x512) {
@@ -159,7 +159,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 256x64x512_64x64x512) {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 64x128x512_32x64x512) {
@@ -181,7 +181,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 64x128x512_32x64x512) {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 128x64x512_64x32x512) {
@@ -203,7 +203,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 128x64x512_64x32x512) {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 64x64x512_32x32x512) {
@@ -225,7 +225,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 64x64x512_32x32x512) {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -70,7 +70,7 @@ TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x256x1024_64x64x1024) {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 256x128x1024_64x64x1024) {
@@ -90,7 +90,7 @@ TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 256x128x1024_64x64x1024) {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x128x1024_64x64x1024) {
@@ -111,7 +111,7 @@ TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x128x1024_64x64x1024) {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 256x64x1024_64x64x1024) {
@@ -131,7 +131,7 @@ TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 256x64x1024_64x64x1024) {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x256x1024_64x64x1024) {
@@ -151,7 +151,7 @@ TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x256x1024_64x64x1024) {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x128x1024_32x64x1024) {
@@ -171,7 +171,7 @@ TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x128x1024_32x64x1024) {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x64x1024_64x32x1024) {
@@ -191,7 +191,7 @@ TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x64x1024_64x32x1024) {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x64x1024_32x32x1024) {
@@ -211,7 +211,7 @@ TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x64x1024_32x32x1024) {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x256x512_64x64x512) {
@@ -231,7 +231,7 @@ TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x256x512_64x64x512) {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 256x128x512_64x64x512) {
@@ -251,7 +251,7 @@ TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 256x128x512_64x64x512) {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x128x512_64x64x512) {
@@ -271,7 +271,7 @@ TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x128x512_64x64x512) {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 256x64x512_64x64x512) {
@@ -291,7 +291,7 @@ TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 256x64x512_64x64x512) {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x256x512_64x64x512) {
@@ -311,7 +311,7 @@ TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x256x512_64x64x512) {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x128x512_32x64x512) {
@@ -331,7 +331,7 @@ TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x128x512_32x64x512) {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x64x512_64x32x512) {
@@ -351,7 +351,7 @@ TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x64x512_64x32x512) {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x64x512_32x32x512) {
@@ -371,7 +371,7 @@ TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x64x512_32x32x512) {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
////////////////////////////////////////////////////////////////////////////////

View File

@@ -83,7 +83,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 128x256x512_64x64x512_8x8
cutlass::arch::OpXorPopc
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 256x128x512_64x64x512_8x8x128) {
@@ -114,7 +114,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 256x128x512_64x64x512_8x8
2, 128, 128, false,
cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 128x128x512_64x64x512_8x8x128) {
@@ -145,7 +145,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 128x128x512_64x64x512_8x8
2, 128, 128, false,
cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 64x128x512_32x64x512_8x8x128) {
@@ -176,7 +176,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 64x128x512_32x64x512_8x8x
2, 128, 128, false,
cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 128x64x512_64x32x512_8x8x128) {
@@ -207,7 +207,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 128x64x512_64x32x512_8x8x
2, 128, 128, false,
cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 64x64x512_32x32x512_8x8x128) {
@@ -238,6 +238,6 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 64x64x512_32x32x512_8x8x1
2, 128, 128, false,
cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
#endif //CUTLASS_SUBBYTE_INTEGER_MATRIX_MULTIPLY_ENABLED

View File

@@ -71,7 +71,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 128x256x512_64x64x512) {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 256x128x512_64x64x512) {
@@ -93,7 +93,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 256x128x512_64x64x512) {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 128x128x512_64x64x512) {
@@ -115,7 +115,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 128x128x512_64x64x512) {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 64x256x512_64x64x512) {
@@ -137,7 +137,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 64x256x512_64x64x512) {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 256x64x512_64x64x512) {
@@ -159,7 +159,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 256x64x512_64x64x512) {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 64x128x512_32x64x512) {
@@ -180,7 +180,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 64x128x512_32x64x512) {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 128x64x512_64x32x512) {
@@ -202,7 +202,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 128x64x512_64x32x512) {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 64x64x512_32x32x512) {
@@ -224,7 +224,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 64x64x512_32x32x512) {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -83,7 +83,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 128x256x512_64x64x512_8x8
cutlass::arch::OpXorPopc
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 256x128x512_64x64x512_8x8x128) {
@@ -114,7 +114,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 256x128x512_64x64x512_8x8
2, 128, 128, false,
cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 128x128x512_64x64x512_8x8x128) {
@@ -145,7 +145,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 128x128x512_64x64x512_8x8
2, 128, 128, false,
cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 64x128x512_32x64x512_8x8x128) {
@@ -176,7 +176,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 64x128x512_32x64x512_8x8x
2, 128, 128, false,
cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 128x64x512_64x32x512_8x8x128) {
@@ -207,7 +207,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 128x64x512_64x32x512_8x8x
2, 128, 128, false,
cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 64x64x512_32x32x512_8x8x128) {
@@ -238,6 +238,6 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 64x64x512_32x32x512_8x8x1
2, 128, 128, false,
cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
#endif //CUTLASS_SUBBYTE_INTEGER_MATRIX_MULTIPLY_ENABLED

View File

@@ -65,7 +65,7 @@ TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x256x64_64x64x64) {
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 256x128x64_64x64x64) {
@@ -83,7 +83,7 @@ TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 256x128x64_64x64x64) {
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x128x64_64x64x64) {
@@ -101,7 +101,7 @@ TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x128x64_64x64x64) {
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 256x64x64_64x64x64) {
@@ -119,7 +119,7 @@ TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 256x64x64_64x64x64) {
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x256x64_64x64x64) {
@@ -137,7 +137,7 @@ TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x256x64_64x64x64) {
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x128x64_32x64x64) {
@@ -155,7 +155,7 @@ TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x128x64_32x64x64) {
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x64x64_64x32x64) {
@@ -173,7 +173,7 @@ TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x64x64_64x32x64) {
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x64x64_32x32x64) {
@@ -191,7 +191,7 @@ TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x64x64_32x32x64) {
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x256x32_64x64x32) {
@@ -209,7 +209,7 @@ TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x256x32_64x64x32) {
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 256x128x32_64x64x32) {
@@ -227,7 +227,7 @@ TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 256x128x32_64x64x32) {
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x128x32_64x64x32) {
@@ -245,7 +245,7 @@ TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x128x32_64x64x32) {
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 256x64x32_64x64x32) {
@@ -263,7 +263,7 @@ TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 256x64x32_64x64x32) {
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x256x32_64x64x32) {
@@ -281,7 +281,7 @@ TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x256x32_64x64x32) {
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x128x32_32x64x32) {
@@ -299,7 +299,7 @@ TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x128x32_32x64x32) {
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x64x32_64x32x32) {
@@ -317,7 +317,7 @@ TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x64x32_64x32x32) {
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x64x32_32x32x32) {
@@ -335,7 +335,7 @@ TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x64x32_32x32x32) {
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
////////////////////////////////////////////////////////////////////////////////

View File

@@ -56,7 +56,7 @@
// Operands data type: complex<float>
// Rounding: float -> tfloat32_t (half_ulp_truncate)
// Instruction operand data type: tfloat32_t (real part) and tfloat32_t (imaginary part)
// Math instruction: MMA.1688.F32.TF32
// Math instruction: mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32
// Instruction output/accumulation data type: f32 (real part) and f32 (imaginary part)
// Output data type: complex<float>
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -56,7 +56,7 @@
// Operands data type: complex<float>
// Rounding: float -> tfloat32_t (round to nearest)
// Instruction operand data type: tfloat32_t (real part) and tfloat32_t (imaginary part)
// Math instruction: MMA.1688.F32.TF32
// Math instruction: mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32
// Instruction output/accumulation data type: f32 (real part) and f32 (imaginary part)
// Output data type: complex<float>
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,198 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface with Hopper FP64
*/
#include <iostream>
#include "../../common/cutlass_unit_test.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_complex.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "testbed_complex.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian, 32x32x16_16x16x16) {
using Element = cutlass::complex<double>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<16, 16, 16>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
cutlass::arch::OpMultiplyAddGaussianComplex
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian, 32x32x8_16x16x8) {
using Element = cutlass::complex<double>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<32, 32, 8>,
cutlass::gemm::GemmShape<16, 16, 8>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
cutlass::arch::OpMultiplyAddGaussianComplex
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian, 64x64x16_16x32x16) {
using Element = cutlass::complex<double>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<64, 64, 16>,
cutlass::gemm::GemmShape<16, 32, 16>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
cutlass::arch::OpMultiplyAddGaussianComplex
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian, 64x64x8_16x32x8) {
using Element = cutlass::complex<double>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<64, 64, 8>,
cutlass::gemm::GemmShape<16, 32, 8>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
cutlass::arch::OpMultiplyAddGaussianComplex
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,252 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface with Hopper FP64
*/
#include <iostream>
#include "../../common/cutlass_unit_test.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_complex.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "testbed_complex.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 32x32x16_16x16x16) {
using Element = cutlass::complex<double>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<16, 16, 16>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 32x32x8_16x16x8) {
using Element = cutlass::complex<double>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<32, 32, 8>,
cutlass::gemm::GemmShape<16, 16, 8>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 64x64x16_16x32x16) {
using Element = cutlass::complex<double>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<64, 64, 16>,
cutlass::gemm::GemmShape<16, 32, 16>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 64x64x8_16x32x8) {
using Element = cutlass::complex<double>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<64, 64, 8>,
cutlass::gemm::GemmShape<16, 32, 8>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 64x64x16_32x32x16) {
using Element = cutlass::complex<double>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<64, 64, 16>,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 64x64x8_32x32x8) {
using Element = cutlass::complex<double>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<64, 64, 8>,
cutlass::gemm::GemmShape<32, 32, 8>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,197 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface with Hopper FP64
*/
#include <iostream>
#include "../../common/cutlass_unit_test.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_complex.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "testbed_complex.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian, 32x32x8_16x16x8) {
using Element = cutlass::complex<double>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<32, 32, 8>,
cutlass::gemm::GemmShape<16, 16, 8>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
cutlass::arch::OpMultiplyAddGaussianComplex
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian, 64x64x8_32x16x8) {
using Element = cutlass::complex<double>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<64, 64, 8>,
cutlass::gemm::GemmShape<32, 16, 8>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
cutlass::arch::OpMultiplyAddGaussianComplex
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian, 32x32x16_16x16x16) {
using Element = cutlass::complex<double>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<16, 16, 16>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
cutlass::arch::OpMultiplyAddGaussianComplex
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian, 64x64x16_32x16x16) {
using Element = cutlass::complex<double>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<64, 64, 16>,
cutlass::gemm::GemmShape<32, 16, 16>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
cutlass::arch::OpMultiplyAddGaussianComplex
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,305 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface with Hopper FP64
*/
#include <iostream>
#include "../../common/cutlass_unit_test.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_complex.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "testbed_complex.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 32x32x8_16x16x8) {
using Element = cutlass::complex<double>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<32, 32, 8>,
cutlass::gemm::GemmShape<16, 16, 8>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 64x64x8_32x32x8) {
using Element = cutlass::complex<double>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<64, 64, 8>,
cutlass::gemm::GemmShape<32, 32, 8>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 64x128x8_32x32x8) {
using Element = cutlass::complex<double>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<64, 128, 8>,
cutlass::gemm::GemmShape<32, 32, 8>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 128x64x8_32x32x8) {
using Element = cutlass::complex<double>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<64, 128, 8>,
cutlass::gemm::GemmShape<32, 32, 8>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 32x32x16_16x16x16) {
using Element = cutlass::complex<double>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<16, 16, 16>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 64x64x16_32x32x16) {
using Element = cutlass::complex<double>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<64, 64, 16>,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 64x128x16_32x32x16) {
using Element = cutlass::complex<double>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<64, 128, 16>,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 128x64x16_32x32x16) {
using Element = cutlass::complex<double>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<64, 128, 16>,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -45,6 +45,7 @@
#include "cutlass/util/tensor_view_io.h"
#include "testbed.h"
#include "testbed_universal.h"
#if (CUTLASS_ARCH_MMA_SM80_SUPPORTED)
@@ -104,6 +105,44 @@ CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x128x64_64x64x
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
} )
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x128x64_64x64x64_sk, {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::GemmUniversal<
cutlass::half_t, cutlass::layout::ColumnMajor,
cutlass::half_t, cutlass::layout::RowMajor,
ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal<Gemm>());
} )
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32n_tensor_op_f32, 128x128x64_64x64x64_sk, {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::GemmUniversal<
cutlass::half_t, cutlass::layout::ColumnMajor,
cutlass::half_t, cutlass::layout::RowMajor,
ElementOutput, cutlass::layout::ColumnMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal<Gemm>());
} )
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x64x64_64x64x64, {
using ElementOutput = float;
using ElementAccumulator = float;

View File

@@ -0,0 +1,440 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for GEMM + broadcast interface
*/
#include <fstream>
#include "cutlass/cutlass.h"
#include "cutlass/functional.h"
#include "cutlass/gemm/kernel/default_gemm_with_broadcast.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/gemm/device/gemm_universal_with_broadcast.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/epilogue/thread/linear_combination_bias_relu.h"
#include "cutlass/epilogue/thread/linear_combination_residual_block.h"
#include "../../common/cutlass_unit_test.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_elementwise.h"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/gemm.h"
template<typename GemmElement, typename LayoutA, typename LayoutB, typename LayoutC>
struct TestbedUtils {
/// Initialization
cutlass::Distribution::Kind init_A;
cutlass::Distribution::Kind init_B;
cutlass::Distribution::Kind init_C;
uint64_t seed;
cutlass::HostTensor<GemmElement, LayoutA> tensor_A; // Input A
cutlass::HostTensor<GemmElement, LayoutB> tensor_B; // Input B
cutlass::HostTensor<GemmElement, LayoutC> tensor_C; // Input C
cutlass::HostTensor<GemmElement, LayoutC> tensor_D1; // Input D
cutlass::HostTensor<GemmElement, LayoutC> tensor_D2; // Input D
cutlass::HostTensor<GemmElement, LayoutC> tensor_Y1; // Input Y
cutlass::HostTensor<GemmElement, LayoutC> tensor_Y2; // Input Y
cutlass::HostTensor<GemmElement, LayoutC> tensor_Y_ref;
//
// Methods
//
TestbedUtils(
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
uint64_t seed_ = 2080
):
init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { }
/// Helper to initialize a tensor view
template <typename Element, typename Layout>
bool initialize_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
if (dist_kind == cutlass::Distribution::Uniform) {
double scope_max, scope_min;
int bits_input = cutlass::sizeof_bits<Element>::value;
int bits_output = cutlass::sizeof_bits<Element>::value;
if (bits_input == 1) {
scope_max = 2;
scope_min = 0;
} else if (bits_input <= 8) {
scope_max = 2;
scope_min = -2;
} else if (bits_output == 16) {
scope_max = 5;
scope_min = -5;
} else {
scope_max = 8;
scope_min = -8;
}
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
}
else if (dist_kind == cutlass::Distribution::AllZeros) {
cutlass::reference::host::TensorFill(view);
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
}
else if (dist_kind == cutlass::Distribution::Sequential) {
cutlass::reference::host::BlockFillSequential(
view.data(), view.capacity());
}
else {
// TODO: Implement the rest
EXPECT_TRUE(false) << "Not implemented";
return false;
}
return true;
}
/// Initializes data structures
void initialize(cutlass::gemm::GemmCoord problem_size) {
//
// Allocate the GEMM workspace
//
tensor_A.resize(problem_size.mk());
tensor_B.resize(problem_size.kn());
tensor_C.resize({1, problem_size.n()});
tensor_D1.resize(problem_size.mn());
tensor_D2.resize(problem_size.mn());
tensor_Y1.resize(problem_size.mn());
tensor_Y2.resize(problem_size.mn());
tensor_Y_ref.resize(problem_size.mn());
EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019));
EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018));
EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017));
// Initialize D data to smaller data range. This helps avoid large roundoff errors.
int d_scope_min = -2;
int d_scope_max = 2;
cutlass::reference::host::TensorFillRandomUniform(tensor_D1.host_view(), seed + 2016, d_scope_max, d_scope_min, 0);
cutlass::reference::host::TensorFillRandomUniform(tensor_D2.host_view(), seed + 2015, d_scope_max, d_scope_min, 0);
EXPECT_TRUE(initialize_tensor(tensor_Y1.host_view(), cutlass::Distribution::AllZeros, 0));
EXPECT_TRUE(initialize_tensor(tensor_Y2.host_view(), cutlass::Distribution::AllZeros, 0));
EXPECT_TRUE(initialize_tensor(tensor_Y_ref.host_view(), cutlass::Distribution::AllZeros, 0));
// It is possible to randomly initialize to all zeros, so override this with non-zeros
// in the upper left corner of each operand.
tensor_A.host_view().at({0, 0}) = GemmElement(1);
tensor_B.host_view().at({0, 0}) = GemmElement(1);
tensor_C.host_view().at({0, 0}) = GemmElement(1);
tensor_D1.host_view().at({0, 0}) = GemmElement(1);
tensor_D2.host_view().at({0, 0}) = GemmElement(1);
tensor_A.sync_device();
tensor_B.sync_device();
tensor_C.sync_device();
tensor_D1.sync_device();
tensor_D2.sync_device();
}
/// Compares computed reference with device reference and outputs to a file if incorrect
bool compare_reference(
cutlass::gemm::GemmCoord problem_size, cutlass::HostTensor<GemmElement, LayoutC>& tensor_Y_ref, cutlass::HostTensor<GemmElement, LayoutC>& tensor_Y) {
tensor_Y_ref.sync_host();
tensor_Y.sync_host();
EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0);
EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0);
EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0);
EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0);
EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D2.host_view()), 0);
EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Y_ref.host_view()), 0);
EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Y.host_view()), 0);
bool passed = true;
float norm_diff = 0;
norm_diff = cutlass::reference::host::TensorNormDiff(tensor_Y_ref.host_view(), tensor_Y.host_view(), float());
passed = (norm_diff <= 0.1f);
EXPECT_LT(norm_diff, 0.1f) << " tensor_Y is incorrect";
if (!passed) {
std::ofstream file("errors_testbed_gemm_broadcast_new.txt");
file
<< "problem: " << problem_size << "\n\n";
file
<< "capacity: \n"
<< "A: " << tensor_A.capacity()
<< "\nB: " << tensor_B.capacity()
<< "\nC: " << tensor_C.capacity()
<< "\nD1: " << tensor_D1.capacity()
<< "\nD2: " << tensor_D2.capacity()
<< "\nY: " << tensor_Y.capacity()
<< "\n\n"
<< "\nY_ref: " << tensor_Y_ref.capacity()
<< "\n\n";
file
<< "A =\n" << tensor_A.host_view()
<< "\n\nB =\n" << tensor_B.host_view()
<< "\n\nC =\n" << tensor_C.host_view()
<< "\n\nD1 =\n" << tensor_D1.host_view()
<< "\n\nD2 =\n" << tensor_D2.host_view()
<< "\n\nY =\n" << tensor_Y.host_view()
<< "\n\nY_ref =\n" << tensor_Y_ref.host_view();
}
return passed;
}
};
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
TEST(SM80_Device_GemmWithBroadcast_f16t_f16n_f16t_tensor_op_f16, 128x128_32x3_64x64x32_16x8x16) {
using ElementA = cutlass::half_t;
using ElementB = cutlass::half_t;
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using OpClass = cutlass::arch::OpClassTensorOp;
using ArchTag = cutlass::arch::Sm80;
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>;
using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle;
const int kStages = 3;
const int batch_count = 1;
const cutlass::half_t alpha(1);
const cutlass::half_t beta(1);
const int M = 1024;
const int K = 10240;
const int N = 512;
cutlass::gemm::GemmCoord problem{M, N, K};
const int batch_stride_A = 0;
const int batch_stride_B = 0;
const int batch_stride_C1 = 0;
const int batch_stride_C2 = 0;
const int batch_stride_D = 0;
const int batch_stride_Vector = 0;
const int batch_stride_Tensor = 0;
const int64_t lda = LayoutA::packed({problem.m(), problem.k()}).stride(0);
const int64_t ldb = LayoutB::packed({problem.k(), problem.n()}).stride(0);
const int64_t ldc1 = LayoutC::packed({problem.m(), problem.n()}).stride(0);
const int64_t ldc2 = LayoutC::packed({problem.m(), problem.n()}).stride(0);
const int64_t ldd = LayoutC::packed({problem.m(), problem.n()}).stride(0);
const int64_t ldv = 0;
const int64_t ldt = 0;
TestbedUtils<ElementA, LayoutA, LayoutB, LayoutC> utils;
utils.initialize(problem);
//
// Create reference Gemm
//
using GemmRef = cutlass::gemm::device::GemmUniversal<
ElementA, LayoutA, ElementB, LayoutB, ElementOutput, LayoutC, ElementAccumulator,
OpClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
ThreadblockSwizzle, kStages>;
typename GemmRef::Arguments args_ref{
cutlass::gemm::GemmUniversalMode::kGemm,
problem,
batch_count,
{alpha, beta},
utils.tensor_A.device_data(),
utils.tensor_B.device_data(),
utils.tensor_C.device_data(),
utils.tensor_Y_ref.device_data(),
batch_stride_A,
batch_stride_B,
batch_stride_C1,
batch_stride_D,
lda,
ldb,
ldv,
ldd,
};
GemmRef gemm_op_ref;
size_t workspace_size_ref = GemmRef::get_workspace_size(args_ref);
cutlass::device_memory::allocation<uint8_t> workspace_ref(workspace_size_ref);
cutlass::Status status = gemm_op_ref.initialize(args_ref, workspace_ref.get());
EXPECT_TRUE(status == cutlass::Status::kSuccess) << cutlassGetStatusString(status);
status = gemm_op_ref();
EXPECT_TRUE(status == cutlass::Status::kSuccess) << cutlassGetStatusString(status);
//
// Create GemmWithBroadcast from single source
//
using GemmSingle = cutlass::gemm::device::GemmUniversalWithBroadcast<
ElementA, LayoutA, ElementB, LayoutB, ElementOutput, LayoutC, ElementAccumulator,
OpClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape,
cutlass::epilogue::thread::LinearCombinationResidualBlock<
ElementOutput, ElementAccumulator, ElementAccumulator,
ElementAccumulator, 128 / cutlass::sizeof_bits<ElementOutput>::value,
cutlass::epilogue::thread::Identity, cutlass::multiplies, cutlass::epilogue::thread::Identity>,
ThreadblockSwizzle, kStages>;
typename GemmSingle::Arguments args_single{
cutlass::gemm::GemmUniversalMode::kGemm,
problem,
batch_count,
{alpha, beta},
utils.tensor_A.device_data(),
utils.tensor_B.device_data(),
utils.tensor_D1.device_data(),
utils.tensor_Y1.device_data(),
utils.tensor_C.device_data(),
/* ptr_Tensor = */ nullptr,
batch_stride_A,
batch_stride_B,
batch_stride_C1,
batch_stride_D,
batch_stride_Vector,
batch_stride_Tensor,
lda,
ldb,
ldc1,
ldd,
ldv,
ldt
};
GemmSingle gemm_op_single;
size_t workspace_size_single = GemmSingle::get_workspace_size(args_single);
cutlass::device_memory::allocation<uint8_t> workspace_single(workspace_size_single);
status = gemm_op_single.initialize(args_single, workspace_single.get());
EXPECT_TRUE(status == cutlass::Status::kSuccess) << cutlassGetStatusString(status);
status = gemm_op_single();
EXPECT_TRUE(status == cutlass::Status::kSuccess) << cutlassGetStatusString(status);
// Compute the broadcast on the reference previously computed and compare results
utils.tensor_Y_ref.sync_host();
cutlass::reference::host::TensorMul(utils.tensor_Y_ref.host_view(), utils.tensor_D1.host_view());
utils.tensor_Y_ref.sync_device();
utils.compare_reference(problem, utils.tensor_Y_ref, utils.tensor_Y1);
//
// Create GemmWithBroadcast from two sources
//
using GemmDouble = cutlass::gemm::device::GemmUniversalWithBroadcast<
ElementA, LayoutA, ElementB, LayoutB, ElementOutput, LayoutC, ElementAccumulator,
OpClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape,
cutlass::epilogue::thread::LinearCombinationResidualBlock<
ElementOutput, ElementAccumulator, ElementAccumulator,
ElementAccumulator, 128 / cutlass::sizeof_bits<ElementOutput>::value,
cutlass::epilogue::thread::Identity, cutlass::multiplies, cutlass::epilogue::thread::Identity, cutlass::plus>,
ThreadblockSwizzle, kStages>;
typename GemmDouble::Arguments args_double{
cutlass::gemm::GemmUniversalMode::kGemm,
problem,
batch_count,
{alpha, beta},
utils.tensor_A.device_data(),
utils.tensor_B.device_data(),
utils.tensor_D1.device_data(),
utils.tensor_D2.device_data(),
utils.tensor_Y2.device_data(),
utils.tensor_C.device_data(),
/* ptr_Tensor = */ nullptr,
batch_stride_A,
batch_stride_B,
batch_stride_C1,
batch_stride_C2,
batch_stride_D,
batch_stride_Vector,
batch_stride_Tensor,
lda,
ldb,
ldc1,
ldc2,
ldd,
ldv,
ldt
};
GemmDouble gemm_op_double;
size_t workspace_size_double = GemmDouble::get_workspace_size(args_double);
cutlass::device_memory::allocation<uint8_t> workspace_double(workspace_size_double);
status = gemm_op_double.initialize(args_double, workspace_double.get());
EXPECT_TRUE(status == cutlass::Status::kSuccess) << cutlassGetStatusString(status);
status = gemm_op_double();
EXPECT_TRUE(status == cutlass::Status::kSuccess) << cutlassGetStatusString(status);
// Compute the broadcast on the reference previously computed and compare results
utils.tensor_Y_ref.sync_host();
cutlass::reference::host::TensorAdd(utils.tensor_Y_ref.host_view(), utils.tensor_D2.host_view());
utils.tensor_Y_ref.sync_device();
utils.compare_reference(problem, utils.tensor_Y_ref, utils.tensor_Y2);
}
#endif

View File

@@ -0,0 +1,223 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface with Hopper FP64
*/
#include <iostream>
#include "../../common/cutlass_unit_test.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "testbed.h"
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_f64n_f64t_f64t_tensor_op_f64, 32x32x16_16x16x16_16x8x4) {
using ElementOutput = double;
using ElementAccumulator = double;
using ElementCompute = double;
using Gemm = cutlass::gemm::device::Gemm<
double,
cutlass::layout::ColumnMajor,
double,
cutlass::layout::RowMajor,
ElementOutput,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<16, 16, 16>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput,
1,
ElementAccumulator,
ElementCompute
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_f64n_f64t_f64t_tensor_op_f64, 64x64x16_32x32x16_16x8x4) {
using ElementOutput = double;
using ElementAccumulator = double;
using ElementCompute = double;
using Gemm = cutlass::gemm::device::Gemm<
double,
cutlass::layout::ColumnMajor,
double,
cutlass::layout::RowMajor,
ElementOutput,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<64, 64, 16>,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput,
1,
ElementAccumulator,
ElementCompute
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_f64n_f64t_f64t_tensor_op_f64, 128x64x16_64x32x16_16x8x4) {
using ElementOutput = double;
using ElementAccumulator = double;
using ElementCompute = double;
using Gemm = cutlass::gemm::device::Gemm<
double,
cutlass::layout::ColumnMajor,
double,
cutlass::layout::RowMajor,
ElementOutput,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<128, 64, 16>,
cutlass::gemm::GemmShape<64, 32, 16>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput,
1,
ElementAccumulator,
ElementCompute
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_f64n_f64t_f64t_tensor_op_f64, 64x128x16_32x64x16_16x8x4) {
using ElementOutput = double;
using ElementAccumulator = double;
using ElementCompute = double;
using Gemm = cutlass::gemm::device::Gemm<
double,
cutlass::layout::ColumnMajor,
double,
cutlass::layout::RowMajor,
ElementOutput,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<64, 128, 16>,
cutlass::gemm::GemmShape<32, 64, 16>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput,
1,
ElementAccumulator,
ElementCompute
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_f64n_f64t_f64t_tensor_op_f64, 128x128x16_32x64x16_16x8x4) {
using ElementOutput = double;
using ElementAccumulator = double;
using ElementCompute = double;
using Gemm = cutlass::gemm::device::Gemm<
double,
cutlass::layout::ColumnMajor,
double,
cutlass::layout::RowMajor,
ElementOutput,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<128, 128, 16>,
cutlass::gemm::GemmShape<32, 64, 16>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput,
1,
ElementAccumulator,
ElementCompute
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
#endif // if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)

View File

@@ -0,0 +1,223 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface with Hopper FP64
*/
#include <iostream>
#include "../../common/cutlass_unit_test.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "testbed.h"
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_f64t_f64n_f64t_tensor_op_f64, 32x32x16_16x16x16_16x8x4) {
using ElementOutput = double;
using ElementAccumulator = double;
using ElementCompute = double;
using Gemm = cutlass::gemm::device::Gemm<
double,
cutlass::layout::RowMajor,
double,
cutlass::layout::ColumnMajor,
ElementOutput,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<16, 16, 16>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput,
1,
ElementAccumulator,
ElementCompute
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_f64t_f64n_f64t_tensor_op_f64, 64x64x16_32x32x16_16x8x4) {
using ElementOutput = double;
using ElementAccumulator = double;
using ElementCompute = double;
using Gemm = cutlass::gemm::device::Gemm<
double,
cutlass::layout::RowMajor,
double,
cutlass::layout::ColumnMajor,
ElementOutput,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<64, 64, 16>,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput,
1,
ElementAccumulator,
ElementCompute
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_f64t_f64n_f64t_tensor_op_f64, 64x128x16_32x64x16_16x8x4) {
using ElementOutput = double;
using ElementAccumulator = double;
using ElementCompute = double;
using Gemm = cutlass::gemm::device::Gemm<
double,
cutlass::layout::RowMajor,
double,
cutlass::layout::ColumnMajor,
ElementOutput,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<64, 128, 16>,
cutlass::gemm::GemmShape<32, 64, 16>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput,
1,
ElementAccumulator,
ElementCompute
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_f64t_f64n_f64t_tensor_op_f64, 128x64x16_64x32x16_16x8x4) {
using ElementOutput = double;
using ElementAccumulator = double;
using ElementCompute = double;
using Gemm = cutlass::gemm::device::Gemm<
double,
cutlass::layout::RowMajor,
double,
cutlass::layout::ColumnMajor,
ElementOutput,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<128, 64, 16>,
cutlass::gemm::GemmShape<64, 32, 16>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput,
1,
ElementAccumulator,
ElementCompute
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Gemm_f64t_f64n_f64t_tensor_op_f64, 128x128x16_32x64x16_16x8x4) {
using ElementOutput = double;
using ElementAccumulator = double;
using ElementCompute = double;
using Gemm = cutlass::gemm::device::Gemm<
double,
cutlass::layout::RowMajor,
double,
cutlass::layout::ColumnMajor,
ElementOutput,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<128, 128, 16>,
cutlass::gemm::GemmShape<32, 64, 16>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput,
1,
ElementAccumulator,
ElementCompute
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
#endif // if (CUTLASS_ARCH_MMA_SM90_SUPPORTED)

View File

@@ -81,7 +81,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 128x256x128_64x64x128) {
2
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 256x128x128_64x64x128) {
@@ -113,7 +113,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 256x128x128_64x64x128) {
2
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 128x128x128_64x64x128) {
@@ -145,7 +145,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 128x128x128_64x64x128) {
2
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 64x256x128_64x64x128) {
@@ -177,7 +177,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 64x256x128_64x64x128) {
2
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 256x64x128_64x64x128) {
@@ -209,7 +209,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 256x64x128_64x64x128) {
2
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 64x128x128_32x64x128) {
@@ -241,7 +241,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 64x128x128_32x64x128) {
2
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 128x64x128_64x32x128) {
@@ -273,7 +273,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 128x64x128_64x32x128) {
2
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 64x64x128_32x32x128) {
@@ -305,7 +305,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 64x64x128_32x32x128) {
2
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -81,7 +81,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 128x256x128_64x64x128) {
2
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 256x128x128_64x64x128) {
@@ -113,7 +113,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 256x128x128_64x64x128) {
2
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 128x128x128_64x64x128) {
@@ -145,7 +145,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 128x128x128_64x64x128) {
2
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 64x256x128_64x64x128) {
@@ -177,7 +177,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 64x256x128_64x64x128) {
2
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 256x64x128_64x64x128) {
@@ -209,7 +209,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 256x64x128_64x64x128) {
2
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 64x128x128_32x64x128) {
@@ -240,7 +240,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 64x128x128_32x64x128) {
2
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 128x64x128_64x32x128) {
@@ -272,7 +272,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 128x64x128_64x32x128) {
2
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 64x64x128_32x32x128) {
@@ -304,7 +304,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 64x64x128_32x32x128) {
2
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,135 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide HEMM interface
*/
#include <iostream>
#include "../../common/cutlass_unit_test.h"
#include "cutlass/blas3.h"
#include "cutlass/gemm/device/symm.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/symm_complex.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "testbed_symm_universal.h"
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Hemm_cf64h_cf64n_ls_l_tensor_op_f64_gaussian, 32x32x16_16x16x16) {
using ElementOutput = cutlass::complex<double>;
using ElementAccumulator = cutlass::complex<double>;
using Hemm = cutlass::gemm::device::Symm<
cutlass::complex<double>,
cutlass::layout::ColumnMajor,
cutlass::SideMode::kLeft,
cutlass::FillMode::kLower,
cutlass::complex<double>,
cutlass::layout::ColumnMajor,
ElementOutput,
cutlass::layout::ColumnMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<16, 16, 16>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput,
1,
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4,
1,
1,
false,
cutlass::arch::OpMultiplyAddGaussianComplex,
cutlass::BlasMode::kHermitian
>;
EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal<Hemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Hemm_cf64h_cf64n_rs_u_tensor_op_f64, 64x64x16_32x32x16) {
using ElementOutput = cutlass::complex<double>;
using ElementAccumulator = cutlass::complex<double>;
using Hemm = cutlass::gemm::device::Symm<
cutlass::complex<double>,
cutlass::layout::ColumnMajor,
cutlass::SideMode::kRight,
cutlass::FillMode::kUpper,
cutlass::complex<double>,
cutlass::layout::ColumnMajor,
ElementOutput,
cutlass::layout::ColumnMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<64, 64, 16>,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput,
1,
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4,
1,
1,
false,
cutlass::arch::OpMultiplyAddComplex,
cutlass::BlasMode::kHermitian
>;
EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal<Hemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)

View File

@@ -0,0 +1,149 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide HER2K interface
*/
#include <iostream>
#include "../../common/cutlass_unit_test.h"
#include "cutlass/blas3.h"
#include "cutlass/gemm/device/rank_2k.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/rank_2k.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "testbed_rank2k_universal.h"
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Her2k_cf64n_cf64n_l_tensor_op_f64, 32x32x16_16x16x16) {
using ElementA = cutlass::complex<double>;
using LayoutA = cutlass::layout::ColumnMajor;
using ElementB = cutlass::complex<double>;
using LayoutB = cutlass::layout::ColumnMajor;
using ElementC = cutlass::complex<double>;
using LayoutC = cutlass::layout::ColumnMajor;
using ElementAccumulator = cutlass::complex<double>;
using Rank2K = cutlass::gemm::device::Rank2K<
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
cutlass::FillMode::kLower,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<16, 16, 16>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
ElementC,
1,
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4, // kStages
1, // AlignmentA
1, // AlignmentB
false, // SplitKSerial
cutlass::arch::OpMultiplyAddComplex,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
cutlass::BlasMode::kHermitian
>;
EXPECT_TRUE(test::gemm::device::TestAllRank2KHermitianUniversal<Rank2K>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Her2k_cf64c_cf64n_u_tensor_op_f64, 32x32x16_16x16x16) {
using ElementA = cutlass::complex<double>;
using LayoutA = cutlass::layout::RowMajor;
using ElementB = cutlass::complex<double>;
using LayoutB = cutlass::layout::RowMajor;
using ElementC = cutlass::complex<double>;
using LayoutC = cutlass::layout::ColumnMajor;
using ElementAccumulator = cutlass::complex<double>;
using Rank2K = cutlass::gemm::device::Rank2K<
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
cutlass::FillMode::kUpper,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<16, 16, 16>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
ElementC,
1,
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4, // kStages
1, // AlignmentA
1, // AlignmentB
false, // SplitKSerial
cutlass::arch::OpMultiplyAddComplex,
cutlass::ComplexTransform::kConjugate,
cutlass::ComplexTransform::kConjugate,
cutlass::BlasMode::kHermitian
>;
EXPECT_TRUE(test::gemm::device::TestAllRank2KHermitianUniversal<Rank2K>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)

View File

@@ -0,0 +1,93 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide HERK interface
*/
#include <iostream>
#include "../../common/cutlass_unit_test.h"
#include "cutlass/blas3.h"
#include "cutlass/gemm/device/rank_k.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/rank_k_complex.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "testbed_rank_k_universal.h"
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
// HERK operator on CUBLAS_OP_C (row-major + conj) input layouts
TEST(SM90_Device_Herk_cf64h_cf64n_l_tensor_op_f64, 64x64x16_32x32x16) {
using ElementA = cutlass::complex<double>;
using LayoutA = cutlass::layout::RowMajor;
using ElementC = cutlass::complex<double>;
using LayoutC = cutlass::layout::ColumnMajor;
using ElementAccumulator = cutlass::complex<double>;
using RankK = cutlass::gemm::device::RankK<
ElementA,
LayoutA,
ElementC,
LayoutC,
cutlass::FillMode::kLower,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<16, 16, 16>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
ElementC,
1,
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4, // kStages
1, // AlignmentA
false, // SplitKSerial
cutlass::arch::OpMultiplyAddComplex,
cutlass::ComplexTransform::kConjugate,
cutlass::BlasMode::kHermitian
>;
EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal<RankK>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)

View File

@@ -125,7 +125,7 @@ struct MultistageTestbed {
throw std::runtime_error("cudaGetDeviceProperties() failed");
}
if (properties.sharedMemPerMultiprocessor < smem_size) {
if (properties.sharedMemPerBlockOptin < smem_size) {
return false;
}

View File

@@ -0,0 +1,265 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface
*/
#include <iostream>
#include "../../common/cutlass_unit_test.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_complex.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "testbed_complex.h"
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_simt_cf32, 32x64x8_32x64x1) {
using Element = cutlass::complex<float>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassSimt,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<32, 64, 8>,
cutlass::gemm::GemmShape<32, 32, 8>,
cutlass::gemm::GemmShape<1, 1, 1>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_simt_cf32, 64x64x8_32x64x1) {
using Element = cutlass::complex<float>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassSimt,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 8>,
cutlass::gemm::GemmShape<32, 64, 8>,
cutlass::gemm::GemmShape<1, 1, 1>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_simt_cf32, 128x128x8_32x64x1) {
using Element = cutlass::complex<float>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassSimt,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 8>,
cutlass::gemm::GemmShape<32, 64, 8>,
cutlass::gemm::GemmShape<1, 1, 1>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3,
cutlass::ComplexTransform::kConjugate,
cutlass::ComplexTransform::kConjugate
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_simt_cf32, 64x128x8_32x64x1) {
using Element = cutlass::complex<float>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassSimt,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 128, 8>,
cutlass::gemm::GemmShape<32, 64, 8>,
cutlass::gemm::GemmShape<1, 1, 1>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3,
cutlass::ComplexTransform::kConjugate,
cutlass::ComplexTransform::kConjugate
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_simt_cf32, 128x64x8_32x64x1) {
using Element = cutlass::complex<float>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassSimt,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 64, 8>,
cutlass::gemm::GemmShape<64, 32, 8>,
cutlass::gemm::GemmShape<1, 1, 1>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_simt_cf32, 128x128x8_64x64x1) {
using Element = cutlass::complex<float>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassSimt,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 8>,
cutlass::gemm::GemmShape<64, 64, 8>,
cutlass::gemm::GemmShape<1, 1, 1>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_simt_cf32, 128x256x8_64x64x1) {
using Element = cutlass::complex<float>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassSimt,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 256, 8>,
cutlass::gemm::GemmShape<64, 64, 8>,
cutlass::gemm::GemmShape<1, 1, 1>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3,
cutlass::ComplexTransform::kConjugate,
cutlass::ComplexTransform::kConjugate
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,269 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface
*/
#include <iostream>
#include "../../common/cutlass_unit_test.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_complex.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "testbed_complex.h"
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_simt_cf32, 32x64x8_32x64x1) {
using Element = cutlass::complex<float>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassSimt,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<32, 64, 8>,
cutlass::gemm::GemmShape<32, 32, 8>,
cutlass::gemm::GemmShape<1, 1, 1>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_simt_cf32, 64x64x8_32x64x1) {
using Element = cutlass::complex<float>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassSimt,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 8>,
cutlass::gemm::GemmShape<32, 64, 8>,
cutlass::gemm::GemmShape<1, 1, 1>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3,
cutlass::ComplexTransform::kConjugate,
cutlass::ComplexTransform::kConjugate
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_simt_cf32, 128x128x8_32x64x1) {
using Element = cutlass::complex<float>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassSimt,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 8>,
cutlass::gemm::GemmShape<32, 64, 8>,
cutlass::gemm::GemmShape<1, 1, 1>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_simt_cf32, 64x128x8_32x64x1) {
using Element = cutlass::complex<float>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassSimt,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 128, 8>,
cutlass::gemm::GemmShape<32, 64, 8>,
cutlass::gemm::GemmShape<1, 1, 1>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3,
cutlass::ComplexTransform::kConjugate,
cutlass::ComplexTransform::kConjugate
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_simt_cf32, 128x64x8_64x32x1) {
using Element = cutlass::complex<float>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassSimt,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 64, 8>,
cutlass::gemm::GemmShape<64, 32, 8>,
cutlass::gemm::GemmShape<1, 1, 1>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3,
cutlass::ComplexTransform::kConjugate,
cutlass::ComplexTransform::kConjugate
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_simt_cf32, 128x128x8_64x64x1) {
using Element = cutlass::complex<float>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassSimt,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 8>,
cutlass::gemm::GemmShape<64, 64, 8>,
cutlass::gemm::GemmShape<1, 1, 1>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_simt_cf32, 128x256x8_64x64x1) {
using Element = cutlass::complex<float>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassSimt,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 256, 8>,
cutlass::gemm::GemmShape<64, 64, 8>,
cutlass::gemm::GemmShape<1, 1, 1>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3,
cutlass::ComplexTransform::kConjugate,
cutlass::ComplexTransform::kConjugate
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,87 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface
*/
#include <iostream>
#include "../../common/cutlass_unit_test.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "testbed.h"
////////////////////////////////////////////////////////////////////////////////
#if (__CUDACC_VER_MAJOR__ > 11) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4))
TEST(SM50_Device_Gemm_fe4m3t_fe4m3n_fe4m3t_simt_f32, 32x64x8_32x64x1) {
using ElementA = cutlass::float_e4m3_t;
using ElementB = cutlass::float_e4m3_t;
using ElementC = cutlass::float_e4m3_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
ElementA,
cutlass::layout::RowMajor,
ElementB,
cutlass::layout::ColumnMajor,
ElementC,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassSimt,
cutlass::arch::Sm50,
cutlass::gemm::GemmShape<32, 64, 8>,
cutlass::gemm::GemmShape<32, 64, 8>,
cutlass::gemm::GemmShape<1, 1, 1>,
cutlass::epilogue::thread::LinearCombination<
ElementC,
1,
ElementAccumulator,
ElementC>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
}
#endif
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -76,7 +76,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_nn, 8x32x8_8x32x1_2x4_4x8_1x1, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -106,7 +106,7 @@ CUTLASS_TEST_L0(SM50_device_qgemm_nn, 16x32x8_16x32x1_4x4_4x8_1x1, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -136,7 +136,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nn, 8x32x8_8x16x1_2x2_4x8_1x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -166,7 +166,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_nn, 8x64x8_8x32x1_2x4_4x8_1x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -196,7 +196,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_nn, 16x32x8_16x16x1_4x2_4x8_1x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -226,7 +226,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_nn, 16x64x8_16x32x1_4x4_4x8_1x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -256,7 +256,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_nn, 32x32x8_32x16x1_4x4_8x4_1x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -286,7 +286,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nn, 32x32x8_16x32x1_4x4_4x8_2x1, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -316,7 +316,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nn, 16x32x8_8x16x1_2x2_4x8_2x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -346,7 +346,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nn, 16x64x8_8x32x1_2x4_4x8_2x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -376,7 +376,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nn, 32x32x8_16x16x1_4x2_4x8_2x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -406,7 +406,7 @@ CUTLASS_TEST_L0(SM50_device_qgemm_nn, 32x64x8_16x32x1_4x4_4x8_2x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -436,7 +436,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_nn, 64x32x8_32x16x1_4x4_8x4_2x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -466,7 +466,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nn, 16x64x16_8x16x1_2x2_4x8_2x4, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -496,7 +496,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nn, 32x32x8_16x8x1_2x2_8x4_2x4, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -526,7 +526,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_nn, 32x64x8_16x16x1_4x2_4x8_2x4, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -556,7 +556,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_nn, 32x128x8_16x32x1_4x4_4x8_2x4, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -586,7 +586,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_nn, 64x64x8_32x16x1_4x4_8x4_2x4, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -616,7 +616,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nn, 32x32x8_8x16x1_2x2_4x8_4x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -646,7 +646,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nn, 64x32x8_16x16x1_4x2_4x8_4x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -676,7 +676,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nn, 64x64x8_16x32x1_4x4_4x8_4x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -706,7 +706,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_nn, 128x32x8_32x16x1_4x4_8x4_4x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -736,7 +736,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nn, 32x64x16_8x16x1_2x2_4x8_4x4, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -766,7 +766,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nn, 64x32x16_16x8x1_2x2_8x4_4x4, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -796,7 +796,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nn, 64x64x8_16x16x1_4x2_4x8_4x4, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -826,7 +826,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_nn, 64x128x8_16x32x1_4x4_4x8_4x4, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -856,6 +856,6 @@ CUTLASS_TEST_L1(SM50_device_qgemm_nn, 128x64x8_32x16x1_4x4_8x4_4x4, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )

View File

@@ -76,7 +76,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_nt, 8x32x8_8x32x1_2x4_4x8_1x1, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -106,7 +106,7 @@ CUTLASS_TEST_L0(SM50_device_qgemm_nt, 16x32x8_16x32x1_4x4_4x8_1x1, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -136,7 +136,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nt, 8x32x8_8x16x1_2x2_4x8_1x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -166,7 +166,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_nt, 8x64x8_8x32x1_2x4_4x8_1x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -196,7 +196,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_nt, 16x32x8_16x16x1_4x2_4x8_1x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -226,7 +226,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_nt, 16x64x8_16x32x1_4x4_4x8_1x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -256,7 +256,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_nt, 32x32x8_32x16x1_4x4_8x4_1x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -286,7 +286,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nt, 32x32x8_16x32x1_4x4_4x8_2x1, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -316,7 +316,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nt, 16x32x8_8x16x1_2x2_4x8_2x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -346,7 +346,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nt, 16x64x8_8x32x1_2x4_4x8_2x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -376,7 +376,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nt, 32x32x8_16x16x1_4x2_4x8_2x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -406,7 +406,7 @@ CUTLASS_TEST_L0(SM50_device_qgemm_nt, 32x64x8_16x32x1_4x4_4x8_2x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -436,7 +436,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_nt, 64x32x8_32x16x1_4x4_8x4_2x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -466,7 +466,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nt, 16x64x16_8x16x1_2x2_4x8_2x4, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -496,7 +496,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nt, 32x32x8_16x8x1_2x2_8x4_2x4, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -526,7 +526,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_nt, 32x64x8_16x16x1_4x2_4x8_2x4, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -556,7 +556,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_nt, 32x128x8_16x32x1_4x4_4x8_2x4, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -586,7 +586,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_nt, 64x64x8_32x16x1_4x4_8x4_2x4, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -616,7 +616,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nt, 32x32x8_8x16x1_2x2_4x8_4x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -646,7 +646,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nt, 64x32x8_16x16x1_4x2_4x8_4x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -676,7 +676,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nt, 64x64x8_16x32x1_4x4_4x8_4x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -706,7 +706,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_nt, 128x32x8_32x16x1_4x4_8x4_4x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -736,7 +736,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nt, 32x64x16_8x16x1_2x2_4x8_4x4, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -766,7 +766,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nt, 64x32x16_16x8x1_2x2_8x4_4x4, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -796,7 +796,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nt, 64x64x8_16x16x1_4x2_4x8_4x4, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -826,7 +826,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_nt, 64x128x8_16x32x1_4x4_4x8_4x4, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -856,6 +856,6 @@ CUTLASS_TEST_L1(SM50_device_qgemm_nt, 128x64x8_32x16x1_4x4_8x4_4x4, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )

View File

@@ -76,7 +76,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_tn, 8x32x8_8x32x1_2x4_4x8_1x1, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -106,7 +106,7 @@ CUTLASS_TEST_L0(SM50_device_qgemm_tn, 16x32x8_16x32x1_4x4_4x8_1x1, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -136,7 +136,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tn, 8x32x8_8x16x1_2x2_4x8_1x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -166,7 +166,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_tn, 8x64x8_8x32x1_2x4_4x8_1x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -196,7 +196,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_tn, 16x32x8_16x16x1_4x2_4x8_1x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -226,7 +226,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_tn, 16x64x8_16x32x1_4x4_4x8_1x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -256,7 +256,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_tn, 32x32x8_32x16x1_4x4_8x4_1x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -286,7 +286,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tn, 32x32x8_16x32x1_4x4_4x8_2x1, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -316,7 +316,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tn, 16x32x8_8x16x1_2x2_4x8_2x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -346,7 +346,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tn, 16x64x8_8x32x1_2x4_4x8_2x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -376,7 +376,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tn, 32x32x8_16x16x1_4x2_4x8_2x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -406,7 +406,7 @@ CUTLASS_TEST_L0(SM50_device_qgemm_tn, 32x64x8_16x32x1_4x4_4x8_2x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -436,7 +436,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_tn, 64x32x8_32x16x1_4x4_8x4_2x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -466,7 +466,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tn, 16x64x16_8x16x1_2x2_4x8_2x4, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -496,7 +496,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tn, 32x32x8_16x8x1_2x2_8x4_2x4, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -526,7 +526,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_tn, 32x64x8_16x16x1_4x2_4x8_2x4, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -556,7 +556,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_tn, 32x128x8_16x32x1_4x4_4x8_2x4, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -586,7 +586,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_tn, 64x64x8_32x16x1_4x4_8x4_2x4, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -616,7 +616,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tn, 32x32x8_8x16x1_2x2_4x8_4x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -646,7 +646,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tn, 64x32x8_16x16x1_4x2_4x8_4x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -676,7 +676,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tn, 64x64x8_16x32x1_4x4_4x8_4x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -706,7 +706,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_tn, 128x32x8_32x16x1_4x4_8x4_4x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -736,7 +736,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tn, 32x64x16_8x16x1_2x2_4x8_4x4, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -766,7 +766,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tn, 64x32x16_16x8x1_2x2_8x4_4x4, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -796,7 +796,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tn, 64x64x8_16x16x1_4x2_4x8_4x4, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -826,7 +826,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_tn, 64x128x8_16x32x1_4x4_4x8_4x4, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -856,6 +856,6 @@ CUTLASS_TEST_L1(SM50_device_qgemm_tn, 128x64x8_32x16x1_4x4_8x4_4x4, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )

View File

@@ -76,7 +76,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_tt, 8x32x8_8x32x1_2x4_4x8_1x1, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -106,7 +106,7 @@ CUTLASS_TEST_L0(SM50_device_qgemm_tt, 16x32x8_16x32x1_4x4_4x8_1x1, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -136,7 +136,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tt, 8x32x8_8x16x1_2x2_4x8_1x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -166,7 +166,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_tt, 8x64x8_8x32x1_2x4_4x8_1x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -196,7 +196,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_tt, 16x32x8_16x16x1_4x2_4x8_1x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -226,7 +226,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_tt, 16x64x8_16x32x1_4x4_4x8_1x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -256,7 +256,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_tt, 32x32x8_32x16x1_4x4_8x4_1x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -286,7 +286,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tt, 32x32x8_16x32x1_4x4_4x8_2x1, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -316,7 +316,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tt, 16x32x8_8x16x1_2x2_4x8_2x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -346,7 +346,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tt, 16x64x8_8x32x1_2x4_4x8_2x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -376,7 +376,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tt, 32x32x8_16x16x1_4x2_4x8_2x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -406,7 +406,7 @@ CUTLASS_TEST_L0(SM50_device_qgemm_tt, 32x64x8_16x32x1_4x4_4x8_2x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -436,7 +436,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_tt, 64x32x8_32x16x1_4x4_8x4_2x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -466,7 +466,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tt, 16x64x16_8x16x1_2x2_4x8_2x4, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -496,7 +496,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tt, 32x32x8_16x8x1_2x2_8x4_2x4, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -526,7 +526,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_tt, 32x64x8_16x16x1_4x2_4x8_2x4, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -556,7 +556,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_tt, 32x128x8_16x32x1_4x4_4x8_2x4, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -586,7 +586,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_tt, 64x64x8_32x16x1_4x4_8x4_2x4, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -616,7 +616,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tt, 32x32x8_8x16x1_2x2_4x8_4x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -646,7 +646,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tt, 64x32x8_16x16x1_4x2_4x8_4x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -676,7 +676,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tt, 64x64x8_16x32x1_4x4_4x8_4x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -706,7 +706,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_tt, 128x32x8_32x16x1_4x4_8x4_4x2, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -736,7 +736,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tt, 32x64x16_8x16x1_2x2_4x8_4x4, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -766,7 +766,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tt, 64x32x16_16x8x1_2x2_8x4_4x4, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -796,7 +796,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tt, 64x64x8_16x16x1_4x2_4x8_4x4, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -826,7 +826,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_tt, 64x128x8_16x32x1_4x4_4x8_4x4, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
@@ -856,6 +856,6 @@ CUTLASS_TEST_L1(SM50_device_qgemm_tt, 128x64x8_32x16x1_4x4_8x4_4x4, {
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2 // Stages
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
EXPECT_TRUE(test::gemm::device::TestAllGemmBasic<Gemm>());
} )

View File

@@ -0,0 +1,133 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide SYMM interface
*/
#include <iostream>
#include "../../common/cutlass_unit_test.h"
#include "cutlass/blas3.h"
#include "cutlass/gemm/device/symm.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/symm_complex.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "testbed_symm_universal.h"
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Symm_cf64n_cf64n_ls_l_tensor_op_f64_gaussian, 32x32x16_16x16x16) {
using ElementOutput = cutlass::complex<double>;
using ElementAccumulator = cutlass::complex<double>;
using Symm = cutlass::gemm::device::Symm<
cutlass::complex<double>,
cutlass::layout::ColumnMajor,
cutlass::SideMode::kLeft,
cutlass::FillMode::kLower,
cutlass::complex<double>,
cutlass::layout::ColumnMajor,
ElementOutput,
cutlass::layout::ColumnMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<16, 16, 16>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput,
1,
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4,
1,
1,
false,
cutlass::arch::OpMultiplyAddGaussianComplex
>;
EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal<Symm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Symm_cf64n_cf64n_rs_u_tensor_op_f64, 64x64x16_32x32x16) {
using ElementOutput = cutlass::complex<double>;
using ElementAccumulator = cutlass::complex<double>;
using Symm = cutlass::gemm::device::Symm<
cutlass::complex<double>,
cutlass::layout::ColumnMajor,
cutlass::SideMode::kRight,
cutlass::FillMode::kUpper,
cutlass::complex<double>,
cutlass::layout::ColumnMajor,
ElementOutput,
cutlass::layout::ColumnMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<64, 64, 16>,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput,
1,
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4,
1,
1,
false,
cutlass::arch::OpMultiplyAddComplex
>;
EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal<Symm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)

View File

@@ -0,0 +1,135 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide SYMM interface
*/
#include <iostream>
#include "../../common/cutlass_unit_test.h"
#include "cutlass/blas3.h"
#include "cutlass/gemm/device/symm.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/symm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "testbed_symm_universal.h"
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Symm_f64n_f64n_rs_l_tensor_op_f64, 32x32x16_16x16x16) {
using ElementA = double;
using LayoutA = cutlass::layout::ColumnMajor;
using ElementB = double;
using LayoutB = cutlass::layout::ColumnMajor;
using ElementC = double;
using LayoutC = cutlass::layout::ColumnMajor;
using ElementAccumulator = double;
using Symm = cutlass::gemm::device::Symm<
ElementA,
LayoutA,
cutlass::SideMode::kRight,
cutlass::FillMode::kLower,
ElementB,
LayoutB,
ElementC,
LayoutC,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<16, 16, 16>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
ElementC,
1,
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4
>;
EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal<Symm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Symm_f64t_f64t_ls_l_tensor_op_f64, 128x128x16_32x64x16) {
using ElementA = double;
using LayoutA = cutlass::layout::RowMajor;
using ElementB = double;
using LayoutB = cutlass::layout::RowMajor;
using ElementC = double;
using LayoutC = cutlass::layout::RowMajor;
using ElementAccumulator = double;
using Symm = cutlass::gemm::device::Symm<
ElementA,
LayoutA,
cutlass::SideMode::kLeft,
cutlass::FillMode::kLower,
ElementB,
LayoutB,
ElementC,
LayoutC,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<128, 128, 16>,
cutlass::gemm::GemmShape<32, 64, 16>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
ElementC,
1,
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3
>;
EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal<Symm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)

View File

@@ -0,0 +1,150 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide SYRK interface
*/
#include <iostream>
#include "../../common/cutlass_unit_test.h"
#include "cutlass/blas3.h"
#include "cutlass/gemm/device/rank_2k.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/rank_2k.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "testbed_rank2k_universal.h"
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Syr2k_cf64n_cf64n_l_tensor_op_f64, 32x32x16_16x16x16) {
using ElementA = cutlass::complex<double>;
using LayoutA = cutlass::layout::ColumnMajor;
using ElementB = cutlass::complex<double>;
using LayoutB = cutlass::layout::ColumnMajor;
using ElementC = cutlass::complex<double>;
using LayoutC = cutlass::layout::ColumnMajor;
using ElementAccumulator = cutlass::complex<double>;
using Rank2K = cutlass::gemm::device::Rank2K<
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
cutlass::FillMode::kLower,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<16, 16, 16>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
ElementC,
1,
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4, // kStages
1, // AlignmentA
1, // AlignmentB
false, // SplitKSerial
cutlass::arch::OpMultiplyAddComplex,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
cutlass::BlasMode::kSymmetric
>;
EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal<Rank2K>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Syr2k_cf64n_cf64t_u_tensor_op_f64, 32x32x16_16x16x16) {
using ElementA = cutlass::complex<double>;
using LayoutA = cutlass::layout::ColumnMajor;
using ElementB = cutlass::complex<double>;
using LayoutB = cutlass::layout::ColumnMajor;
using ElementC = cutlass::complex<double>;
using LayoutC = cutlass::layout::RowMajor;
using ElementAccumulator = cutlass::complex<double>;
using Rank2K = cutlass::gemm::device::Rank2K<
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
cutlass::FillMode::kUpper,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<16, 16, 16>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
ElementC,
1,
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4, // kStages
1, // AlignmentA
1, // AlignmentB
false, // SplitKSerial
cutlass::arch::OpMultiplyAddComplex,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
cutlass::BlasMode::kSymmetric
>;
EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal<Rank2K>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)

View File

@@ -0,0 +1,134 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide SYRK interface
*/
#include <iostream>
#include "../../common/cutlass_unit_test.h"
#include "cutlass/blas3.h"
#include "cutlass/gemm/device/rank_2k.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/rank_2k.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "testbed_rank2k_universal.h"
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Syr2k_f64n_f64n_l_tensor_op_f64, 32x32x16_16x16x16) {
using ElementA = double;
using LayoutA = cutlass::layout::ColumnMajor;
using ElementB = double;
using LayoutB = cutlass::layout::ColumnMajor;
using ElementC = double;
using LayoutC = cutlass::layout::ColumnMajor;
using ElementAccumulator = double;
using Rank2K = cutlass::gemm::device::Rank2K<
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
cutlass::FillMode::kLower,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<16, 16, 16>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
ElementC,
1,
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4
>;
EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal<Rank2K>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Syr2k_f64t_f64n_l_tensor_op_f64, 128x128x16_32x64x16) {
using ElementA = double;
using LayoutA = cutlass::layout::RowMajor;
using ElementB = double;
using LayoutB = cutlass::layout::RowMajor;
using ElementC = double;
using LayoutC = cutlass::layout::ColumnMajor;
using ElementAccumulator = double;
using Rank2K = cutlass::gemm::device::Rank2K<
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
cutlass::FillMode::kLower,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<128, 128, 16>,
cutlass::gemm::GemmShape<32, 64, 16>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
ElementC,
1,
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3
>;
EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal<Rank2K>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)

View File

@@ -0,0 +1,136 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide SYRK interface
*/
#include <iostream>
#include "../../common/cutlass_unit_test.h"
#include "cutlass/blas3.h"
#include "cutlass/gemm/device/rank_k.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/rank_k_complex.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "testbed_rank_k_universal.h"
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Syrk_cf64n_cf64n_l_tensor_op_f64, 32x32x16_16x16x16) {
using ElementA = cutlass::complex<double>;
using LayoutA = cutlass::layout::ColumnMajor;
using ElementC = cutlass::complex<double>;
using LayoutC = cutlass::layout::ColumnMajor;
using ElementAccumulator = cutlass::complex<double>;
using RankK = cutlass::gemm::device::RankK<
ElementA,
LayoutA,
ElementC,
LayoutC,
cutlass::FillMode::kLower,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<16, 16, 16>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
ElementC,
1,
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4, // kStages
1, // AlignmentA
false, // SplitKSerial
cutlass::arch::OpMultiplyAddComplex,
cutlass::ComplexTransform::kNone,
cutlass::BlasMode::kSymmetric
>;
EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal<RankK>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Syrk_cf64n_cf64t_l_tensor_op_f64_gaussian, 32x32x16_16x16x16) {
using ElementA = cutlass::complex<double>;
using LayoutA = cutlass::layout::ColumnMajor;
using ElementC = cutlass::complex<double>;
using LayoutC = cutlass::layout::RowMajor;
using ElementAccumulator = cutlass::complex<double>;
using RankK = cutlass::gemm::device::RankK<
ElementA,
LayoutA,
ElementC,
LayoutC,
cutlass::FillMode::kLower,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<16, 16, 16>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
ElementC,
1,
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4, // kStages
1, // AlignmentA
false, // SplitKSerial
cutlass::arch::OpMultiplyAddGaussianComplex,
cutlass::ComplexTransform::kNone,
cutlass::BlasMode::kSymmetric
>;
EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal<RankK>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)

View File

@@ -0,0 +1,126 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide SYRK interface
*/
#include <iostream>
#include "../../common/cutlass_unit_test.h"
#include "cutlass/blas3.h"
#include "cutlass/gemm/device/rank_k.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/rank_k_complex.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "testbed_rank_k_universal.h"
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Syrk_f64n_f64t_l_tensor_op_f64, 128x64x16_64x32x16) {
using ElementA = double;
using LayoutA = cutlass::layout::ColumnMajor;
using ElementC = double;
using LayoutC = cutlass::layout::RowMajor;
using ElementAccumulator = double;
using RankK = cutlass::gemm::device::RankK<
ElementA,
LayoutA,
ElementC,
LayoutC,
cutlass::FillMode::kLower,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<128, 64, 16>,
cutlass::gemm::GemmShape<64, 32, 16>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
ElementC,
1,
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4
>;
EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal<RankK>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Syrk_f64t_f64n_l_tensor_op_f64, 32x32x16_16x16x16) {
using ElementA = double;
using LayoutA = cutlass::layout::RowMajor;
using ElementC = double;
using LayoutC = cutlass::layout::ColumnMajor;
using ElementAccumulator = double;
using RankK = cutlass::gemm::device::RankK<
ElementA,
LayoutA,
ElementC,
LayoutC,
cutlass::FillMode::kLower,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<16, 16, 16>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
ElementC,
1,
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4
>;
EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal<RankK>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)

View File

@@ -50,9 +50,11 @@
#include "cutlass/util/reference/host/gemm.h"
#include "testbed_utils.h"
#include "testbed_universal.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
namespace test {
namespace gemm {
@@ -309,7 +311,7 @@ struct Testbed {
throw std::runtime_error("cudaGetDeviceProperties() failed");
}
if (properties.sharedMemPerMultiprocessor < smem_size) {
if (properties.sharedMemPerBlockOptin < smem_size) {
return false;
}
@@ -319,10 +321,19 @@ struct Testbed {
/// Executes one test
bool run(
cutlass::gemm::GemmCoord problem_size,
cutlass::gemm::GemmCoord problem_size,
int split_k_slices = 1,
ElementCompute alpha = ElementCompute(1),
ElementCompute beta = ElementCompute(0)) {
ElementCompute alpha = ElementCompute(1),
ElementCompute beta = ElementCompute(0))
{
/*
std::cout << "\n-----------------------\n";
std::cout << "problem size: " << problem_size << "\n";
std::cout << "split_k_slices: " << split_k_slices << "\n";
std::cout << "alpha: " << alpha << "\n";
std::cout << "beta: " << beta << "\n";
std::cout << "-----------------------\n\n";
*/
// Waive test if insufficient CUDA device
if (!sufficient()) {
@@ -387,7 +398,7 @@ struct Testbed {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Gemm, bool Relu=false>
bool TestAllGemm(
bool TestAllGemmBasic(
const typename Gemm::LayoutA::Stride& stride_factor_A = typename Gemm::LayoutA::Stride(),
const typename Gemm::LayoutB::Stride& stride_factor_B = typename Gemm::LayoutB::Stride(),
const typename Gemm::LayoutC::Stride& stride_factor_C = typename Gemm::LayoutC::Stride()) {
@@ -477,6 +488,52 @@ bool TestAllGemm(
return passed;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Gemm, bool Relu=false>
bool TestAllGemm(
const typename Gemm::LayoutA::Stride& stride_factor_A,
const typename Gemm::LayoutB::Stride& stride_factor_B = typename Gemm::LayoutB::Stride(),
const typename Gemm::LayoutC::Stride& stride_factor_C = typename Gemm::LayoutC::Stride())
{
// Test basic GEMM with non-default stride factors
return TestAllGemmBasic<Gemm, Relu>(stride_factor_A, stride_factor_B, stride_factor_C);
}
template <typename Gemm, bool Relu=false>
bool TestAllGemm()
{
#ifdef NDEBUG
// Non-debug builds also test basic GEMM with default stride factors
if (!TestAllGemmBasic<Gemm, Relu>()) {
return false;
}
#endif // NDEBUG
// Test universal GEMM
#if 0
// Define the universal kernel
using UniversalKernel = cutlass::gemm::kernel::GemmUniversal<
typename Gemm::GemmKernel::Mma, // Mma
typename Gemm::GemmKernel::Epilogue, // Epilogue
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<> // ThreadblockSwizzle
>;
#else
// Define the streamk universal kernel
using UniversalKernel = cutlass::gemm::kernel::GemmUniversalStreamk<
typename Gemm::GemmKernel::Mma, // Mma
typename Gemm::GemmKernel::Epilogue, // Epilogue
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK // ThreadblockSwizzle
>;
#endif
// Define the universal adaptor
using UniversalGemm = cutlass::gemm::device::GemmUniversalAdapter<UniversalKernel>;
// Test universal GEMM
return TestAllGemmUniversal<UniversalGemm, Relu>();
}
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Gemm>
bool TestGemmPerf(int iterations = 1) {

View File

@@ -128,7 +128,7 @@ struct TestbedComplex : public Testbed<Gemm> {
throw std::runtime_error("cudaGetDeviceProperties() failed");
}
if (properties.sharedMemPerMultiprocessor < smem_size) {
if (properties.sharedMemPerBlockOptin < smem_size) {
return false;
}

View File

@@ -388,7 +388,7 @@ struct TestbedGemmWithBroadcast {
throw std::runtime_error("cudaGetDeviceProperties() failed");
}
if (properties.sharedMemPerMultiprocessor < smem_size) {
if (properties.sharedMemPerBlockOptin < smem_size) {
return false;
}

View File

@@ -375,7 +375,7 @@ struct TestbedGemmWithReduction {
throw std::runtime_error("cudaGetDeviceProperties() failed");
}
if (properties.sharedMemPerMultiprocessor < smem_size) {
if (properties.sharedMemPerBlockOptin < smem_size) {
return false;
}

View File

@@ -312,7 +312,8 @@ template <typename ThreadblockShape,
cutlass::gemm::kernel::GroupScheduleMode... Args>
struct TestbedGroupedGemmScheduler {
using BaselinePV = BaselineProblemVisitor<cutlass::gemm::kernel::detail::GemmGroupedProblemSizeHelper<Transpose>,
using PSHelper = cutlass::gemm::kernel::detail::GemmGroupedProblemSizeHelper<ThreadblockShape, Transpose>;
using BaselinePV = BaselineProblemVisitor<PSHelper,
ThreadblockShape,
PrefetchTileCount,
ThreadCount>;

View File

@@ -130,7 +130,7 @@ struct InterleavedTestbed {
throw std::runtime_error("cudaGetDeviceProperties() failed");
}
if (properties.sharedMemPerMultiprocessor < smem_size) {
if (properties.sharedMemPerBlockOptin < smem_size) {
return false;
}

View File

@@ -140,7 +140,7 @@ public:
throw std::runtime_error("cudaGetDeviceProperties() failed");
}
if (properties.sharedMemPerMultiprocessor < smem_size) {
if (properties.sharedMemPerBlockOptin < smem_size) {
return false;
}

View File

@@ -298,7 +298,7 @@ struct TestbedRank2KUniversal {
throw std::runtime_error("cudaGetDeviceProperties() failed");
}
if (properties.sharedMemPerMultiprocessor < smem_size) {
if (properties.sharedMemPerBlockOptin < smem_size) {
return false;
}

View File

@@ -286,7 +286,7 @@ struct TestbedRank2KUniversal {
throw std::runtime_error("cudaGetDeviceProperties() failed");
}
if (properties.sharedMemPerMultiprocessor < smem_size) {
if (properties.sharedMemPerBlockOptin < smem_size) {
return false;
}

View File

@@ -323,7 +323,7 @@ struct SparseTestbed {
throw std::runtime_error("cudaGetDeviceProperties() failed");
}
if (properties.sharedMemPerMultiprocessor < smem_size) {
if (properties.sharedMemPerBlockOptin < smem_size) {
return false;
}

View File

@@ -88,7 +88,7 @@ struct TestbedSplitK : public Testbed<Gemm> {
throw std::runtime_error("cudaGetDeviceProperties() failed");
}
if (properties.sharedMemPerMultiprocessor < smem_size) {
if (properties.sharedMemPerBlockOptin < smem_size) {
return false;
}

View File

@@ -324,7 +324,7 @@ struct TestbedSymmUniversal {
throw std::runtime_error("cudaGetDeviceProperties() failed");
}
if (properties.sharedMemPerMultiprocessor < smem_size) {
if (properties.sharedMemPerBlockOptin < smem_size) {
return false;
}

View File

@@ -364,7 +364,7 @@ struct TestbedTrmmUniversal {
throw std::runtime_error("cudaGetDeviceProperties() failed");
}
if (properties.sharedMemPerMultiprocessor < smem_size) {
if (properties.sharedMemPerBlockOptin < smem_size) {
return false;
}

View File

@@ -58,7 +58,7 @@ namespace device {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Gemm>
template <typename Gemm, bool Relu = false>
struct TestbedUniversal {
using ElementAccumulator = typename Gemm::ElementAccumulator;
@@ -158,9 +158,10 @@ struct TestbedUniversal {
// It is possible to randomly initialize to all zeros, so override this with non-zeros
// in the upper left corner of each operand.
tensor_A.host_view().at({0, 0}) = typename Gemm::ElementA(1);
tensor_B.host_view().at({0, 0}) = typename Gemm::ElementB(1);
tensor_C.host_view().at({0, 0}) = typename Gemm::ElementC(1);
cutlass::Coord<2> origin(0);
tensor_A.host_view().at(origin) = typename Gemm::ElementA(1);
tensor_B.host_view().at(origin) = typename Gemm::ElementB(1);
tensor_C.host_view().at(origin) = typename Gemm::ElementC(1);
cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view());
@@ -253,6 +254,17 @@ struct TestbedUniversal {
ElementAccumulator(0)
);
if (Relu) {
for (int i = 0; i < problem_size.m(); ++i) {
for (int j = 0; j < problem_size.n(); ++j) {
reference_D.at(cutlass::MatrixCoord(i, j)) =
((ElementCompute)reference_D.at(cutlass::MatrixCoord(i, j)) < (ElementCompute)0)
? (typename Gemm::ElementC)0
: reference_D.at(cutlass::MatrixCoord(i, j));
}
}
}
return compare_reference(problem_size, alpha, beta);
}
@@ -278,7 +290,7 @@ struct TestbedUniversal {
throw std::runtime_error("cudaGetDeviceProperties() failed");
}
if (properties.sharedMemPerMultiprocessor < smem_size) {
if (properties.sharedMemPerBlockOptin < smem_size) {
return false;
}
@@ -288,10 +300,20 @@ struct TestbedUniversal {
/// Executes one test
bool run(
cutlass::gemm::GemmUniversalMode mode,
cutlass::gemm::GemmCoord problem_size,
cutlass::gemm::GemmCoord problem_size,
int batch_count = 1,
ElementCompute alpha = ElementCompute(1),
ElementCompute beta = ElementCompute(0)) {
ElementCompute alpha = ElementCompute(1),
ElementCompute beta = ElementCompute(0))
{
/*
std::cout << "\n-----------------------\n";
std::cout << "mode: " << (int) mode << "\n";
std::cout << "problem size: " << problem_size << "\n";
std::cout << "batch_count: " << batch_count << "\n";
std::cout << "alpha: " << alpha << "\n";
std::cout << "beta: " << beta << "\n";
std::cout << "-----------------------\n\n";
*/
// Waive test if insufficient CUDA device
if (!sufficient()) {
@@ -359,7 +381,7 @@ struct TestbedUniversal {
};
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Gemm>
template <typename Gemm, bool Relu = false>
bool TestGemmUniversal(
cutlass::gemm::GemmCoord const & problem_size,
cutlass::gemm::GemmUniversalMode mode,
@@ -369,7 +391,7 @@ bool TestGemmUniversal(
bool passed = true;
TestbedUniversal<Gemm> testbed;
TestbedUniversal<Gemm, Relu> testbed;
using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute;
@@ -384,7 +406,7 @@ bool TestGemmUniversal(
return passed;
}
template <typename Gemm>
template <typename Gemm, bool Relu = false>
bool TestAllGemmUniversal() {
bool passed = true;
@@ -412,9 +434,9 @@ bool TestAllGemmUniversal() {
cutlass::platform::is_same<typename Gemm::ElementB, int8_t>::value &&
(cutlass::platform::is_same<typename Gemm::LayoutA, cutlass::layout::RowMajor>::value ||
cutlass::platform::is_same<typename Gemm::LayoutB, cutlass::layout::ColumnMajor>::value) ? 4 : kAlignment;
cutlass::gemm::GemmUniversalMode modes[] = {
cutlass::gemm::GemmUniversalMode::kGemm,
};
@@ -428,8 +450,8 @@ bool TestAllGemmUniversal() {
};
int problem_size_k[] = {
kAlignmentK,
Gemm::ThreadblockShape::kK * Gemm::kStages - kAlignmentK,
kAlignmentK,
Gemm::ThreadblockShape::kK * Gemm::kStages - kAlignmentK,
Gemm::ThreadblockShape::kK * Gemm::kStages * 3 - kAlignmentK
};
@@ -468,7 +490,7 @@ bool TestAllGemmUniversal() {
cutlass::gemm::GemmCoord problem_size(m, n, k);
TestbedUniversal<Gemm> testbed;
TestbedUniversal<Gemm, Relu> testbed;
passed = testbed.run(
mode,

View File

@@ -0,0 +1,137 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide TRMM interface
*/
#include <iostream>
#include "../../common/cutlass_unit_test.h"
#include "cutlass/blas3.h"
#include "cutlass/gemm/device/trmm.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/trmm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "testbed_trmm_universal.h"
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Trmm_cf64n_cf64n_cf64t_ls_u_nu_tensor_op_f64_gaussian, 32x32x16_16x16x16) {
using ElementOutput = cutlass::complex<double>;
using ElementAccumulator = cutlass::complex<double>;
using Trmm = cutlass::gemm::device::Trmm<
cutlass::complex<double>,
cutlass::layout::ColumnMajor,
cutlass::SideMode::kLeft,
cutlass::FillMode::kUpper,
cutlass::DiagType::kNonUnit,
cutlass::complex<double>,
cutlass::layout::ColumnMajor,
ElementOutput,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<16, 16, 16>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput,
1,
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4,
1,
1,
false,
cutlass::arch::OpMultiplyAddGaussianComplex,
cutlass::ComplexTransform::kNone
>;
EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal<Trmm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Trmm_cf64h_cf64n_cf64t_ls_u_nu_tensor_op_f64, 64x64x16_32x32x16) {
using ElementOutput = cutlass::complex<double>;
using ElementAccumulator = cutlass::complex<double>;
using Trmm = cutlass::gemm::device::Trmm<
cutlass::complex<double>,
cutlass::layout::ColumnMajor,
cutlass::SideMode::kLeft,
cutlass::FillMode::kUpper,
cutlass::DiagType::kNonUnit,
cutlass::complex<double>,
cutlass::layout::ColumnMajor,
ElementOutput,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<64, 64, 16>,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput,
1,
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4,
1,
1,
false,
cutlass::arch::OpMultiplyAddComplex,
cutlass::ComplexTransform::kConjugate
>;
EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal<Trmm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)

View File

@@ -0,0 +1,127 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide TRMM interface
*/
#include <iostream>
#include "../../common/cutlass_unit_test.h"
#include "cutlass/blas3.h"
#include "cutlass/gemm/device/trmm.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/trmm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "testbed_trmm_universal.h"
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Trmm_f64n_f64n_f64t_rs_l_nu_tensor_op_f64, 32x32x16_16x16x16) {
using ElementOutput = double;
using ElementAccumulator = double;
using Trmm = cutlass::gemm::device::Trmm<
double,
cutlass::layout::ColumnMajor,
cutlass::SideMode::kRight,
cutlass::FillMode::kLower,
cutlass::DiagType::kNonUnit,
double,
cutlass::layout::ColumnMajor,
ElementOutput,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<16, 16, 16>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput,
1,
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4
>;
EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal<Trmm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM90_Device_Trmm_f64t_f64t_f64n_rs_l_nu_tensor_op_f64, 64x64x16_32x32x16) {
using ElementOutput = double;
using ElementAccumulator = double;
using Trmm = cutlass::gemm::device::Trmm<
double,
cutlass::layout::RowMajor,
cutlass::SideMode::kRight,
cutlass::FillMode::kLower,
cutlass::DiagType::kNonUnit,
double,
cutlass::layout::RowMajor,
ElementOutput,
cutlass::layout::ColumnMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm90,
cutlass::gemm::GemmShape<64, 64, 16>,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<16, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput,
1,
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4
>;
EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal<Trmm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)

View File

@@ -257,7 +257,7 @@ struct SparseTestbed {
throw std::runtime_error("cudaGetDeviceProperties() failed");
}
if (properties.sharedMemPerMultiprocessor < smem_size) {
if (properties.sharedMemPerBlockOptin < smem_size) {
return false;
}

View File

@@ -37,6 +37,8 @@ cutlass_test_unit_add_executable(
gemm_complex_sm80.cu
gemm_sparse_sm80.cu
gemm_gaussian_complex_sm80.cu
gemm_sm90.cu
gemm_complex_sm90.cu
wmma_sm70.cu
wmma_sm72.cu
wmma_sm75.cu

View File

@@ -56,7 +56,7 @@
////////////////////////////////////////////////////////////////////////////////////////////////////
// complex<double> * complex<double> => complex<double>
// Input data type: complex<double>
// Math instruction: MMA.884.F64.F64
// Math instruction: mma.sync.aligned.m8n8k4.f64.f64.f64.f64
// Output data type: complex<double>
///////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM80_warp_gemm_complex_tensor_op_f64, 8x8x4_8x8x4_nt) {
@@ -293,7 +293,7 @@ TEST(SM80_warp_gemm_complex_tensor_op_f64, 16x16x4_8x8x4_tn) {
///////////////////////////////////////////////////////////////////////////////////////////////////
// complex<float> * complex<float> => complex<float>
// Input data type: complex<float>
// Math instruction: MMA.1688.F32.TF32
// Math instruction: mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32
// Output data type: complex<float>
// Shared memory layout: Congrous
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -495,7 +495,7 @@ TEST(SM80_warp_gemm_complex_tensor_op_f32, 32x32x8_16x8x8_ct) {
///////////////////////////////////////////////////////////////////////////////////////////////////
// complex<float> * complex<float> => complex<float>
// Input data type: complex<float>
// Math instruction: MMA.1688.F32.TF32
// Math instruction: mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32
// Output data type: complex<float>
// Shared memory layout: Crosswise
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -526,7 +526,7 @@ TEST(SM80_warp_gemm_complex_tensor_op_f32, 16x16x8_16x8x8_tn) {
.run();
}
// TEST FAILS crosswise complex<float> TN MMA.1688.F32.TF32 test fails for k = 2*8 = 16
// TEST FAILS crosswise complex<float> TN mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 test fails for k = 2*8 = 16
TEST(SM80_warp_gemm_complex_tensor_op_f32, 16x16x16_16x8x8_tn) {
using Shape = cutlass::gemm::GemmShape<16, 16, 16>;

View File

@@ -0,0 +1,334 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Unit tests for thread-level GEMM with Hopper FP64
*/
#include "../../common/cutlass_unit_test.h"
#include "cutlass/aligned_buffer.h"
#include "cutlass/half.h"
#include "cutlass/gemm/warp/default_mma_complex_tensor_op.h"
#include "cutlass/core_io.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/gemm.h"
#include "testbed.h"
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
TEST(SM90_warp_gemm_complex_tensor_op_f64, 16x8x4_16x8x4_nt) {
using Shape = cutlass::gemm::GemmShape<16, 8, 4>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>;
using Element = cutlass::complex<double>;
using ElementC = cutlass::complex<double>;
using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b;
using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b;
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp<
Shape,
InstructionShape,
Element,
LayoutA,
Element,
LayoutB,
ElementC,
cutlass::layout::RowMajor
>::Type;
test::gemm::warp::TestbedComplex<MmaTensorOp, cutlass::gemm::GemmShape<16, 8, 4> >().run();
}
TEST(SM90_warp_gemm_complex_tensor_op_f64, 16x16x4_16x8x4_nt) {
using Shape = cutlass::gemm::GemmShape<16, 16, 4>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>;
using Element = cutlass::complex<double>;
using ElementC = cutlass::complex<double>;
using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b;
using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b;
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp<
Shape,
InstructionShape,
Element,
LayoutA,
Element,
LayoutB,
ElementC,
cutlass::layout::RowMajor
>::Type;
test::gemm::warp::TestbedComplex<MmaTensorOp, cutlass::gemm::GemmShape<16, 16, 4> >().run();
}
TEST(SM90_warp_gemm_complex_tensor_op_f64, 16x32x4_16x8x4_nt) {
using Shape = cutlass::gemm::GemmShape<16, 32, 4>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>;
using Element = cutlass::complex<double>;
using ElementC = cutlass::complex<double>;
using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b;
using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b;
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp<
Shape,
InstructionShape,
Element,
LayoutA,
Element,
LayoutB,
ElementC,
cutlass::layout::RowMajor
>::Type;
test::gemm::warp::TestbedComplex<MmaTensorOp, cutlass::gemm::GemmShape<16, 32, 4> >().run();
}
TEST(SM90_warp_gemm_complex_tensor_op_f64, 32x16x4_16x8x4_nt) {
using Shape = cutlass::gemm::GemmShape<32, 16, 4>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>;
using Element = cutlass::complex<double>;
using ElementC = cutlass::complex<double>;
using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b;
using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b;
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp<
Shape,
InstructionShape,
Element,
LayoutA,
Element,
LayoutB,
ElementC,
cutlass::layout::RowMajor
>::Type;
test::gemm::warp::TestbedComplex<MmaTensorOp, cutlass::gemm::GemmShape<32, 16, 4> >().run();
}
TEST(SM90_warp_gemm_complex_tensor_op_f64, 32x32x4_16x8x4_nt) {
using Shape = cutlass::gemm::GemmShape<32, 32, 4>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>;
using Element = cutlass::complex<double>;
using ElementC = cutlass::complex<double>;
using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b;
using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b;
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp<
Shape,
InstructionShape,
Element,
LayoutA,
Element,
LayoutB,
ElementC,
cutlass::layout::RowMajor
>::Type;
test::gemm::warp::TestbedComplex<MmaTensorOp, cutlass::gemm::GemmShape<32, 32, 4> >().run();
}
TEST(SM90_warp_gemm_complex_tensor_op_f64, 32x32x4_16x8x4_nh) {
using Shape = cutlass::gemm::GemmShape<32, 32, 4>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>;
using Element = cutlass::complex<double>;
using ElementC = cutlass::complex<double>;
using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b;
using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b;
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp<
Shape,
InstructionShape,
Element,
LayoutA,
Element,
LayoutB,
ElementC,
cutlass::layout::RowMajor,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kConjugate
>::Type;
test::gemm::warp::TestbedComplex<MmaTensorOp, cutlass::gemm::GemmShape<32, 32, 4> >().run();
}
TEST(SM90_warp_gemm_complex_tensor_op_f64, 32x32x4_16x8x4_ct) {
using Shape = cutlass::gemm::GemmShape<32, 32, 4>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>;
using Element = cutlass::complex<double>;
using ElementC = cutlass::complex<double>;
using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b;
using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b;
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp<
Shape,
InstructionShape,
Element,
LayoutA,
Element,
LayoutB,
ElementC,
cutlass::layout::RowMajor,
cutlass::ComplexTransform::kConjugate,
cutlass::ComplexTransform::kNone
>::Type;
test::gemm::warp::TestbedComplex<MmaTensorOp, cutlass::gemm::GemmShape<32, 32, 4> >().run();
}
TEST(SM90_warp_gemm_complex_tensor_op_f64, 16x8x4_16x8x4_tn) {
using Shape = cutlass::gemm::GemmShape<16, 8, 4>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>;
using Element = cutlass::complex<double>;
using ElementC = cutlass::complex<double>;
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise128x4;
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise128x4;
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp<
Shape,
InstructionShape,
Element,
LayoutA,
Element,
LayoutB,
ElementC,
cutlass::layout::RowMajor
>::Type;
test::gemm::warp::TestbedComplex<MmaTensorOp, cutlass::gemm::GemmShape<16, 8, 4> >().run();
}
TEST(SM90_warp_gemm_complex_tensor_op_f64, 16x16x4_16x8x4_tn) {
using Shape = cutlass::gemm::GemmShape<16, 16, 4>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>;
using Element = cutlass::complex<double>;
using ElementC = cutlass::complex<double>;
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise128x4;
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise128x4;
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp<
Shape,
InstructionShape,
Element,
LayoutA,
Element,
LayoutB,
ElementC,
cutlass::layout::RowMajor
>::Type;
test::gemm::warp::TestbedComplex<MmaTensorOp, cutlass::gemm::GemmShape<16, 16, 4> >().run();
}
TEST(SM90_warp_gemm_complex_tensor_op_f64, 32x32x16_16x8x4_tn) {
using Shape = cutlass::gemm::GemmShape<32, 32, 16>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>;
using Element = cutlass::complex<double>;
using ElementC = cutlass::complex<double>;
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise128x4;
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise128x4;
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp<
Shape,
InstructionShape,
Element,
LayoutA,
Element,
LayoutB,
ElementC,
cutlass::layout::RowMajor
>::Type;
test::gemm::warp::TestbedComplex<MmaTensorOp, Shape>().run();
}
TEST(SM90_warp_gemm_complex_tensor_op_f64, 64x64x4_16x8x4_tn) {
using Shape = cutlass::gemm::GemmShape<64, 32, 4>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>;
using Element = cutlass::complex<double>;
using ElementC = cutlass::complex<double>;
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise128x4;
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise128x4;
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp<
Shape,
InstructionShape,
Element,
LayoutA,
Element,
LayoutB,
ElementC,
cutlass::layout::RowMajor
>::Type;
test::gemm::warp::TestbedComplex<MmaTensorOp, Shape>().run();
}
#endif // if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)

View File

@@ -0,0 +1,206 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Unit tests for thread-level GEMM with Hopper FP64
*/
#include "../../common/cutlass_unit_test.h"
#include "cutlass/aligned_buffer.h"
#include "cutlass/half.h"
#include "cutlass/gemm/warp/default_mma_tensor_op.h"
#include "cutlass/core_io.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/gemm.h"
#include "testbed.h"
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
TEST(SM90_warp_gemm_tensor_op_congruous_f64, 16x16x4_16x16x4_16x8x4) {
using Shape = cutlass::gemm::GemmShape<16, 16, 4>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>;
using Element = double;
using ElementC = double;
using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b;
using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b;
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type;
test::gemm::warp::Testbed<MmaTensorOp,
cutlass::gemm::GemmShape<16, 16, 4> >()
.run();
}
////////////////////////////////////////////////////////////////////////////////
TEST(SM90_warp_gemm_tensor_op_congruous_f64, 32x16x4_32x16x4_16x8x4) {
using Shape = cutlass::gemm::GemmShape<32, 16, 4>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>;
using Element = double;
using ElementC = double;
using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b;
using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b;
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type;
test::gemm::warp::Testbed<MmaTensorOp,
cutlass::gemm::GemmShape<32, 16, 4> >()
.run();
}
////////////////////////////////////////////////////////////////////////////////
TEST(SM90_warp_gemm_tensor_op_congruous_f64, 32x32x4_32x32x4_16x8x4) {
using Shape = cutlass::gemm::GemmShape<32, 32, 4>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>;
using Element = double;
using ElementC = double;
using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b;
using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b;
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type;
test::gemm::warp::Testbed<MmaTensorOp,
cutlass::gemm::GemmShape<32, 32, 4> >()
.run();
}
////////////////////////////////////////////////////////////////////////////////
TEST(SM90_warp_gemm_tensor_op_congruous_f64, 32x64x4_32x64x4_16x8x4) {
using Shape = cutlass::gemm::GemmShape<32, 64, 4>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>;
using Element = double;
using ElementC = double;
using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b;
using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b;
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type;
test::gemm::warp::Testbed<MmaTensorOp,
cutlass::gemm::GemmShape<32, 64, 4> >()
.run();
}
////////////////////////////////////////////////////////////////////////////////
TEST(SM90_warp_gemm_tensor_op_crosswise_f64, 16x16x16_16x16x16_16x8x4) {
using Shape = cutlass::gemm::GemmShape<16, 16, 16>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>;
using Element = double;
using ElementC = double;
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise;
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise;
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type;
test::gemm::warp::Testbed<MmaTensorOp,
cutlass::gemm::GemmShape<16, 16, 16> >()
.run();
}
////////////////////////////////////////////////////////////////////////////////
TEST(SM90_warp_gemm_tensor_op_crosswise_f64, 32x32x16_32x32x16_16x8x4) {
using Shape = cutlass::gemm::GemmShape<32, 32, 16>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>;
using Element = double;
using ElementC = double;
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise;
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise;
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type;
test::gemm::warp::Testbed<MmaTensorOp,
cutlass::gemm::GemmShape<32, 32, 16> >()
.run();
}
////////////////////////////////////////////////////////////////////////////////
TEST(SM90_warp_gemm_tensor_op_crosswise_f64, 64x32x16_64x32x16_16x8x4) {
using Shape = cutlass::gemm::GemmShape<64, 32, 16>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>;
using Element = double;
using ElementC = double;
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise;
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise;
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type;
test::gemm::warp::Testbed<MmaTensorOp,
cutlass::gemm::GemmShape<64, 32, 16> >()
.run();
}
////////////////////////////////////////////////////////////////////////////////
TEST(SM90_warp_gemm_tensor_op_crosswise_f64, 32x64x16_32x64x16_16x8x4) {
using Shape = cutlass::gemm::GemmShape<32, 64, 16>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>;
using Element = double;
using ElementC = double;
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise;
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise;
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type;
test::gemm::warp::Testbed<MmaTensorOp,
cutlass::gemm::GemmShape<32, 64, 16> >()
.run();
}
////////////////////////////////////////////////////////////////////////////////
#endif // if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)