mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-11 08:50:09 +00:00
Support for Mixed Input TensorOp (#1084)
* Passing warp-level mixed input F16*(S8/U8) tests * passing device-level mixed input F16*(S8/U8) tests * add to profiler - I8 (111 TFLOPs), U (123 TFLOPs) * fast numeric conversions (I8 = 132 TFLOPs, U8 = 148 TFLOPs) * Speedup reference compilation (REVERT THIS COMMIT) * wider_add.u32_packed_sub.f16x2 (I8 = 132TFLOP/s, U8 = 170 TFLOP/s) * Improve s8->f16 cvt and support bf16*u8 @158 TFLOPs * BF16 * S8 (142 TFLOPs) * Handle mixed-input upcast on OperandA (Support [S8|U8]*[F16|BF16] * rename OpMultiplyAddMixedInput to OpMultiplyAddMixedInputUpcast * Add device-level test and profiler support for upcast on operand A * Move shfl before the cvt and reduce #shfls by 1/2 * fix smem_usage calculation for mixed_input types * uncomment the stuff (getting ready for merge) * profiler changes and mixed-input reference * mixed input reference are in a new file * use platform instead of std * comments and typo only * Use CreateGemmOperator and delete CreateMixedInputGemmOperator * copyright for new files * rebase follow-up
This commit is contained in:
@@ -41,6 +41,7 @@ cutlass_test_unit_add_executable(
|
||||
tensor_view.cu
|
||||
matrix_coord.cu
|
||||
numeric_conversion.cu
|
||||
fast_numeric_conversion.cu
|
||||
functional.cu
|
||||
)
|
||||
|
||||
|
||||
176
test/unit/core/fast_numeric_conversion.cu
Normal file
176
test/unit/core/fast_numeric_conversion.cu
Normal file
@@ -0,0 +1,176 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Unit tests for conversion operators.
|
||||
*/
|
||||
|
||||
#include "../common/cutlass_unit_test.h"
|
||||
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace test {
|
||||
namespace core {
|
||||
namespace kernel {
|
||||
|
||||
/// Simple conversion function
|
||||
template <typename Destination, typename Source, int Count>
|
||||
__global__ void convert(
|
||||
cutlass::Array<Destination, Count> *destination,
|
||||
cutlass::Array<Source, Count> const *source) {
|
||||
|
||||
cutlass::FastNumericArrayConverter<Destination, Source, Count> convert;
|
||||
|
||||
*destination = convert(*source);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Destination, typename Source, int Count>
|
||||
void run_test_integer_range_limited() {
|
||||
const int kN = Count;
|
||||
|
||||
dim3 grid(1, 1);
|
||||
dim3 block(1, 1);
|
||||
|
||||
cutlass::HostTensor<Destination, cutlass::layout::RowMajor> destination({1, kN});
|
||||
cutlass::HostTensor<Source, cutlass::layout::RowMajor> source({1, kN});
|
||||
|
||||
for (int i = 0; i < kN; ++i) {
|
||||
source.host_data()[i] = Source(i % 4);
|
||||
}
|
||||
|
||||
source.sync_device();
|
||||
|
||||
convert<Destination, Source, kN><<< grid, block >>>(
|
||||
reinterpret_cast<cutlass::Array<Destination, kN> *>(destination.device_data()),
|
||||
reinterpret_cast<cutlass::Array<Source, kN> const *>(source.device_data())
|
||||
);
|
||||
|
||||
destination.sync_host();
|
||||
|
||||
for (int i = 0; i < kN; ++i) {
|
||||
EXPECT_TRUE(float(destination.host_data()[i]) == float(source.host_data()[i]));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename Destination, typename Source, int Count>
|
||||
void run_test_integer_range_all() {
|
||||
const int kN = Count;
|
||||
|
||||
dim3 grid(1, 1);
|
||||
dim3 block(1, 1);
|
||||
|
||||
cutlass::HostTensor<Destination, cutlass::layout::RowMajor> destination({1, kN});
|
||||
cutlass::HostTensor<Source, cutlass::layout::RowMajor> source({1, kN});
|
||||
|
||||
int const kIntSourceMin = std::numeric_limits<Source>::min();
|
||||
int const kIntSourceMax = std::numeric_limits<Source>::max();
|
||||
int const kIntRange = kIntSourceMax - kIntSourceMin + 1;
|
||||
|
||||
for (int i = 0; i < kN; ++i) {
|
||||
source.host_data()[i] = Source(kIntSourceMin + (i % kIntRange));
|
||||
|
||||
}
|
||||
|
||||
source.sync_device();
|
||||
|
||||
convert<Destination, Source, kN><<< grid, block >>>(
|
||||
reinterpret_cast<cutlass::Array<Destination, kN> *>(destination.device_data()),
|
||||
reinterpret_cast<cutlass::Array<Source, kN> const *>(source.device_data())
|
||||
);
|
||||
|
||||
destination.sync_host();
|
||||
|
||||
// Verify conversion
|
||||
bool passed = true;
|
||||
for (int i = 0; i < kN; ++i) {
|
||||
if(!(float(destination.host_data()[i]) == float(source.host_data()[i]))) {
|
||||
passed = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
EXPECT_TRUE(passed) << " FastNumericArrayConverter failed";
|
||||
|
||||
// Print out results for the failed conversion.
|
||||
if (!passed) {
|
||||
for (int i = 0; i < kN; ++i) {
|
||||
std::cout << "source(" << float(source.host_data()[i]) << ") -> "
|
||||
<< "destination ("<< float(destination.host_data()[i]) << ")" << std::endl;
|
||||
}
|
||||
}
|
||||
std::flush(std::cout);
|
||||
}
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace core
|
||||
} // namespace test
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
TEST(FastNumericConversion, s32_to_f32) {
|
||||
int const kN = 4;
|
||||
using Source = int;
|
||||
using Destination = float;
|
||||
test::core::kernel::run_test_integer_range_limited<Destination, Source, kN>();
|
||||
}
|
||||
|
||||
TEST(FastNumericConversion, s8_to_f16_array) {
|
||||
int const kN = 256;
|
||||
using Source = int8_t;
|
||||
using Destination = cutlass::half_t;
|
||||
test::core::kernel::run_test_integer_range_all<Destination, Source, kN>();
|
||||
}
|
||||
|
||||
TEST(FastNumericConversion, u8_to_f16_array) {
|
||||
int const kN = 256;
|
||||
using Source = uint8_t;
|
||||
using Destination = cutlass::half_t;
|
||||
test::core::kernel::run_test_integer_range_all<Destination, Source, kN>();
|
||||
}
|
||||
|
||||
TEST(FastNumericConversion, u8_to_bf16_array) {
|
||||
int const kN = 256;
|
||||
using Source = uint8_t;
|
||||
using Destination = cutlass::bfloat16_t;
|
||||
test::core::kernel::run_test_integer_range_all<Destination, Source, kN>();
|
||||
}
|
||||
|
||||
TEST(FastNumericConversion, s8_to_bf16_array) {
|
||||
int const kN = 256;
|
||||
using Source = int8_t;
|
||||
using Destination = cutlass::bfloat16_t;
|
||||
test::core::kernel::run_test_integer_range_all<Destination, Source, kN>();
|
||||
}
|
||||
@@ -341,6 +341,21 @@ cutlass_test_unit_add_executable(
|
||||
sm80_gemm_f16_f16_f32_tensor_op_f32.cu
|
||||
)
|
||||
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_gemm_device_mixed_input_tensorop_sm80
|
||||
|
||||
BATCH_SOURCES ON
|
||||
BATCH_SIZE 4
|
||||
|
||||
# Upcast on Operand A
|
||||
gemm_universal_u8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu
|
||||
gemm_universal_s8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu
|
||||
|
||||
# Upcast on Operand B
|
||||
gemm_universal_f16t_u8n_f16t_mixed_input_tensor_op_f16_sm80.cu
|
||||
gemm_universal_f16t_s8n_f16t_mixed_input_tensor_op_f16_sm80.cu
|
||||
)
|
||||
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_gemm_device_tensorop_f64
|
||||
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tests for device-wide GEMM interface
|
||||
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "../../common/cutlass_unit_test.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/gemm/device/gemm_universal.h"
|
||||
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "testbed_universal.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
TEST(SM80_Device_GemmUniversal_f16t_s8t_f16t_mixed_input_tensor_op_f16, 128x128x64_64x64x64) {
|
||||
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = int8_t;
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversal<
|
||||
ElementA,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementB,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementOutput,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
4, // Stages
|
||||
8, // AlignmentA
|
||||
16, // AlignmentB
|
||||
cutlass::arch::OpMultiplyAddMixedInputUpcast,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
cutlass::ComplexTransform::kNone
|
||||
>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal<Gemm>());
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,97 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tests for device-wide GEMM interface
|
||||
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "../../common/cutlass_unit_test.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/gemm/device/gemm_universal.h"
|
||||
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "testbed_universal.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
TEST(SM80_Device_GemmUniversal_f16t_u8t_f16t_mixed_input_tensor_op_f16, 128x128x64_64x64x64) {
|
||||
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = uint8_t;
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversal<
|
||||
ElementA,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementB,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementOutput,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
4, // Stages
|
||||
8, // AlignmentA
|
||||
16, // AlignmentB
|
||||
cutlass::arch::OpMultiplyAddMixedInputUpcast,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
cutlass::ComplexTransform::kNone
|
||||
>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal<Gemm>());
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,97 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tests for device-wide GEMM interface
|
||||
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "../../common/cutlass_unit_test.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/gemm/device/gemm_universal.h"
|
||||
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "testbed_universal.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
TEST(SM80_Device_GemmUniversal_s8t_f16t_f16t_mixed_input_tensor_op_f16, 128x128x64_64x64x64) {
|
||||
|
||||
using ElementA = int8_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversal<
|
||||
ElementA,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementB,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementOutput,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
4, // Stages
|
||||
16, // AlignmentA
|
||||
8, // AlignmentB
|
||||
cutlass::arch::OpMultiplyAddMixedInputUpcast,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
cutlass::ComplexTransform::kNone
|
||||
>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal<Gemm>());
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,97 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tests for device-wide GEMM interface
|
||||
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "../../common/cutlass_unit_test.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/gemm/device/gemm_universal.h"
|
||||
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "testbed_universal.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
TEST(SM80_Device_GemmUniversal_u8t_f16t_f16t_mixed_input_tensor_op_f16, 128x128x64_64x64x64) {
|
||||
|
||||
using ElementA = uint8_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversal<
|
||||
ElementA,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementB,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementOutput,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
4, // Stages
|
||||
16, // AlignmentA
|
||||
8, // AlignmentB
|
||||
cutlass::arch::OpMultiplyAddMixedInputUpcast,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
cutlass::ComplexTransform::kNone
|
||||
>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal<Gemm>());
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -103,16 +103,17 @@ struct TestbedUniversal {
|
||||
double scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
int bits_output = cutlass::sizeof_bits<typename Gemm::ElementC>::value;
|
||||
bool is_unsigned_int = std::numeric_limits<Element>::is_integer && !std::numeric_limits<Element>::is_signed;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
} else if (bits_input <= 8) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
scope_max = is_unsigned_int ? 4 : 2;
|
||||
scope_min = is_unsigned_int ? 0 : -2;
|
||||
} else if (bits_output == 16) {
|
||||
scope_max = 5;
|
||||
scope_min = -5;
|
||||
scope_max = is_unsigned_int ? 10 : 5;
|
||||
scope_min = is_unsigned_int ? 0 : -5;
|
||||
} else {
|
||||
scope_max = 8;
|
||||
scope_min = -8;
|
||||
|
||||
@@ -37,6 +37,7 @@ cutlass_test_unit_add_executable(
|
||||
gemm_complex_sm80.cu
|
||||
gemm_sparse_sm80.cu
|
||||
gemm_gaussian_complex_sm80.cu
|
||||
gemm_mixed_input_sm80.cu
|
||||
gemm_sm90.cu
|
||||
gemm_complex_sm90.cu
|
||||
wmma_sm70.cu
|
||||
|
||||
322
test/unit/gemm/warp/gemm_mixed_input_sm80.cu
Normal file
322
test/unit/gemm/warp/gemm_mixed_input_sm80.cu
Normal file
@@ -0,0 +1,322 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Unit tests for thread-level GEMM
|
||||
*/
|
||||
|
||||
#include "../../common/cutlass_unit_test.h"
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/half.h"
|
||||
|
||||
#include "cutlass/gemm/warp/default_mma_tensor_op.h"
|
||||
|
||||
#include "cutlass/core_io.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
|
||||
#include "testbed.h"
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// F32 <= F16 * I8 + F32 (Upcast on Operand B)
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_i8, 128x128x64_64x64x64_16x8x16) {
|
||||
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = int8_t;
|
||||
using ElementC = float;
|
||||
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
|
||||
cutlass::sizeof_bits<ElementA>::value, 64>;
|
||||
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
|
||||
cutlass::sizeof_bits<ElementB>::value, 64>;
|
||||
|
||||
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
|
||||
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
|
||||
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;
|
||||
|
||||
test::gemm::warp::TransformTestbed<MmaTensorOp,
|
||||
cutlass::gemm::GemmShape<128, 128, 64> >()
|
||||
.run();
|
||||
}
|
||||
|
||||
|
||||
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_i8, 64x64x64_64x64x64_16x8x16) {
|
||||
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = int8_t;
|
||||
using ElementC = float;
|
||||
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
|
||||
cutlass::sizeof_bits<ElementA>::value, 64>;
|
||||
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
|
||||
cutlass::sizeof_bits<ElementB>::value, 64>;
|
||||
|
||||
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
|
||||
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
|
||||
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;
|
||||
|
||||
test::gemm::warp::TransformTestbed<MmaTensorOp,
|
||||
cutlass::gemm::GemmShape<64, 64, 64> >()
|
||||
.run();
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// F32 <= I8 * F16 + F32 (Upcast on Operand A)
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i8_f16, 128x128x64_64x64x64_16x8x16) {
|
||||
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
using ElementA = int8_t;
|
||||
using ElementB = cutlass::half_t;;
|
||||
using ElementC = float;
|
||||
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
|
||||
cutlass::sizeof_bits<ElementA>::value, 64>;
|
||||
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
|
||||
cutlass::sizeof_bits<ElementB>::value, 64>;
|
||||
|
||||
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
|
||||
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
|
||||
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;
|
||||
|
||||
test::gemm::warp::TransformTestbed<MmaTensorOp,
|
||||
cutlass::gemm::GemmShape<128, 128, 64> >()
|
||||
.run();
|
||||
}
|
||||
|
||||
|
||||
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i8_f16, 64x64x64_64x64x64_16x8x16) {
|
||||
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
using ElementA = int8_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = float;
|
||||
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
|
||||
cutlass::sizeof_bits<ElementA>::value, 64>;
|
||||
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
|
||||
cutlass::sizeof_bits<ElementB>::value, 64>;
|
||||
|
||||
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
|
||||
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
|
||||
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;
|
||||
|
||||
test::gemm::warp::TransformTestbed<MmaTensorOp,
|
||||
cutlass::gemm::GemmShape<64, 64, 64> >()
|
||||
.run();
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// F32 <= F16 * U8 + F32 (Upcast on Operand B)
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_u8, 64x64x64_64x64x64_16x8x16) {
|
||||
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = uint8_t;
|
||||
using ElementC = float;
|
||||
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
|
||||
cutlass::sizeof_bits<ElementA>::value, 64>;
|
||||
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
|
||||
cutlass::sizeof_bits<ElementB>::value, 64>;
|
||||
|
||||
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
|
||||
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
|
||||
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;
|
||||
|
||||
test::gemm::warp::TransformTestbed<MmaTensorOp,
|
||||
cutlass::gemm::GemmShape<64, 64, 64> >()
|
||||
.run();
|
||||
}
|
||||
|
||||
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_u8, 128x128x64_64x64x64_16x8x16) {
|
||||
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = uint8_t;
|
||||
using ElementC = float;
|
||||
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
|
||||
cutlass::sizeof_bits<ElementA>::value, 64>;
|
||||
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
|
||||
cutlass::sizeof_bits<ElementB>::value, 64>;
|
||||
|
||||
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
|
||||
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
|
||||
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;
|
||||
|
||||
test::gemm::warp::TransformTestbed<MmaTensorOp,
|
||||
cutlass::gemm::GemmShape<128, 128, 64> >()
|
||||
.run();
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// F32 <= U8 * F16 + F32 (Upcast on Operand A)
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_u8_f16, 64x64x64_64x64x64_16x8x16) {
|
||||
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
using ElementA = uint8_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = float;
|
||||
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
|
||||
cutlass::sizeof_bits<ElementA>::value, 64>;
|
||||
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
|
||||
cutlass::sizeof_bits<ElementB>::value, 64>;
|
||||
|
||||
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
|
||||
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
|
||||
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;
|
||||
|
||||
test::gemm::warp::TransformTestbed<MmaTensorOp,
|
||||
cutlass::gemm::GemmShape<64, 64, 64> >()
|
||||
.run();
|
||||
}
|
||||
|
||||
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_u8_f16, 128x128x64_64x64x64_16x8x16) {
|
||||
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
using ElementA = uint8_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = float;
|
||||
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
|
||||
cutlass::sizeof_bits<ElementA>::value, 64>;
|
||||
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
|
||||
cutlass::sizeof_bits<ElementB>::value, 64>;
|
||||
|
||||
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
|
||||
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
|
||||
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;
|
||||
|
||||
test::gemm::warp::TransformTestbed<MmaTensorOp,
|
||||
cutlass::gemm::GemmShape<128, 128, 64> >()
|
||||
.run();
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// F32 <= B16 * U8 + F32 (Upcast on Operand B)
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_bf16_u8, 64x64x64_64x64x64_16x8x16) {
|
||||
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
using ElementA = cutlass::bfloat16_t;
|
||||
using ElementB = uint8_t;
|
||||
using ElementC = float;
|
||||
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
|
||||
cutlass::sizeof_bits<ElementA>::value, 64>;
|
||||
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
|
||||
cutlass::sizeof_bits<ElementB>::value, 64>;
|
||||
|
||||
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
|
||||
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
|
||||
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;
|
||||
|
||||
test::gemm::warp::TransformTestbed<MmaTensorOp,
|
||||
cutlass::gemm::GemmShape<64, 64, 64> >()
|
||||
.run();
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// F32 <= B16 * U8 + F32 (Upcast on Operand B)
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_u8_bf16, 64x64x64_64x64x64_16x8x16) {
|
||||
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
using ElementA = uint8_t;
|
||||
using ElementB = cutlass::bfloat16_t;
|
||||
using ElementC = float;
|
||||
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
|
||||
cutlass::sizeof_bits<ElementA>::value, 64>;
|
||||
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
|
||||
cutlass::sizeof_bits<ElementB>::value, 64>;
|
||||
|
||||
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
|
||||
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
|
||||
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;
|
||||
|
||||
test::gemm::warp::TransformTestbed<MmaTensorOp,
|
||||
cutlass::gemm::GemmShape<64, 64, 64> >()
|
||||
.run();
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// F32 <= B16 * I8 + F32 (Upcast on Operand B)
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_bf16_i8, 64x64x64_64x64x64_16x8x16) {
|
||||
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
using ElementA = cutlass::bfloat16_t;
|
||||
using ElementB = int8_t;
|
||||
using ElementC = float;
|
||||
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
|
||||
cutlass::sizeof_bits<ElementA>::value, 64>;
|
||||
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
|
||||
cutlass::sizeof_bits<ElementB>::value, 64>;
|
||||
|
||||
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
|
||||
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
|
||||
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;
|
||||
|
||||
test::gemm::warp::TransformTestbed<MmaTensorOp,
|
||||
cutlass::gemm::GemmShape<64, 64, 64> >()
|
||||
.run();
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// F32 <= B16 * I8 + F32 (Upcast on Operand B)
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i8_bf16, 64x64x64_64x64x64_16x8x16) {
|
||||
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
using ElementA = int8_t;
|
||||
using ElementB = cutlass::bfloat16_t;
|
||||
using ElementC = float;
|
||||
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
|
||||
cutlass::sizeof_bits<ElementA>::value, 64>;
|
||||
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
|
||||
cutlass::sizeof_bits<ElementB>::value, 64>;
|
||||
|
||||
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
|
||||
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
|
||||
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;
|
||||
|
||||
test::gemm::warp::TransformTestbed<MmaTensorOp,
|
||||
cutlass::gemm::GemmShape<64, 64, 64> >()
|
||||
.run();
|
||||
}
|
||||
|
||||
#endif // if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
Reference in New Issue
Block a user