mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-11 08:50:09 +00:00
Extend DualGemm: support batched mode + decouple B0/B1 layouts (#790)
* Fix MHA kernel Summary: ATT Test Plan: Reviewers: Subscribers: Tasks: Tags: * Extend DualGemm to support batched mode (#5) Following the GemmUniversalMode::kBatched implementation, batched mode is added to the DualGemm (under examples/45_dual_gemm). DualGemmMode::kBatched and SplitKSerial are not compatible: Status::kErrorInvalidProblem is returned if both are set. * Decouple LayoutB0 and LayoutB1 in DualGemm The DualGemm template assumed the same layout, LayoutB, for both right operand matrices B0 and B1. This is problematic if the layout of the two matrices is different. In particular, this may be the case when one of the matrices is row-major, while the other is a (column) vector that has to be broadcasted in column-major with zero stride (e.g., as {B1.device_data(), 0}) for the DualGemm implementation to be able to process B0 and B1 simultaneously. In this commit, LayoutB0 and LayoutB1 are decoupled throughout the DualGemm code (device, kernel, and mma). Additionally, the batch strides of B0 and B1 are also decoupled to accommodate the column vector B1 case described above. * Remove comment as no longer relevant * Revert Fix MHA kernel --------- Co-authored-by: mikeiovine <mikeiovine@fb.com>
This commit is contained in:
@@ -52,6 +52,7 @@ D2 = element_wise(D0, D1)
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
|
||||
|
||||
#include "../kernel/dual_gemm.h"
|
||||
#include "../dual_gemm_common.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@@ -68,8 +69,10 @@ template <
|
||||
typename LayoutA_,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB_,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB_,
|
||||
/// Layout type for B0 matrix operand
|
||||
typename LayoutB0_,
|
||||
/// Layout type for B1 matrix operand
|
||||
typename LayoutB1_,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC_,
|
||||
/// Layout type for C and D matrix operands
|
||||
@@ -119,8 +122,10 @@ class DualGemm {
|
||||
using LayoutA = LayoutA_;
|
||||
using TensorRefA = TensorRef<ElementA const, LayoutA>;
|
||||
using ElementB = ElementB_;
|
||||
using LayoutB = LayoutB_;
|
||||
using TensorRefB = TensorRef<ElementB const, LayoutB>;
|
||||
using LayoutB0 = LayoutB0_;
|
||||
using LayoutB1 = LayoutB1_;
|
||||
using TensorRefB0 = TensorRef<ElementB const, LayoutB0>;
|
||||
using TensorRefB1 = TensorRef<ElementB const, LayoutB1>;
|
||||
using ElementC = ElementC_;
|
||||
using LayoutC = LayoutC_;
|
||||
using TensorRefC = TensorRef<ElementC const, LayoutC>;
|
||||
@@ -151,23 +156,31 @@ class DualGemm {
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
static_assert(ArchTag::kMinComputeCapability >= 80, "Only multistage is implemented");
|
||||
static_assert(kStages >= 3, "Only multistage is implemented");
|
||||
using Mma = typename cutlass::gemm::threadblock::DefaultMma<
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||
using Mma0 = typename cutlass::gemm::threadblock::DefaultMma<
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB0, kAlignmentB,
|
||||
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag,
|
||||
ThreadblockShape, WarpShape,
|
||||
InstructionShape, Stages, Operator>::ThreadblockMma;
|
||||
using Mma1 = typename cutlass::gemm::threadblock::DefaultMma<
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB1, kAlignmentB,
|
||||
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag,
|
||||
ThreadblockShape, WarpShape,
|
||||
InstructionShape, Stages, Operator>::ThreadblockMma;
|
||||
using DualMma = threadblock::DualMmaMultistage<
|
||||
typename Mma::Shape,
|
||||
typename Mma::IteratorA,
|
||||
typename Mma::SmemIteratorA,
|
||||
Mma::kCacheOpA,
|
||||
typename Mma::IteratorB,
|
||||
typename Mma::SmemIteratorB,
|
||||
Mma::kCacheOpB,
|
||||
typename Mma::ElementC,
|
||||
typename Mma::LayoutC,
|
||||
typename Mma::Policy,
|
||||
Mma::kStages,
|
||||
typename Mma0::Shape,
|
||||
typename Mma0::IteratorA,
|
||||
typename Mma0::SmemIteratorA,
|
||||
Mma0::kCacheOpA,
|
||||
typename Mma0::IteratorB,
|
||||
typename Mma0::SmemIteratorB,
|
||||
Mma0::kCacheOpB,
|
||||
typename Mma1::IteratorB,
|
||||
typename Mma1::SmemIteratorB,
|
||||
typename Mma0::ElementC,
|
||||
typename Mma0::LayoutC,
|
||||
typename Mma0::Policy,
|
||||
typename Mma1::Policy,
|
||||
Mma0::kStages,
|
||||
SharedMemoryClearOption::kNone
|
||||
>;
|
||||
|
||||
@@ -176,11 +189,11 @@ class DualGemm {
|
||||
/// Define the epilogue
|
||||
using Epilogue0 =
|
||||
typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape, typename DualMma::Operator, kPartitionsK, EpilogueOutputOp0,
|
||||
ThreadblockShape, typename DualMma::Operator0, kPartitionsK, EpilogueOutputOp0,
|
||||
EpilogueOutputOp0::kCount>::Epilogue;
|
||||
using Epilogue1 =
|
||||
typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape, typename DualMma::Operator, kPartitionsK, EpilogueOutputOp1,
|
||||
ThreadblockShape, typename DualMma::Operator1, kPartitionsK, EpilogueOutputOp1,
|
||||
EpilogueOutputOp1::kCount>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
@@ -197,12 +210,13 @@ class DualGemm {
|
||||
// Data members
|
||||
//
|
||||
|
||||
DualGemmMode mode;
|
||||
GemmCoord problem_size;
|
||||
TensorRef<ElementA const, LayoutA> ref_A0;
|
||||
TensorRef<ElementB const, LayoutB> ref_B0;
|
||||
TensorRef<ElementB const, LayoutB0> ref_B0;
|
||||
TensorRef<ElementC const, LayoutC> ref_C0;
|
||||
TensorRef<ElementC, LayoutC> ref_D0;
|
||||
TensorRef<ElementB const, LayoutB> ref_B1;
|
||||
TensorRef<ElementB const, LayoutB1> ref_B1;
|
||||
TensorRef<ElementC const, LayoutC> ref_C1;
|
||||
TensorRef<ElementC, LayoutC> ref_D1;
|
||||
TensorRef<ElementC, LayoutC> ref_D2;
|
||||
@@ -211,6 +225,13 @@ class DualGemm {
|
||||
typename EpilogueOutputOp2::Params epilogue2;
|
||||
int split_k_slices;
|
||||
|
||||
int batch_count;
|
||||
int64_t batch_stride_A;
|
||||
int64_t batch_stride_B0;
|
||||
int64_t batch_stride_B1;
|
||||
int64_t batch_stride_C;
|
||||
int64_t batch_stride_D;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
@@ -224,12 +245,13 @@ class DualGemm {
|
||||
/// Constructs an Arguments structure
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
DualGemmMode mode,
|
||||
GemmCoord problem_size_,
|
||||
TensorRef<ElementA const, LayoutA> ref_A0_,
|
||||
TensorRef<ElementB const, LayoutB> ref_B0_,
|
||||
TensorRef<ElementB const, LayoutB0> ref_B0_,
|
||||
TensorRef<ElementC const, LayoutC> ref_C0_,
|
||||
TensorRef<ElementC, LayoutC> ref_D0_,
|
||||
TensorRef<ElementB const, LayoutB> ref_B1_,
|
||||
TensorRef<ElementB const, LayoutB1> ref_B1_,
|
||||
TensorRef<ElementC const, LayoutC> ref_C1_,
|
||||
TensorRef<ElementC, LayoutC> ref_D1_,
|
||||
TensorRef<ElementC, LayoutC> ref_D2_,
|
||||
@@ -239,8 +261,15 @@ class DualGemm {
|
||||
typename EpilogueOutputOp1::Params(),
|
||||
typename EpilogueOutputOp2::Params epilogue2_ =
|
||||
typename EpilogueOutputOp2::Params(),
|
||||
int split_k_slices_ = 1
|
||||
int split_k_slices_ = 1,
|
||||
int batch_count = 1,
|
||||
int64_t batch_stride_A = 0,
|
||||
int64_t batch_stride_B0 = 0,
|
||||
int64_t batch_stride_B1 = 0,
|
||||
int64_t batch_stride_C = 0,
|
||||
int64_t batch_stride_D = 0
|
||||
):
|
||||
mode(mode),
|
||||
problem_size(problem_size_),
|
||||
ref_A0(ref_A0_),
|
||||
ref_B0(ref_B0_),
|
||||
@@ -253,7 +282,13 @@ class DualGemm {
|
||||
epilogue0(epilogue0_),
|
||||
epilogue1(epilogue1_),
|
||||
epilogue2(epilogue2_),
|
||||
split_k_slices(split_k_slices_) {
|
||||
split_k_slices(split_k_slices_),
|
||||
batch_count(batch_count),
|
||||
batch_stride_A(batch_stride_A),
|
||||
batch_stride_B0(batch_stride_B0),
|
||||
batch_stride_B1(batch_stride_B1),
|
||||
batch_stride_C(batch_stride_C),
|
||||
batch_stride_D(batch_stride_D) {
|
||||
|
||||
}
|
||||
};
|
||||
@@ -271,6 +306,9 @@ public:
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status can_implement(Arguments const &args) {
|
||||
|
||||
if (args.mode == DualGemmMode::kBatched && kSplitKSerial) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
if (!kSplitKSerial && args.split_k_slices > 1) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
@@ -304,17 +342,15 @@ public:
|
||||
static size_t get_workspace_size(Arguments const &args) {
|
||||
|
||||
size_t bytes = 0;
|
||||
|
||||
// Determine grid shape
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape(
|
||||
args.problem_size,
|
||||
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
|
||||
args.split_k_slices);
|
||||
|
||||
if (kSplitKSerial && args.split_k_slices > 1) {
|
||||
// Determine grid shape
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape(
|
||||
args.problem_size,
|
||||
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
|
||||
args.split_k_slices);
|
||||
|
||||
bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n());
|
||||
}
|
||||
@@ -331,7 +367,7 @@ public:
|
||||
cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
|
||||
args.problem_size,
|
||||
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
|
||||
args.split_k_slices);
|
||||
args.mode == DualGemmMode::kBatched ? args.batch_count : args.split_k_slices);
|
||||
|
||||
if (kSplitKSerial) {
|
||||
if (args.split_k_slices > 1) {
|
||||
@@ -357,6 +393,7 @@ public:
|
||||
|
||||
// Initialize the Params structure
|
||||
params_ = typename DualGemmKernel::Params{
|
||||
args.mode,
|
||||
args.problem_size,
|
||||
grid_shape,
|
||||
args.ref_A0.non_const_ref(),
|
||||
@@ -371,6 +408,11 @@ public:
|
||||
args.epilogue1,
|
||||
args.epilogue2,
|
||||
reinterpret_cast<int *>(workspace),
|
||||
args.batch_stride_A,
|
||||
args.batch_stride_B0,
|
||||
args.batch_stride_B1,
|
||||
args.batch_stride_C,
|
||||
args.batch_stride_D,
|
||||
};
|
||||
|
||||
return Status::kSuccess;
|
||||
|
||||
@@ -43,8 +43,6 @@ D2 = element_wise(D0, D1)
|
||||
D0 and D1 will be optionally stored in gmem (`kStoreD0` / `kStoreD1`)
|
||||
*/
|
||||
|
||||
// #define IS_PROFILING
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
@@ -66,10 +64,12 @@ D2 = element_wise(D0, D1)
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::gemm::GemmCoord problem_size(4096, 4096, 8192);
|
||||
cutlass::gemm::GemmCoord batch_problem_size(321, 256, 512);
|
||||
|
||||
constexpr int kStages = 3;
|
||||
constexpr bool kSplitKSerial = false;
|
||||
constexpr bool kUseBias = true;
|
||||
constexpr int kBatchCount = 37;
|
||||
|
||||
|
||||
#if 0
|
||||
@@ -165,7 +165,16 @@ bool run_nonfused_gemm_f16_sm80() {
|
||||
NonFusedDualGemmRun<Gemm0, Gemm1> nonFusedGemm;
|
||||
|
||||
std::cout << "Running Non-fused GEMMs FP16 TN GEMMs...\n";
|
||||
bool pass = nonFusedGemm.run(problem_size, alpha0, beta0, alpha1, beta1);
|
||||
|
||||
bool pass = nonFusedGemm.run(
|
||||
problem_size,
|
||||
alpha0,
|
||||
beta0,
|
||||
alpha1,
|
||||
beta1,
|
||||
true /* is_profiling */
|
||||
);
|
||||
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
@@ -215,6 +224,7 @@ bool run_fused_gemm_f16_sm80_shmem() {
|
||||
cutlass::layout::RowMajor,
|
||||
ElementOperandB,
|
||||
cutlass::layout::ColumnMajor,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementOutput,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
@@ -236,24 +246,212 @@ bool run_fused_gemm_f16_sm80_shmem() {
|
||||
DualFusedGemmRun<DualGemm> fusedGemm;
|
||||
|
||||
std::cout << "Running Fused FP16 TN GEMMs + Epilogue2...\n";
|
||||
bool passed = fusedGemm.run(problem_size, alpha0, beta0, alpha1, beta1);
|
||||
|
||||
bool passed = fusedGemm.run(
|
||||
problem_size,
|
||||
alpha0,
|
||||
beta0,
|
||||
alpha1,
|
||||
beta1
|
||||
);
|
||||
|
||||
if(passed)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
bool run_batched_fused_gemm_f16_sm80_shmem() {
|
||||
using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
|
||||
// Optionally, we might not need intermediate GEMM outputs
|
||||
constexpr bool kStoreD0 = true;
|
||||
constexpr bool kStoreD1 = true;
|
||||
|
||||
using DualGemm = cutlass::gemm::device::DualGemm<
|
||||
ElementOperandA,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementOperandB,
|
||||
cutlass::layout::ColumnMajor,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementOutput,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
EpilogueOutputOp2,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
kStages,
|
||||
kStoreD0,
|
||||
kStoreD1,
|
||||
kSplitKSerial
|
||||
>;
|
||||
|
||||
DualFusedGemmRun<DualGemm> fusedGemm;
|
||||
|
||||
std::cout << "Running Batched Fused FP16 TN GEMMs + Epilogue2...\n";
|
||||
|
||||
bool passed = fusedGemm.run(
|
||||
batch_problem_size,
|
||||
alpha0,
|
||||
beta0,
|
||||
alpha1,
|
||||
beta1,
|
||||
kBatchCount,
|
||||
false, /* broadcast_b1 */
|
||||
false /* is_profiling */
|
||||
);
|
||||
|
||||
if(passed)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
bool run_broadcast_fused_gemm_f16_sm80_shmem() {
|
||||
using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
|
||||
// Optionally, we might not need intermediate GEMM outputs
|
||||
constexpr bool kStoreD0 = true;
|
||||
constexpr bool kStoreD1 = true;
|
||||
|
||||
using DualGemm = cutlass::gemm::device::DualGemm<
|
||||
ElementOperandA,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementOperandB,
|
||||
// different LayoutB0 and B1
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementOutput,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
EpilogueOutputOp2,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
kStages,
|
||||
kStoreD0,
|
||||
kStoreD1,
|
||||
kSplitKSerial
|
||||
>;
|
||||
|
||||
DualFusedGemmRun<DualGemm> fusedGemm;
|
||||
|
||||
std::cout << "Running Broadcast Fused FP16 TN GEMMs + Epilogue2...\n";
|
||||
|
||||
bool passed = fusedGemm.run(
|
||||
problem_size,
|
||||
alpha0,
|
||||
beta0,
|
||||
alpha1,
|
||||
beta1,
|
||||
1, /* batch_count */
|
||||
true, /* broadcast_b1 */
|
||||
true /* is_profiling */
|
||||
);
|
||||
|
||||
if(passed)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
bool run_batched_broadcast_fused_gemm_f16_sm80_shmem() {
|
||||
using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
|
||||
// Optionally, we might not need intermediate GEMM outputs
|
||||
constexpr bool kStoreD0 = true;
|
||||
constexpr bool kStoreD1 = true;
|
||||
|
||||
using DualGemm = cutlass::gemm::device::DualGemm<
|
||||
ElementOperandA,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementOperandB,
|
||||
// different LayoutB0 and B1
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementOutput,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
EpilogueOutputOp2,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
kStages,
|
||||
kStoreD0,
|
||||
kStoreD1,
|
||||
kSplitKSerial
|
||||
>;
|
||||
|
||||
DualFusedGemmRun<DualGemm> fusedGemm;
|
||||
|
||||
std::cout << "Running Batch Broadcast Fused FP16 TN GEMMs + Epilogue2...\n";
|
||||
|
||||
bool passed = fusedGemm.run(
|
||||
batch_problem_size,
|
||||
alpha0,
|
||||
beta0,
|
||||
alpha1,
|
||||
beta1,
|
||||
kBatchCount,
|
||||
true, /* broadcast_b1 */
|
||||
false /* is_profiling */
|
||||
);
|
||||
|
||||
if(passed)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_gemm_f16_sm80,
|
||||
&run_fused_gemm_f16_sm80_shmem
|
||||
&run_fused_gemm_f16_sm80_shmem,
|
||||
&run_batched_fused_gemm_f16_sm80_shmem,
|
||||
&run_broadcast_fused_gemm_f16_sm80_shmem,
|
||||
&run_batched_broadcast_fused_gemm_f16_sm80_shmem
|
||||
};
|
||||
|
||||
std::string test_name = "dual-gemm f16 bias=" + std::to_string(kUseBias) + " split_k_serial=" + std::to_string(kSplitKSerial);
|
||||
std::string test_name = (
|
||||
"dual-gemm f16 bias=" +
|
||||
std::to_string(kUseBias) +
|
||||
" split_k_serial=" +
|
||||
std::to_string(kSplitKSerial) +
|
||||
" batch_count=" +
|
||||
std::to_string(kBatchCount)
|
||||
);
|
||||
|
||||
return testRun(80, funcs, test_name);
|
||||
}
|
||||
|
||||
|
||||
52
examples/45_dual_gemm/dual_gemm_common.h
Normal file
52
examples/45_dual_gemm/dual_gemm_common.h
Normal file
@@ -0,0 +1,52 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines common types used for all DualGemm operators.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
enum class DualGemmMode {
|
||||
kGemm,
|
||||
kBatched,
|
||||
kInvalid
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -33,6 +33,7 @@
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <type_traits>
|
||||
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
@@ -44,6 +45,10 @@
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_relu.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/device/gemm_universal.h"
|
||||
|
||||
#include "dual_gemm_common.h"
|
||||
#include "helper.h"
|
||||
|
||||
#define CHECK_GT(val1, val2) \
|
||||
@@ -205,6 +210,7 @@ struct NonFusedDualGemmRun
|
||||
ElementCompute beta0 = ElementCompute(0),
|
||||
ElementCompute alpha1 = ElementCompute(1),
|
||||
ElementCompute beta1 = ElementCompute(0),
|
||||
bool is_profiling = true,
|
||||
bool relu = false,
|
||||
int warm_ups = 1,
|
||||
int runs = 100) {
|
||||
@@ -335,40 +341,41 @@ struct NonFusedDualGemmRun
|
||||
status = gemm_op_1();
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
#ifdef IS_PROFILING
|
||||
return true;
|
||||
#endif
|
||||
//
|
||||
// Run the GEMM
|
||||
//
|
||||
cudaEvent_t start, stop1, stop2;
|
||||
cudaEventCreate(&start);
|
||||
cudaEventCreate(&stop1);
|
||||
cudaEventCreate(&stop2);
|
||||
|
||||
cudaEventRecord(start);
|
||||
if (is_profiling) {
|
||||
//
|
||||
// Profile the GEMM
|
||||
//
|
||||
|
||||
for(int i = 0; i < runs; i++) {
|
||||
status = gemm_op_0();
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
cudaEvent_t start, stop1, stop2;
|
||||
cudaEventCreate(&start);
|
||||
cudaEventCreate(&stop1);
|
||||
cudaEventCreate(&stop2);
|
||||
|
||||
cudaEventRecord(start);
|
||||
|
||||
for(int i = 0; i < runs; i++) {
|
||||
status = gemm_op_0();
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
cudaEventRecord(stop1);
|
||||
for(int i = 0; i < runs; i++) {
|
||||
status = gemm_op_1();
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
cudaEventRecord(stop2);
|
||||
cudaDeviceSynchronize();
|
||||
float gemm0Time, gemm1Time, totalTime;
|
||||
cudaEventElapsedTime(&gemm0Time, start, stop1);
|
||||
cudaEventElapsedTime(&gemm1Time, stop1, stop2);
|
||||
cudaEventElapsedTime(&totalTime, start, stop2);
|
||||
std::cout << "gemm 0 time " << gemm0Time / (float)runs << " ms\n";
|
||||
std::cout << "gemm 1 time " << gemm1Time / (float)runs << " ms\n";
|
||||
std::cout << "Non-fusion GEMM only time " << totalTime / (float)runs << " ms\n";
|
||||
}
|
||||
cudaEventRecord(stop1);
|
||||
for(int i = 0; i < runs; i++) {
|
||||
status = gemm_op_1();
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
cudaEventRecord(stop2);
|
||||
cudaDeviceSynchronize();
|
||||
float gemm0Time, gemm1Time, totalTime;
|
||||
cudaEventElapsedTime(&gemm0Time, start, stop1);
|
||||
cudaEventElapsedTime(&gemm1Time, stop1, stop2);
|
||||
cudaEventElapsedTime(&totalTime, start, stop2);
|
||||
std::cout << "gemm 0 time " << gemm0Time / (float)runs << " ms\n";
|
||||
std::cout << "gemm 1 time " << gemm1Time / (float)runs << " ms\n";
|
||||
std::cout << "Non-fusion GEMM only time " << totalTime / (float)runs << " ms\n";
|
||||
|
||||
tensor_D0.sync_host();
|
||||
tensor_D1.sync_host();
|
||||
@@ -543,6 +550,9 @@ struct DualFusedGemmRun
|
||||
ElementCompute beta0 = ElementCompute(1),
|
||||
ElementCompute alpha1 = ElementCompute(1),
|
||||
ElementCompute beta1 = ElementCompute(1),
|
||||
int batch_count = 1,
|
||||
bool broadcast_b1 = false,
|
||||
bool is_profiling = true,
|
||||
bool relu = false,
|
||||
int warm_ups = 1,
|
||||
int runs = 100) {
|
||||
@@ -553,55 +563,91 @@ struct DualFusedGemmRun
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename DualGemm::ElementA,
|
||||
typename DualGemm::LayoutA> tensor_A0(problem_size.mk());
|
||||
typename DualGemm::LayoutA> tensor_A0(
|
||||
std::is_same<typename DualGemm::LayoutA, cutlass::layout::RowMajor>::value ?
|
||||
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.k()) :
|
||||
cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.k()));
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename DualGemm::ElementB,
|
||||
typename DualGemm::LayoutB> tensor_B0(problem_size.kn());
|
||||
typename DualGemm::LayoutB0> tensor_B0(
|
||||
std::is_same<typename DualGemm::LayoutB0, cutlass::layout::RowMajor>::value ?
|
||||
cutlass::MatrixCoord(batch_count * problem_size.k(), problem_size.n()) :
|
||||
cutlass::MatrixCoord(problem_size.k(), batch_count * problem_size.n()));
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename DualGemm::ElementC,
|
||||
typename DualGemm::LayoutC> tensor_C0(problem_size.mn());
|
||||
typename DualGemm::LayoutC> tensor_C0(
|
||||
std::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
|
||||
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
|
||||
cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n()));
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename DualGemm::ElementC,
|
||||
typename DualGemm::LayoutScaleBias> tensor_Bias0({1, problem_size.n()});
|
||||
typename DualGemm::LayoutScaleBias> tensor_Bias0({batch_count, problem_size.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename DualGemm::ElementC,
|
||||
typename DualGemm::LayoutC> tensor_D0(problem_size.mn());
|
||||
typename DualGemm::LayoutC> tensor_D0(
|
||||
std::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
|
||||
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
|
||||
cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n()));
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename DualGemm::ElementC,
|
||||
typename DualGemm::LayoutC> reference_D0(problem_size.mn());
|
||||
typename DualGemm::LayoutC> reference_D0(
|
||||
std::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
|
||||
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
|
||||
cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n()));
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename DualGemm::ElementB,
|
||||
typename DualGemm::LayoutB> tensor_B1(problem_size.kn());
|
||||
typename DualGemm::LayoutB1> tensor_B1(
|
||||
std::is_same<typename DualGemm::LayoutB1, cutlass::layout::RowMajor>::value ?
|
||||
cutlass::MatrixCoord(batch_count * problem_size.k(), problem_size.n()) :
|
||||
cutlass::MatrixCoord(problem_size.k(), batch_count * problem_size.n()));
|
||||
if (broadcast_b1) {
|
||||
tensor_B1.resize({problem_size.k(), batch_count});
|
||||
}
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename DualGemm::ElementC,
|
||||
typename DualGemm::LayoutC> tensor_C1(problem_size.mn());
|
||||
typename DualGemm::LayoutC> tensor_C1(
|
||||
std::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
|
||||
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
|
||||
cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n()));
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename DualGemm::ElementC,
|
||||
typename DualGemm::LayoutScaleBias> tensor_Bias1({1, problem_size.n()});
|
||||
typename DualGemm::LayoutScaleBias> tensor_Bias1({batch_count, problem_size.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename DualGemm::ElementC,
|
||||
typename DualGemm::LayoutC> tensor_D1(problem_size.mn());
|
||||
typename DualGemm::LayoutC> tensor_D1(
|
||||
std::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
|
||||
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
|
||||
cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n()));
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename DualGemm::ElementC,
|
||||
typename DualGemm::LayoutC> tensor_D2(problem_size.mn());
|
||||
typename DualGemm::LayoutC> tensor_D2(
|
||||
std::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
|
||||
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
|
||||
cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n()));
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename DualGemm::ElementC,
|
||||
typename DualGemm::LayoutC> reference_D1(problem_size.mn());
|
||||
typename DualGemm::LayoutC> reference_D1(
|
||||
std::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
|
||||
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
|
||||
cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n()));
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename DualGemm::ElementC,
|
||||
typename DualGemm::LayoutC> reference_D2(problem_size.mn());
|
||||
typename DualGemm::LayoutC> reference_D2(
|
||||
std::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
|
||||
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
|
||||
cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n()));
|
||||
|
||||
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
|
||||
CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2118));
|
||||
@@ -638,6 +684,20 @@ struct DualFusedGemmRun
|
||||
reference_D1.sync_device();
|
||||
reference_D2.sync_device();
|
||||
|
||||
//
|
||||
// Batch strides (irrelevant when batch_count == 1)
|
||||
//
|
||||
|
||||
int64_t batch_stride_A = problem_size.m() * problem_size.k();
|
||||
int64_t batch_stride_B0 = problem_size.k() * problem_size.n();
|
||||
int64_t batch_stride_B1 = problem_size.k() * problem_size.n();
|
||||
if (broadcast_b1) {
|
||||
// B1 is a (column) vector
|
||||
batch_stride_B1 = problem_size.k();
|
||||
}
|
||||
int64_t batch_stride_Bias = problem_size.n();
|
||||
int64_t batch_stride_D = problem_size.m() * problem_size.n();
|
||||
|
||||
//
|
||||
// Initialize the GEMM operator
|
||||
//
|
||||
@@ -652,21 +712,36 @@ struct DualFusedGemmRun
|
||||
ref_B1 = {tensor_Bias1.device_data(), typename DualGemm::LayoutC::Stride(0)};
|
||||
}
|
||||
typename DualGemm::Arguments arguments{
|
||||
(batch_count > 1 ?
|
||||
cutlass::gemm::DualGemmMode::kBatched :
|
||||
cutlass::gemm::DualGemmMode::kGemm),
|
||||
problem_size,
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
ref_B0,
|
||||
DualGemm::kStoreD0 ? tensor_D0.device_ref() : nullptr_ref,
|
||||
tensor_B1.device_ref(),
|
||||
(broadcast_b1 ?
|
||||
typename DualGemm::TensorRefB1(tensor_B1.device_data(), 0) :
|
||||
tensor_B1.device_ref()),
|
||||
ref_B1,
|
||||
DualGemm::kStoreD1 ? tensor_D1.device_ref() : nullptr_ref,
|
||||
tensor_D2.device_ref(),
|
||||
{alpha0, beta0},
|
||||
{alpha1, beta1},
|
||||
{},
|
||||
split_k_slices
|
||||
split_k_slices,
|
||||
batch_count,
|
||||
batch_stride_A,
|
||||
batch_stride_B0,
|
||||
batch_stride_B1,
|
||||
batch_stride_Bias,
|
||||
batch_stride_D,
|
||||
};
|
||||
|
||||
//
|
||||
// Run the GEMM
|
||||
//
|
||||
|
||||
DualGemm b2b_gemm_op;
|
||||
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(b2b_gemm_op.get_workspace_size(arguments));
|
||||
@@ -684,31 +759,29 @@ struct DualFusedGemmRun
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
#ifdef IS_PROFILING
|
||||
return true;
|
||||
#endif
|
||||
//
|
||||
// Run the GEMM
|
||||
//
|
||||
if (is_profiling) {
|
||||
//
|
||||
// Profile the GEMM
|
||||
//
|
||||
|
||||
cudaEvent_t start, stop;
|
||||
cudaEventCreate(&start);
|
||||
cudaEventCreate(&stop);
|
||||
cudaEvent_t start, stop;
|
||||
cudaEventCreate(&start);
|
||||
cudaEventCreate(&stop);
|
||||
|
||||
cudaEventRecord(start);
|
||||
cudaEventRecord(start);
|
||||
|
||||
for(int i = 0; i < runs; i++) {
|
||||
status = b2b_gemm_op();
|
||||
for(int i = 0; i < runs; i++) {
|
||||
status = b2b_gemm_op();
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
cudaEventRecord(stop);
|
||||
cudaDeviceSynchronize();
|
||||
float gemmTime;
|
||||
cudaEventElapsedTime(&gemmTime, start, stop);
|
||||
std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n";
|
||||
}
|
||||
|
||||
cudaEventRecord(stop);
|
||||
cudaDeviceSynchronize();
|
||||
float gemmTime;
|
||||
cudaEventElapsedTime(&gemmTime, start, stop);
|
||||
std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n";
|
||||
|
||||
tensor_D0.sync_host();
|
||||
tensor_D1.sync_host();
|
||||
tensor_D2.sync_host();
|
||||
@@ -717,45 +790,81 @@ struct DualFusedGemmRun
|
||||
// Verify
|
||||
//
|
||||
|
||||
cutlass::reference::device::Gemm<
|
||||
typename DualGemm::ElementA, typename DualGemm::LayoutA,
|
||||
typename DualGemm::ElementB, typename DualGemm::LayoutB,
|
||||
typename DualGemm::ElementC, typename DualGemm::LayoutC,
|
||||
ElementAccumulator, ElementAccumulator>
|
||||
reference_gemm_0;
|
||||
using GemmUniversal0 = cutlass::gemm::device::GemmUniversal<
|
||||
typename DualGemm::ElementA, typename DualGemm::LayoutA,
|
||||
typename DualGemm::ElementB, typename DualGemm::LayoutB0,
|
||||
typename DualGemm::ElementC, typename DualGemm::LayoutC,
|
||||
ElementAccumulator
|
||||
>;
|
||||
|
||||
cutlass::reference::device::Gemm<
|
||||
typename DualGemm::ElementA, typename DualGemm::LayoutA,
|
||||
typename DualGemm::ElementB, typename DualGemm::LayoutB,
|
||||
typename DualGemm::ElementC, typename DualGemm::LayoutC, ElementCompute,
|
||||
ElementAccumulator, typename DualGemm::Operator>
|
||||
reference_gemm_1;
|
||||
GemmUniversal0 reference_gemm0;
|
||||
|
||||
reference_gemm_0(
|
||||
typename GemmUniversal0::Arguments args0 {
|
||||
(batch_count > 1 ?
|
||||
cutlass::gemm::GemmUniversalMode::kBatched :
|
||||
cutlass::gemm::GemmUniversalMode::kGemm),
|
||||
problem_size,
|
||||
alpha0,
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
beta0,
|
||||
{tensor_Bias0.device_data(), typename DualGemm::LayoutC::Stride(0)},
|
||||
reference_D0.device_ref()
|
||||
);
|
||||
batch_count,
|
||||
{alpha0, beta0},
|
||||
tensor_A0.device_data(),
|
||||
tensor_B0.device_data(),
|
||||
tensor_Bias0.device_data(),
|
||||
reference_D0.device_data(),
|
||||
batch_stride_A,
|
||||
batch_stride_B0,
|
||||
batch_stride_Bias,
|
||||
batch_stride_D,
|
||||
tensor_A0.stride(0),
|
||||
tensor_B0.stride(0),
|
||||
0, // zero stride for the bias vector
|
||||
reference_D0.stride(0),
|
||||
};
|
||||
|
||||
status = reference_gemm0.can_implement(args0);
|
||||
CUTLASS_CHECK(status);
|
||||
status = reference_gemm0(args0);
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
using GemmUniversal1 = cutlass::gemm::device::GemmUniversal<
|
||||
typename DualGemm::ElementA, typename DualGemm::LayoutA,
|
||||
typename DualGemm::ElementB, typename DualGemm::LayoutB1,
|
||||
typename DualGemm::ElementC, typename DualGemm::LayoutC,
|
||||
ElementAccumulator
|
||||
>;
|
||||
|
||||
GemmUniversal1 reference_gemm1;
|
||||
|
||||
typename GemmUniversal1::Arguments args1 {
|
||||
(batch_count > 1 ?
|
||||
cutlass::gemm::GemmUniversalMode::kBatched :
|
||||
cutlass::gemm::GemmUniversalMode::kGemm),
|
||||
problem_size,
|
||||
batch_count,
|
||||
{alpha1, beta1},
|
||||
tensor_A0.device_data(),
|
||||
tensor_B1.device_data(),
|
||||
tensor_Bias1.device_data(),
|
||||
reference_D1.device_data(),
|
||||
batch_stride_A,
|
||||
batch_stride_B1,
|
||||
batch_stride_Bias,
|
||||
batch_stride_D,
|
||||
tensor_A0.stride(0),
|
||||
(broadcast_b1 ? 0 : tensor_B1.stride(0)),
|
||||
0, // zero stride for the bias vector
|
||||
reference_D1.stride(0),
|
||||
};
|
||||
|
||||
status = reference_gemm1.can_implement(args1);
|
||||
CUTLASS_CHECK(status);
|
||||
status = reference_gemm1(args1);
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(reference_D0.device_view());
|
||||
}
|
||||
|
||||
reference_gemm_1(
|
||||
problem_size,
|
||||
alpha1,
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
beta1,
|
||||
{tensor_Bias1.device_data(), typename DualGemm::LayoutC::Stride(0)},
|
||||
reference_D1.device_ref()
|
||||
);
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(reference_D1.device_view());
|
||||
}
|
||||
|
||||
TensorEpilogueForEach<EpilogueOutputOp2>(reference_D0.device_view(), reference_D1.device_view(), reference_D2.device_view());
|
||||
cudaDeviceSynchronize();
|
||||
reference_D0.sync_host();
|
||||
@@ -793,7 +902,6 @@ struct DualFusedGemmRun
|
||||
bool passed = passed_out0 && passed_out1 && passed_out2;
|
||||
if (!passed)
|
||||
{
|
||||
|
||||
std::stringstream fname;
|
||||
|
||||
fname << "error_DualGemm_device_fused.txt";
|
||||
|
||||
@@ -42,6 +42,7 @@
|
||||
|
||||
#include "../threadblock/dual_mma_multistage.h"
|
||||
#include "../threadblock/dual_epilogue.h"
|
||||
#include "../dual_gemm_common.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@@ -92,6 +93,10 @@ struct DualGemm {
|
||||
true // IterationsUnroll
|
||||
>;
|
||||
|
||||
using ElementA = typename DualMma::IteratorA::Element;
|
||||
using ElementB = typename DualMma::IteratorB0::Element;
|
||||
using ElementC = typename DualEpilogue::OutputTileIterator::Element;
|
||||
|
||||
static bool const kSplitKSerial = SplitKSerial;
|
||||
static_assert(!kSplitKSerial || (kStoreD0 && kStoreD1),
|
||||
"Split-K serial requires buffers for D0/D1 for reduction");
|
||||
@@ -102,6 +107,7 @@ struct DualGemm {
|
||||
|
||||
/// Parameters structure
|
||||
struct Params {
|
||||
DualGemmMode mode;
|
||||
cutlass::gemm::GemmCoord problem_size;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int swizzle_log_tile;
|
||||
@@ -109,8 +115,8 @@ struct DualGemm {
|
||||
// Mma0
|
||||
typename DualMma::IteratorA::Params params_A0;
|
||||
typename DualMma::IteratorA::TensorRef ref_A0;
|
||||
typename DualMma::IteratorB::Params params_B0;
|
||||
typename DualMma::IteratorB::TensorRef ref_B0;
|
||||
typename DualMma::IteratorB0::Params params_B0;
|
||||
typename DualMma::IteratorB0::TensorRef ref_B0;
|
||||
typename Epilogue0::OutputTileIterator::Params params_C0;
|
||||
typename Epilogue0::OutputTileIterator::TensorRef ref_C0;
|
||||
typename Epilogue0::OutputTileIterator::Params params_D0;
|
||||
@@ -118,8 +124,8 @@ struct DualGemm {
|
||||
typename OutputOp0::Params output_op_0;
|
||||
|
||||
// Mma1
|
||||
typename DualMma::IteratorB::Params params_B1;
|
||||
typename DualMma::IteratorB::TensorRef ref_B1;
|
||||
typename DualMma::IteratorB1::Params params_B1;
|
||||
typename DualMma::IteratorB1::TensorRef ref_B1;
|
||||
typename Epilogue1::OutputTileIterator::Params params_C1;
|
||||
typename Epilogue1::OutputTileIterator::TensorRef ref_C1;
|
||||
typename Epilogue1::OutputTileIterator::Params params_D1;
|
||||
@@ -133,6 +139,12 @@ struct DualGemm {
|
||||
int *semaphore;
|
||||
int gemm_k_size;
|
||||
|
||||
int64_t batch_stride_A;
|
||||
int64_t batch_stride_B0;
|
||||
int64_t batch_stride_B1;
|
||||
int64_t batch_stride_C;
|
||||
int64_t batch_stride_D;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
@@ -142,15 +154,16 @@ struct DualGemm {
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
DualGemmMode mode,
|
||||
cutlass::gemm::GemmCoord const & problem_size,
|
||||
cutlass::gemm::GemmCoord const & grid_tiled_shape,
|
||||
// Mma0: D0 = A @ B0 + C0
|
||||
typename DualMma::IteratorA::TensorRef ref_A0,
|
||||
typename DualMma::IteratorB::TensorRef ref_B0,
|
||||
typename DualMma::IteratorB0::TensorRef ref_B0,
|
||||
typename Epilogue0::OutputTileIterator::TensorRef ref_C0,
|
||||
typename Epilogue0::OutputTileIterator::TensorRef ref_D0,
|
||||
// Mma1: D1 = A @ B1 + C1
|
||||
typename DualMma::IteratorB::TensorRef ref_B1,
|
||||
typename DualMma::IteratorB1::TensorRef ref_B1,
|
||||
typename Epilogue1::OutputTileIterator::TensorRef ref_C1,
|
||||
typename Epilogue1::OutputTileIterator::TensorRef ref_D1,
|
||||
|
||||
@@ -158,8 +171,14 @@ struct DualGemm {
|
||||
typename OutputOp0::Params output_op_0 = typename OutputOp0::Params(),
|
||||
typename OutputOp1::Params output_op_1 = typename OutputOp1::Params(),
|
||||
typename OutputOp2::Params output_op_2 = typename OutputOp2::Params(),
|
||||
int *workspace = nullptr
|
||||
int *workspace = nullptr,
|
||||
int64_t batch_stride_A = 1,
|
||||
int64_t batch_stride_B0 = 1,
|
||||
int64_t batch_stride_B1 = 1,
|
||||
int64_t batch_stride_C = 1,
|
||||
int64_t batch_stride_D = 1
|
||||
):
|
||||
mode(mode),
|
||||
problem_size(problem_size),
|
||||
grid_tiled_shape(grid_tiled_shape),
|
||||
swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
|
||||
@@ -183,13 +202,18 @@ struct DualGemm {
|
||||
ref_D2(ref_D2),
|
||||
output_op_0(output_op_0),
|
||||
output_op_1(output_op_1),
|
||||
output_op_2(output_op_2) {
|
||||
output_op_2(output_op_2),
|
||||
batch_stride_A(batch_stride_A),
|
||||
batch_stride_B0(batch_stride_B0),
|
||||
batch_stride_B1(batch_stride_B1),
|
||||
batch_stride_C(batch_stride_C),
|
||||
batch_stride_D(batch_stride_D) {
|
||||
|
||||
int total_gemm_k_iterations = (problem_size.k() + DualMma::Shape::kK - 1) / DualMma::Shape::kK;
|
||||
int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k();
|
||||
gemm_k_size = gemm_k_iterations * DualMma::Shape::kK;
|
||||
|
||||
semaphore = workspace;
|
||||
semaphore = workspace;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -210,16 +234,16 @@ struct DualGemm {
|
||||
static Status can_implement(
|
||||
cutlass::gemm::GemmCoord const & problem_size,
|
||||
typename DualMma::IteratorA::TensorRef ref_A0,
|
||||
typename DualMma::IteratorB::TensorRef ref_B0,
|
||||
typename DualMma::IteratorB0::TensorRef ref_B0,
|
||||
typename Epilogue0::OutputTileIterator::TensorRef ref_C0,
|
||||
typename Epilogue0::OutputTileIterator::TensorRef ref_D0,
|
||||
typename DualMma::IteratorB::TensorRef ref_B1,
|
||||
typename DualMma::IteratorB1::TensorRef ref_B1,
|
||||
typename Epilogue1::OutputTileIterator::TensorRef ref_C1,
|
||||
typename Epilogue1::OutputTileIterator::TensorRef ref_D1,
|
||||
typename Epilogue1::OutputTileIterator::TensorRef ref_D2) {
|
||||
|
||||
static int const kAlignmentA = DualMma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = DualMma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentB = DualMma::IteratorB0::AccessType::kElements;
|
||||
static int const kAlignmentC = Epilogue0::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
if (!TensorRef_aligned(ref_A0, kAlignmentA)) {
|
||||
@@ -273,52 +297,66 @@ struct DualGemm {
|
||||
return;
|
||||
}
|
||||
|
||||
int offset_k = 0;
|
||||
int problem_size_k = params.problem_size.k();
|
||||
|
||||
ElementA *ptr_A0 = static_cast<ElementA *>(params.ref_A0.data());
|
||||
ElementB *ptr_B0 = static_cast<ElementB *>(params.ref_B0.data());
|
||||
ElementB *ptr_B1 = static_cast<ElementB *>(params.ref_B1.data());
|
||||
|
||||
//
|
||||
// Fetch pointers based on mode.
|
||||
//
|
||||
if (params.mode == DualGemmMode::kGemm) {
|
||||
if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) {
|
||||
problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;
|
||||
}
|
||||
|
||||
offset_k = threadblock_tile_offset.k() * params.gemm_k_size;
|
||||
}
|
||||
else if (params.mode == DualGemmMode::kBatched) {
|
||||
ptr_A0 += threadblock_tile_offset.k() * params.batch_stride_A;
|
||||
ptr_B0 += threadblock_tile_offset.k() * params.batch_stride_B0;
|
||||
ptr_B1 += threadblock_tile_offset.k() * params.batch_stride_B1;
|
||||
}
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
cutlass::MatrixCoord tb_offset_A0{
|
||||
threadblock_tile_offset.m() * DualMma::Shape::kM,
|
||||
threadblock_tile_offset.k() * params.gemm_k_size,
|
||||
offset_k,
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B0{
|
||||
threadblock_tile_offset.k() * params.gemm_k_size,
|
||||
offset_k,
|
||||
threadblock_tile_offset.n() * DualMma::Shape::kN
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B1{
|
||||
threadblock_tile_offset.k() * params.gemm_k_size,
|
||||
offset_k,
|
||||
threadblock_tile_offset.n() * DualMma::Shape::kN
|
||||
};
|
||||
|
||||
// Problem size is a function of threadblock index in the K dimension
|
||||
int problem_size_k =
|
||||
(params.problem_size.k() < (threadblock_tile_offset.k() + 1) * params.gemm_k_size) ?
|
||||
params.problem_size.k() :
|
||||
(threadblock_tile_offset.k() + 1) * params.gemm_k_size;
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
int gemm_k_iterations = (problem_size_k - tb_offset_A0.column() + DualMma::Shape::kK - 1) / DualMma::Shape::kK;
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename DualMma::IteratorA iterator_A0(
|
||||
params.params_A0,
|
||||
params.ref_A0.data(),
|
||||
ptr_A0,
|
||||
{params.problem_size.m(), problem_size_k},
|
||||
thread_idx,
|
||||
tb_offset_A0);
|
||||
|
||||
typename DualMma::IteratorB iterator_B0(
|
||||
typename DualMma::IteratorB0 iterator_B0(
|
||||
params.params_B0,
|
||||
params.ref_B0.data(),
|
||||
ptr_B0,
|
||||
{problem_size_k, params.problem_size.n()},
|
||||
thread_idx,
|
||||
tb_offset_B0);
|
||||
|
||||
typename DualMma::IteratorB iterator_B1(
|
||||
typename DualMma::IteratorB1 iterator_B1(
|
||||
params.params_B1,
|
||||
params.ref_B1.data(),
|
||||
ptr_B1,
|
||||
{problem_size_k, params.problem_size.n()},
|
||||
thread_idx,
|
||||
tb_offset_B1);
|
||||
@@ -340,6 +378,9 @@ struct DualGemm {
|
||||
accum0.clear();
|
||||
accum1.clear();
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
int gemm_k_iterations = (problem_size_k - offset_k + DualMma::Shape::kK - 1) / DualMma::Shape::kK;
|
||||
|
||||
DualMma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
|
||||
if (!kSplitKSerial || gemm_k_iterations > 0) {
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
@@ -372,31 +413,46 @@ struct DualGemm {
|
||||
|
||||
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
|
||||
|
||||
ElementC *ptr_C0 = static_cast<ElementC *>(params.ref_C0.data());
|
||||
ElementC *ptr_C1 = static_cast<ElementC *>(params.ref_C1.data());
|
||||
ElementC *ptr_D0 = static_cast<ElementC *>(params.ref_D0.data());
|
||||
ElementC *ptr_D1 = static_cast<ElementC *>(params.ref_D1.data());
|
||||
ElementC *ptr_D2 = static_cast<ElementC *>(params.ref_D2.data());
|
||||
|
||||
// Construct the semaphore.
|
||||
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
|
||||
|
||||
// If performing a reduction via split-K, fetch the initial synchronization
|
||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
// Fetch the synchronization lock initially but do not block.
|
||||
semaphore.fetch();
|
||||
if (params.mode == DualGemmMode::kGemm) {
|
||||
// If performing a reduction via split-K, fetch the initial synchronization
|
||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
// Fetch the synchronization lock initially but do not block.
|
||||
semaphore.fetch();
|
||||
|
||||
// Indicate which position in a serial reduction the output operator is currently updating
|
||||
output_op_0.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
|
||||
output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
|
||||
// Indicate which position in a serial reduction the output operator is currently updating
|
||||
output_op_0.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
|
||||
output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
|
||||
}
|
||||
}
|
||||
else if (params.mode == DualGemmMode::kBatched) {
|
||||
ptr_C0 += threadblock_tile_offset.k() * params.batch_stride_C;
|
||||
ptr_C1 += threadblock_tile_offset.k() * params.batch_stride_C;
|
||||
ptr_D0 += threadblock_tile_offset.k() * params.batch_stride_D;
|
||||
ptr_D1 += threadblock_tile_offset.k() * params.batch_stride_D;
|
||||
ptr_D2 += threadblock_tile_offset.k() * params.batch_stride_D;
|
||||
}
|
||||
|
||||
// Tile iterator loading from source tensor.
|
||||
typename Epilogue0::OutputTileIterator iterator_C0(
|
||||
params.params_C0,
|
||||
params.ref_C0.data(),
|
||||
ptr_C0,
|
||||
params.problem_size.mn(),
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
);
|
||||
typename Epilogue1::OutputTileIterator iterator_C1(
|
||||
params.params_C1,
|
||||
params.ref_C1.data(),
|
||||
ptr_C1,
|
||||
params.problem_size.mn(),
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
@@ -405,21 +461,21 @@ struct DualGemm {
|
||||
// Tile iterator writing to destination tensor.
|
||||
typename Epilogue0::OutputTileIterator iterator_D0(
|
||||
params.params_D0,
|
||||
params.ref_D0.data(),
|
||||
ptr_D0,
|
||||
params.problem_size.mn(),
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
);
|
||||
typename Epilogue1::OutputTileIterator iterator_D1(
|
||||
params.params_D1,
|
||||
params.ref_D1.data(),
|
||||
ptr_D1,
|
||||
params.problem_size.mn(),
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
);
|
||||
typename Epilogue1::OutputTileIterator iterator_D2(
|
||||
params.params_D2,
|
||||
params.ref_D2.data(),
|
||||
ptr_D2,
|
||||
params.problem_size.mn(),
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
|
||||
@@ -58,7 +58,9 @@ template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
typename Policy0_,
|
||||
/// B1-specific version of the policy (concept: MmaPolicy)
|
||||
typename Policy1_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Used for partial specialization
|
||||
@@ -69,18 +71,20 @@ class DualMmaBase {
|
||||
using Shape = Shape_;
|
||||
|
||||
///< Policy describing tuning details
|
||||
using Policy = Policy_;
|
||||
using Policy0 = Policy0_;
|
||||
using Policy1 = Policy1_;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
using Operator0 = typename Policy0::Operator;
|
||||
using Operator1 = typename Policy1::Operator;
|
||||
|
||||
/// Shape describing the overall GEMM computed from shared memory
|
||||
/// by each warp.
|
||||
using WarpGemm = typename Policy::Operator::Shape;
|
||||
using WarpGemm = typename Policy0::Operator::Shape;
|
||||
|
||||
/// Shape describing the number of warps filling the CTA
|
||||
using WarpCount = GemmShape<Shape::kM / WarpGemm::kM,
|
||||
@@ -89,16 +93,17 @@ class DualMmaBase {
|
||||
|
||||
/// Number of warp-level GEMM oeprations
|
||||
static int const kWarpGemmIterations =
|
||||
(WarpGemm::kK / Operator::Policy::MmaShape::kK);
|
||||
(WarpGemm::kK / Operator0::Policy::MmaShape::kK);
|
||||
|
||||
/// Number of stages
|
||||
static int const kStages = Stages;
|
||||
|
||||
/// Tensor reference to the A operand
|
||||
using TensorRefA = TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;
|
||||
using TensorRefA = TensorRef<typename Operator0::ElementA, typename Operator0::LayoutA>;
|
||||
|
||||
/// Tensor reference to the B operand
|
||||
using TensorRefB = TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
|
||||
using TensorRefB0 = TensorRef<typename Operator0::ElementB, typename Operator0::LayoutB>;
|
||||
using TensorRefB1 = TensorRef<typename Operator1::ElementB, typename Operator1::LayoutB>;
|
||||
|
||||
static_assert(kWarpGemmIterations > 1,
|
||||
"The pipelined structure requires at least two warp-level "
|
||||
@@ -119,14 +124,17 @@ class DualMmaBase {
|
||||
//
|
||||
|
||||
/// Shape of the A matrix operand in shared memory
|
||||
using ShapeA = MatrixShape<Shape::kM + Policy::SmemPaddingA::kRow,
|
||||
using ShapeA = MatrixShape<Shape::kM + Policy0::SmemPaddingA::kRow,
|
||||
Shape::kK * kStages +
|
||||
Policy::SmemPaddingA::kColumn>;
|
||||
Policy0::SmemPaddingA::kColumn>;
|
||||
|
||||
/// Shape of the B matrix operand in shared memory
|
||||
using ShapeB =
|
||||
MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow,
|
||||
Shape::kN + Policy::SmemPaddingB::kColumn>;
|
||||
using ShapeB0 =
|
||||
MatrixShape<Shape::kK * kStages + Policy0::SmemPaddingB::kRow,
|
||||
Shape::kN + Policy0::SmemPaddingB::kColumn>;
|
||||
using ShapeB1 =
|
||||
MatrixShape<Shape::kK * kStages + Policy1::SmemPaddingB::kRow,
|
||||
Shape::kN + Policy1::SmemPaddingB::kColumn>;
|
||||
|
||||
public:
|
||||
//
|
||||
@@ -134,11 +142,11 @@ class DualMmaBase {
|
||||
//
|
||||
|
||||
/// Buffer for A operand
|
||||
AlignedBuffer<typename Operator::ElementA, ShapeA::kCount> operand_A;
|
||||
AlignedBuffer<typename Operator0::ElementA, ShapeA::kCount> operand_A;
|
||||
|
||||
/// Buffer for B operand
|
||||
AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B0;
|
||||
AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B1;
|
||||
AlignedBuffer<typename Operator0::ElementB, ShapeB0::kCount> operand_B0;
|
||||
AlignedBuffer<typename Operator1::ElementB, ShapeB1::kCount> operand_B1;
|
||||
|
||||
public:
|
||||
|
||||
@@ -148,14 +156,20 @@ class DualMmaBase {
|
||||
|
||||
/// Returns a layout object for the A matrix
|
||||
CUTLASS_DEVICE
|
||||
static typename Operator::LayoutA LayoutA() {
|
||||
return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});
|
||||
static typename Operator0::LayoutA LayoutA() {
|
||||
return Operator0::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});
|
||||
}
|
||||
|
||||
/// Returns a layout object for the B matrix
|
||||
CUTLASS_HOST_DEVICE
|
||||
static typename Operator::LayoutB LayoutB() {
|
||||
return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});
|
||||
static typename Operator0::LayoutB LayoutB0() {
|
||||
return Operator0::LayoutB::packed({ShapeB0::kRow, ShapeB0::kColumn});
|
||||
}
|
||||
|
||||
/// Returns a layout object for the B matrix
|
||||
CUTLASS_HOST_DEVICE
|
||||
static typename Operator1::LayoutB LayoutB1() {
|
||||
return Operator1::LayoutB::packed({ShapeB1::kRow, ShapeB1::kColumn});
|
||||
}
|
||||
|
||||
/// Returns a TensorRef to the A operand
|
||||
@@ -166,12 +180,12 @@ class DualMmaBase {
|
||||
|
||||
/// Returns a TensorRef to the B operand
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefB operand_B0_ref() {
|
||||
return TensorRefB{operand_B0.data(), LayoutB()};
|
||||
TensorRefB0 operand_B0_ref() {
|
||||
return TensorRefB0{operand_B0.data(), LayoutB0()};
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefB operand_B1_ref() {
|
||||
return TensorRefB{operand_B1.data(), LayoutB()};
|
||||
TensorRefB1 operand_B1_ref() {
|
||||
return TensorRefB1{operand_B1.data(), LayoutB1()};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -182,11 +196,11 @@ class DualMmaBase {
|
||||
//
|
||||
|
||||
/// Iterator to load a warp-scoped tile of A operand from shared memory
|
||||
typename Operator::IteratorA warp_tile_iterator_A_;
|
||||
typename Operator0::IteratorA warp_tile_iterator_A_;
|
||||
|
||||
/// Iterator to load a warp-scoped tile of B operand from shared memory
|
||||
typename Operator::IteratorB warp_tile_iterator_B0_;
|
||||
typename Operator::IteratorB warp_tile_iterator_B1_;
|
||||
typename Operator0::IteratorB warp_tile_iterator_B0_;
|
||||
typename Operator1::IteratorB warp_tile_iterator_B1_;
|
||||
|
||||
public:
|
||||
|
||||
|
||||
@@ -67,21 +67,30 @@ template <
|
||||
typename SmemIteratorA_,
|
||||
/// Cache operation for operand A
|
||||
cutlass::arch::CacheOperation::Kind CacheOpA,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
/// Iterates over tiles of B0 operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorB_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
typename IteratorB0_,
|
||||
/// Iterates over tiles of B0 operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB_,
|
||||
typename SmemIteratorB0_,
|
||||
/// Cache operation for operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB,
|
||||
/// Iterates over tiles of B1 operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorB1_,
|
||||
/// Iterates over tiles of B1 operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB1_,
|
||||
/// Data type of accumulator matrix
|
||||
typename ElementC_,
|
||||
/// Data type of accumulator matrix
|
||||
typename LayoutC_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
typename Policy0_,
|
||||
/// B1-specific version of the policy (concept: MmaPolicy)
|
||||
typename Policy1_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
@@ -89,25 +98,29 @@ template <
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class DualMmaMultistage :
|
||||
public DualMmaBase<Shape_, Policy_, Stages> {
|
||||
public DualMmaBase<Shape_, Policy0_, Policy1_, Stages> {
|
||||
public:
|
||||
///< Base class
|
||||
using Base = DualMmaBase<Shape_, Policy_, Stages>;
|
||||
using Base = DualMmaBase<Shape_, Policy0_, Policy1_, Stages>;
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape = Shape_;
|
||||
///< Iterates over tiles of A operand in global memory
|
||||
using IteratorA = IteratorA_;
|
||||
///< Iterates over tiles of B operand in global memory
|
||||
using IteratorB = IteratorB_;
|
||||
///< Iterates over tiles of B0 operand in global memory
|
||||
using IteratorB0 = IteratorB0_;
|
||||
///< Iterates over tiles of B1 operand in global memory
|
||||
using IteratorB1 = IteratorB1_;
|
||||
///< Data type of accumulator matrix
|
||||
using ElementC = ElementC_;
|
||||
///< Layout of accumulator matrix
|
||||
using LayoutC = LayoutC_;
|
||||
///< Policy describing tuning details
|
||||
using Policy = Policy_;
|
||||
using Policy0 = Policy0_;
|
||||
using Policy1 = Policy1_;
|
||||
|
||||
using SmemIteratorA = SmemIteratorA_;
|
||||
using SmemIteratorB = SmemIteratorB_;
|
||||
using SmemIteratorB0 = SmemIteratorB0_;
|
||||
using SmemIteratorB1 = SmemIteratorB1_;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
|
||||
@@ -117,19 +130,21 @@ public:
|
||||
//
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC = typename Policy::Operator::FragmentC;
|
||||
using FragmentC = typename Policy0::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
using Operator0 = typename Policy0::Operator;
|
||||
using Operator1 = typename Policy1::Operator;
|
||||
|
||||
/// Minimum architecture is Sm80 to support cp.async
|
||||
using ArchTag = arch::Sm80;
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = Operator::kTransformA;
|
||||
static ComplexTransform const kTransformA = Operator0::kTransformA;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB = Operator::kTransformB;
|
||||
static ComplexTransform const kTransformB0 = Operator0::kTransformB;
|
||||
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
|
||||
|
||||
/// Internal structure exposed for introspection.
|
||||
struct Detail {
|
||||
@@ -140,7 +155,7 @@ public:
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand B
|
||||
static int const AsyncCopyIterationsPerStageB =
|
||||
IteratorB::ThreadMap::Iterations::kCount;
|
||||
IteratorB0::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of stages
|
||||
static int const kStages = Stages;
|
||||
@@ -156,10 +171,12 @@ public:
|
||||
|
||||
private:
|
||||
|
||||
using WarpLoadedFragmentA = typename Operator::FragmentA;
|
||||
using WarpLoadedFragmentB = typename Operator::FragmentB;
|
||||
using WarpTransformedFragmentA = typename Operator::TransformedFragmentA;
|
||||
using WarpTransformedFragmentB = typename Operator::TransformedFragmentB;
|
||||
using WarpLoadedFragmentA = typename Operator0::FragmentA;
|
||||
using WarpLoadedFragmentB0 = typename Operator0::FragmentB;
|
||||
using WarpLoadedFragmentB1 = typename Operator1::FragmentB;
|
||||
using WarpTransformedFragmentA = typename Operator0::TransformedFragmentA;
|
||||
using WarpTransformedFragmentB0 = typename Operator0::TransformedFragmentB;
|
||||
using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB;
|
||||
|
||||
private:
|
||||
|
||||
@@ -171,8 +188,8 @@ public:
|
||||
SmemIteratorA smem_iterator_A_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB smem_iterator_B0_;
|
||||
SmemIteratorB smem_iterator_B1_;
|
||||
SmemIteratorB0 smem_iterator_B0_;
|
||||
SmemIteratorB1 smem_iterator_B1_;
|
||||
|
||||
public:
|
||||
|
||||
@@ -215,7 +232,7 @@ public:
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance(IteratorA &iterator_A, IteratorB &iterator_B0, IteratorB &iterator_B1,
|
||||
void copy_tiles_and_advance(IteratorA &iterator_A, IteratorB0 &iterator_B0, IteratorB1 &iterator_B1,
|
||||
int group_start_A = 0, int group_start_B = 0) {
|
||||
iterator_A.set_iteration_index(group_start_A *
|
||||
IteratorA::kAccessesPerVector);
|
||||
@@ -253,9 +270,9 @@ public:
|
||||
}
|
||||
|
||||
iterator_B0.set_iteration_index(group_start_B *
|
||||
IteratorB::kAccessesPerVector);
|
||||
IteratorB0::kAccessesPerVector);
|
||||
iterator_B1.set_iteration_index(group_start_B *
|
||||
IteratorB::kAccessesPerVector);
|
||||
IteratorB1::kAccessesPerVector);
|
||||
this->smem_iterator_B0_.set_iteration_index(group_start_B);
|
||||
this->smem_iterator_B1_.set_iteration_index(group_start_B);
|
||||
|
||||
@@ -263,16 +280,16 @@ public:
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) {
|
||||
if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) {
|
||||
typename IteratorB::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB::AccessType *>(
|
||||
typename IteratorB0::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB0::AccessType *>(
|
||||
this->smem_iterator_B0_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value *
|
||||
IteratorB::ThreadMap::kElementsPerAccess /
|
||||
IteratorB::kAccessesPerVector / 8;
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorB0::Element>::value *
|
||||
IteratorB0::ThreadMap::kElementsPerAccess /
|
||||
IteratorB0::kAccessesPerVector / 8;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
|
||||
for (int v = 0; v < IteratorB0::kAccessesPerVector; ++v) {
|
||||
auto gmem_ptr = iterator_B0.get();
|
||||
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
|
||||
@@ -292,16 +309,16 @@ public:
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) {
|
||||
if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) {
|
||||
typename IteratorB::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB::AccessType *>(
|
||||
typename IteratorB1::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB1::AccessType *>(
|
||||
this->smem_iterator_B1_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value *
|
||||
IteratorB::ThreadMap::kElementsPerAccess /
|
||||
IteratorB::kAccessesPerVector / 8;
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorB1::Element>::value *
|
||||
IteratorB1::ThreadMap::kElementsPerAccess /
|
||||
IteratorB1::kAccessesPerVector / 8;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
|
||||
for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) {
|
||||
auto gmem_ptr = iterator_B1.get();
|
||||
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
|
||||
@@ -330,8 +347,8 @@ public:
|
||||
///< iterator over A operand in global memory
|
||||
IteratorA iterator_A,
|
||||
///< iterator over B operand in global memory
|
||||
IteratorB iterator_B0,
|
||||
IteratorB iterator_B1,
|
||||
IteratorB0 iterator_B0,
|
||||
IteratorB1 iterator_B1,
|
||||
///< initial value of accumulator
|
||||
FragmentC const &src_accum0,
|
||||
FragmentC const &src_accum1
|
||||
@@ -386,16 +403,16 @@ public:
|
||||
// Async Copy for operand B0
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
|
||||
typename IteratorB::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB::AccessType *>(
|
||||
typename IteratorB0::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB0::AccessType *>(
|
||||
this->smem_iterator_B0_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
|
||||
for (int v = 0; v < IteratorB0::kAccessesPerVector; ++v) {
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorB::Element>::value *
|
||||
IteratorB::ThreadMap::kElementsPerAccess /
|
||||
IteratorB::kAccessesPerVector / 8;
|
||||
sizeof_bits<typename IteratorB0::Element>::value *
|
||||
IteratorB0::ThreadMap::kElementsPerAccess /
|
||||
IteratorB0::kAccessesPerVector / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
|
||||
dst_ptr + v, iterator_B0.get(), iterator_B0.valid());
|
||||
@@ -408,16 +425,16 @@ public:
|
||||
// Async Copy for operand B1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
|
||||
typename IteratorB::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB::AccessType *>(
|
||||
typename IteratorB1::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB1::AccessType *>(
|
||||
this->smem_iterator_B1_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
|
||||
for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) {
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorB::Element>::value *
|
||||
IteratorB::ThreadMap::kElementsPerAccess /
|
||||
IteratorB::kAccessesPerVector / 8;
|
||||
sizeof_bits<typename IteratorB1::Element>::value *
|
||||
IteratorB1::ThreadMap::kElementsPerAccess /
|
||||
IteratorB1::kAccessesPerVector / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
|
||||
dst_ptr + v, iterator_B1.get(), iterator_B1.valid());
|
||||
@@ -473,35 +490,35 @@ public:
|
||||
++last_smem_iterator_A;
|
||||
}
|
||||
|
||||
typename IteratorB::AccessType zero_B;
|
||||
typename IteratorB0::AccessType zero_B;
|
||||
zero_B.clear();
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B0 operand to shared memory
|
||||
SmemIteratorB last_smem_iterator_B0(this->smem_iterator_B0_);
|
||||
SmemIteratorB0 last_smem_iterator_B0(this->smem_iterator_B0_);
|
||||
last_smem_iterator_B0.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand B
|
||||
// Async Copy for operand B0
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
|
||||
|
||||
typename IteratorB::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB::AccessType *>(
|
||||
typename IteratorB0::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB0::AccessType *>(
|
||||
last_smem_iterator_B0.get());
|
||||
|
||||
*dst_ptr = zero_B;
|
||||
|
||||
++last_smem_iterator_B0;
|
||||
}
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B1 operand to shared memory
|
||||
SmemIteratorB last_smem_iterator_B1(this->smem_iterator_B1_);
|
||||
SmemIteratorB1 last_smem_iterator_B1(this->smem_iterator_B1_);
|
||||
last_smem_iterator_B1.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand B
|
||||
// Async Copy for operand B1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
|
||||
|
||||
typename IteratorB::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB::AccessType *>(
|
||||
typename IteratorB1::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB1::AccessType *>(
|
||||
last_smem_iterator_B1.get());
|
||||
|
||||
*dst_ptr = zero_B;
|
||||
@@ -517,13 +534,14 @@ public:
|
||||
// Pair of fragments used to overlap shared memory loads and math
|
||||
// instructions
|
||||
WarpLoadedFragmentA warp_loaded_frag_A[2];
|
||||
WarpLoadedFragmentB warp_loaded_frag_B0[2];
|
||||
WarpLoadedFragmentB warp_loaded_frag_B1[2];
|
||||
WarpLoadedFragmentB0 warp_loaded_frag_B0[2];
|
||||
WarpLoadedFragmentB1 warp_loaded_frag_B1[2];
|
||||
WarpTransformedFragmentA warp_transformed_frag_A[2];
|
||||
WarpTransformedFragmentB warp_transformed_frag_B0[2];
|
||||
WarpTransformedFragmentB warp_transformed_frag_B1[2];
|
||||
WarpTransformedFragmentB0 warp_transformed_frag_B0[2];
|
||||
WarpTransformedFragmentB1 warp_transformed_frag_B1[2];
|
||||
|
||||
Operator warp_mma;
|
||||
Operator0 warp_mma0;
|
||||
Operator1 warp_mma1;
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B0_.set_kgroup_index(0);
|
||||
@@ -544,10 +562,10 @@ public:
|
||||
int smem_write_stage_idx = Base::kStages - 1;
|
||||
int smem_read_stage_idx = 0;
|
||||
|
||||
warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B0[0],
|
||||
warp_loaded_frag_A[0], warp_loaded_frag_B0[0]);
|
||||
warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B1[0],
|
||||
warp_loaded_frag_A[0], warp_loaded_frag_B1[0]);
|
||||
warp_mma0.transform(warp_transformed_frag_A[0], warp_transformed_frag_B0[0],
|
||||
warp_loaded_frag_A[0], warp_loaded_frag_B0[0]);
|
||||
warp_mma1.transform(warp_transformed_frag_A[0], warp_transformed_frag_B1[0],
|
||||
warp_loaded_frag_A[0], warp_loaded_frag_B1[0]);
|
||||
|
||||
// tf32x3 kernels use staging accumulation. warp_mma uses a temporary
|
||||
// accumulator and this temporary accumulator is added to the final
|
||||
@@ -556,9 +574,9 @@ public:
|
||||
|
||||
FragmentC tmp_accum0, tmp_accum1;
|
||||
|
||||
if (platform::is_same<typename Operator::MathOperator,
|
||||
if (platform::is_same<typename Operator0::MathOperator,
|
||||
arch::OpMultiplyAddFastF32>::value
|
||||
|| platform::is_same<typename Operator::MathOperator,
|
||||
|| platform::is_same<typename Operator0::MathOperator,
|
||||
arch::OpMultiplyAddComplexFastF32>::value) {
|
||||
|
||||
tmp_accum0.clear();
|
||||
@@ -597,28 +615,28 @@ public:
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
if (warp_mma_k > 0) {
|
||||
warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2],
|
||||
warp_transformed_frag_B0[warp_mma_k % 2],
|
||||
warp_loaded_frag_A[warp_mma_k % 2],
|
||||
warp_loaded_frag_B0[warp_mma_k % 2]);
|
||||
warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2],
|
||||
warp_transformed_frag_B1[warp_mma_k % 2],
|
||||
warp_loaded_frag_A[warp_mma_k % 2],
|
||||
warp_loaded_frag_B1[warp_mma_k % 2]);
|
||||
warp_mma0.transform(warp_transformed_frag_A[warp_mma_k % 2],
|
||||
warp_transformed_frag_B0[warp_mma_k % 2],
|
||||
warp_loaded_frag_A[warp_mma_k % 2],
|
||||
warp_loaded_frag_B0[warp_mma_k % 2]);
|
||||
warp_mma1.transform(warp_transformed_frag_A[warp_mma_k % 2],
|
||||
warp_transformed_frag_B1[warp_mma_k % 2],
|
||||
warp_loaded_frag_A[warp_mma_k % 2],
|
||||
warp_loaded_frag_B1[warp_mma_k % 2]);
|
||||
}
|
||||
|
||||
if (platform::is_same<typename Operator::MathOperator,
|
||||
if (platform::is_same<typename Operator0::MathOperator,
|
||||
arch::OpMultiplyAddFastF32>::value
|
||||
|| platform::is_same<typename Operator::MathOperator,
|
||||
|| platform::is_same<typename Operator0::MathOperator,
|
||||
arch::OpMultiplyAddComplexFastF32>::value) {
|
||||
|
||||
warp_mma(
|
||||
warp_mma0(
|
||||
tmp_accum0,
|
||||
warp_transformed_frag_A[warp_mma_k % 2],
|
||||
warp_transformed_frag_B0[warp_mma_k % 2],
|
||||
tmp_accum0
|
||||
);
|
||||
warp_mma(
|
||||
warp_mma1(
|
||||
tmp_accum1,
|
||||
warp_transformed_frag_A[warp_mma_k % 2],
|
||||
warp_transformed_frag_B1[warp_mma_k % 2],
|
||||
@@ -632,13 +650,13 @@ public:
|
||||
tmp_accum1.clear();
|
||||
}
|
||||
} else {
|
||||
warp_mma(
|
||||
warp_mma0(
|
||||
accum0,
|
||||
warp_transformed_frag_A[warp_mma_k % 2],
|
||||
warp_transformed_frag_B0[warp_mma_k % 2],
|
||||
accum0
|
||||
);
|
||||
warp_mma(
|
||||
warp_mma1(
|
||||
accum1,
|
||||
warp_transformed_frag_A[warp_mma_k % 2],
|
||||
warp_transformed_frag_B1[warp_mma_k % 2],
|
||||
@@ -696,14 +714,14 @@ public:
|
||||
|
||||
if (smem_read_stage_idx == (Base::kStages - 1)) {
|
||||
this->warp_tile_iterator_A_.add_tile_offset(
|
||||
{0, -Base::kStages * Policy::kPartitionsK *
|
||||
{0, -Base::kStages * Policy0::kPartitionsK *
|
||||
Base::kWarpGemmIterations});
|
||||
this->warp_tile_iterator_B0_.add_tile_offset(
|
||||
{-Base::kStages * Policy::kPartitionsK *
|
||||
{-Base::kStages * Policy0::kPartitionsK *
|
||||
Base::kWarpGemmIterations,
|
||||
0});
|
||||
this->warp_tile_iterator_B1_.add_tile_offset(
|
||||
{-Base::kStages * Policy::kPartitionsK *
|
||||
{-Base::kStages * Policy1::kPartitionsK *
|
||||
Base::kWarpGemmIterations,
|
||||
0});
|
||||
smem_read_stage_idx = 0;
|
||||
@@ -720,22 +738,22 @@ public:
|
||||
// Do any conversions feeding the first stage at the end of the loop so
|
||||
// we can start right away on mma instructions
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations) {
|
||||
warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2],
|
||||
warp_transformed_frag_B0[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_A[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_B0[(warp_mma_k + 1) % 2]);
|
||||
warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2],
|
||||
warp_transformed_frag_B1[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_A[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_B1[(warp_mma_k + 1) % 2]);
|
||||
warp_mma0.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2],
|
||||
warp_transformed_frag_B0[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_A[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_B0[(warp_mma_k + 1) % 2]);
|
||||
warp_mma1.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2],
|
||||
warp_transformed_frag_B1[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_A[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_B1[(warp_mma_k + 1) % 2]);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if (platform::is_same<typename Operator::MathOperator,
|
||||
if (platform::is_same<typename Operator0::MathOperator,
|
||||
arch::OpMultiplyAddFastF32>::value
|
||||
|| platform::is_same<typename Operator::MathOperator,
|
||||
|| platform::is_same<typename Operator0::MathOperator,
|
||||
arch::OpMultiplyAddComplexFastF32>::value) {
|
||||
accum0 = plus_accum(accum0, tmp_accum0);
|
||||
accum1 = plus_accum(accum1, tmp_accum1);
|
||||
|
||||
Reference in New Issue
Block a user