diff --git a/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h b/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h index 306e8cf47..e7ef56b96 100644 --- a/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h +++ b/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h @@ -341,7 +341,7 @@ struct B2bGemm { OutputOp0 output_op_0(params.output_op_0); // Construct thread-scoped matrix multiply - B2bMma b2bMma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + B2bMma b2bMma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx, params.problem_size_0.n()); typename B2bMma::FragmentC0 src_accum; typename B2bMma::FragmentC1 accumulators; diff --git a/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage.h b/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage.h index 8104f6385..c9f51f4e9 100644 --- a/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage.h +++ b/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage.h @@ -267,7 +267,9 @@ public: ///< ID of warp int warp_idx, ///< ID of each thread within a warp - int lane_idx + int lane_idx, + ///< GEMM0 N is used for accumulator extent + int problem_size_0_n ): Base(shared_storage, thread_idx, warp_idx, lane_idx), smem_iterator_A0_(shared_storage.shared_storage0.operand_A_ref(), thread_idx), @@ -639,7 +641,6 @@ public: } - // 2nd Gemm /// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile @@ -657,12 +658,11 @@ public: tb_frag_A1_bias.clear(); iterator_A1_bias.load(tb_frag_A1_bias); ++iterator_A1_bias; - - + // // Prologue // - int gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1; + int gemm_k_iterations_1 = (FragmentIteratorA1::Policy::kIterations + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1; // Issue several complete stages CUTLASS_PRAGMA_UNROLL @@ -750,9 +750,9 @@ public: // Mainloop // + gemm_k_iterations_1 = (FragmentIteratorA1::Policy::kIterations + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1 - (Base::kStages - 1); CUTLASS_PRAGMA_UNROLL - for (gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1 - (Base::kStages - 1); - gemm_k_iterations_1 > (-Base::kStages + 1); gemm_k_iterations_1--) { + for (; gemm_k_iterations_1 > (-Base::kStages + 1); gemm_k_iterations_1--) { // // Loop over GEMM K dimension // diff --git a/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage_smem_accumulator.h b/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage_smem_accumulator.h index c28f4e49c..2b0eca00c 100644 --- a/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage_smem_accumulator.h +++ b/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage_smem_accumulator.h @@ -276,13 +276,15 @@ public: ///< ID of warp int warp_idx, ///< ID of each thread within a warp - int lane_idx + int lane_idx, + ///< GEMM0 N is used for accumulator extent + int problem_size_0_n ): Base(shared_storage, thread_idx, warp_idx, lane_idx), smem_iterator_A0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_A_ref(), thread_idx), smem_iterator_B0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_B_ref(), thread_idx), smem_iterator_D0_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx), - warp_tile_iterator_A1_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx), + warp_tile_iterator_A1_(shared_storage.accumulator_shared_storage0.accum_ref(), {Base::WarpGemm1::kM, problem_size_0_n}, lane_idx ), smem_iterator_B1_(shared_storage.b2b_mma_shared_storage.shared_storage1.operand_B_ref(), thread_idx) { // Compute warp location within threadblock tile by mapping the warp_id to diff --git a/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined.h b/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined.h index 4e39fda5b..d43ba46bc 100644 --- a/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined.h +++ b/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined.h @@ -228,7 +228,8 @@ public: typename Base::B2bMmaSharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM int thread_idx, ///< ID within the threadblock int warp_idx, ///< ID of warp - int lane_idx ///< ID of each thread within a warp + int lane_idx, ///< ID of each thread within a warp + int problem_size_0_n ///< GEMM0 N is used for accumulator extent ): Base(shared_storage, thread_idx, warp_idx, lane_idx), smem_iterator_A_(shared_storage.shared_storage0.operand_A_ref(), thread_idx), diff --git a/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h b/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h index b548c8576..c466b0cb7 100644 --- a/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h +++ b/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h @@ -236,13 +236,14 @@ public: typename Base::B2bMmaSharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM int thread_idx, ///< ID within the threadblock int warp_idx, ///< ID of warp - int lane_idx ///< ID of each thread within a warp + int lane_idx, ///< ID of each thread within a warp + int problem_size_0_n ///< GEMM0 N is used for accumulator extent ): Base(shared_storage, thread_idx, warp_idx, lane_idx), smem_iterator_A_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_A_ref(), thread_idx), smem_iterator_B0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_B_ref(), thread_idx), smem_iterator_D0_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx), - warp_tile_iterator_A1_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx), + warp_tile_iterator_A1_(shared_storage.accumulator_shared_storage0.accum_ref(), {Base::WarpGemm1::kM, problem_size_0_n}, lane_idx), smem_iterator_B1_(shared_storage.b2b_mma_shared_storage.shared_storage1.operand_B_ref(), thread_idx) { // Compute warp location within threadblock tile by mapping the warp_id to diff --git a/examples/13_two_tensor_op_fusion/threadblock/default_b2b_mma_smem_accumulator.h b/examples/13_two_tensor_op_fusion/threadblock/default_b2b_mma_smem_accumulator.h index ea1a258fb..0ea41c712 100644 --- a/examples/13_two_tensor_op_fusion/threadblock/default_b2b_mma_smem_accumulator.h +++ b/examples/13_two_tensor_op_fusion/threadblock/default_b2b_mma_smem_accumulator.h @@ -43,7 +43,7 @@ #include "cutlass/gemm/threadblock/default_mma_core_sm70.h" #include "cutlass/gemm/threadblock/default_mma_core_sm75.h" #include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h" #include "threadblock/b2b_mma_pipelined_smem_accumulator.h" #include "threadblock/b2b_mma_multistage_smem_accumulator.h" @@ -158,11 +158,11 @@ struct DefaultB2bMma, cutlass::gemm::Operand::kA, + using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator< + MatrixShape, cutlass::gemm::Operand::kA, ElementA, SmemAccumulatorLayout, MatrixShape, - WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>; + WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount, true>; // Define the threadblock-scoped pipelined matrix multiply using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaPipelinedSmemAccumulator< @@ -303,11 +303,11 @@ struct DefaultB2bMma, cutlass::gemm::Operand::kA, + using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator< + MatrixShape, cutlass::gemm::Operand::kA, ElementA, SmemAccumulatorLayout, MatrixShape, - WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>; + WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount, true>; // Define the threadblock-scoped pipelined matrix multiply using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaMultistageSmemAccumulator< @@ -436,11 +436,11 @@ struct DefaultB2bMma, cutlass::gemm::Operand::kA, + using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator< + MatrixShape, cutlass::gemm::Operand::kA, ElementA, SmemAccumulatorLayout, MatrixShape, - WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>; + WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount, true>; // Define the threadblock-scoped pipelined matrix multiply using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaPipelinedSmemAccumulator< @@ -574,11 +574,11 @@ struct DefaultB2bMma, cutlass::gemm::Operand::kA, + using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator< + MatrixShape, cutlass::gemm::Operand::kA, ElementA, SmemAccumulatorLayout, MatrixShape, - WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>; + WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount, true >; // Define the threadblock-scoped multistage matrix multiply diff --git a/include/cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h b/include/cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h new file mode 100644 index 000000000..fce3fa592 --- /dev/null +++ b/include/cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h @@ -0,0 +1,362 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor_op_multiplicand_sm80.h" + +#include "cutlass/platform/platform.h" +#include "cutlass/fast_math.h" + +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace warp { + + +/// Tile access iterator +/// Each iteration acess in the tile is +/// used as multiplicand for one +/// warp-level matrix multiplication +template < + /// Size of the tile (concept: MatrixShape) + typename Shape_, + /// Operand identity + Operand Operand_, + /// Data type of A elements + typename Element_, + /// Layout of operand + typename Layout_, + /// Shape of one matrix production operation (concept: MatrixShape) + typename InstructionShape_, + /// Delta between *MMA operations (in units of *MMA operations, concept: + /// MatrixShape) + int OpDelta_, + /// Number of threads participating in one matrix operation + int Threads = 32, + /// Enable Residual Support + bool EnableResidual = false, + /// Number of partitions along K dimension + int PartitionsK_ = 1 +> +class MmaTensorOpMultiplicandTileAccessIterator { + public: + + /// Shape of tile to load (concept: MatrixShape) + using Shape = Shape_; + + /// Operand tag + static Operand const kOperand = Operand_; + + /// Basic check + static_assert(kOperand == Operand::kA || kOperand== Operand::kB, + "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); + + /// Element type + using Element = Element_; + + /// Layout of source tile + using Layout = Layout_; + + /// Shape of one matrix product operation (concept: MatrixShape) + using InstructionShape = InstructionShape_; + + /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) + static int const kOpDelta = OpDelta_; + + /// Number of participating threads + static int const kThreads = 32; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + /// Index type + using Index = typename TensorRef::Index; + + /// Long Index type + using LongIndex = typename TensorRef::LongIndex; + + /// Coordinate for an element in the tensor + using TensorCoord = typename TensorRef::TensorCoord; + + /// Number of elements accessed per Shared Memory load + static int const kElementsPerAccess = + (sizeof_bits::value >= 32 ? 1 : 32 / sizeof_bits::value); + + using InstructionCount = MatrixShape< + Shape::kRow / InstructionShape::kRow, + Shape::kColumn / InstructionShape::kColumn + >; + + static int const kIterations = (kOperand == Operand::kA) ? + InstructionCount::kColumn : InstructionCount::kRow; + + +public: + + // + // Derived quantities + // + + /// Fragment object holding a thread's part of a tile + using Fragment = Array< + Element, + (kOperand == Operand::kA) ? + (Shape::kRow * InstructionShape::kColumn / kThreads) : + (Shape::kColumn * InstructionShape::kRow / kThreads) + >; + + /// Memory access type + using AccessType = AlignedArray; + +private: + + /// Underlying tensor reference + TensorRef ref_; + + /// Extent of tensor + MatrixCoord extent_; + + /// Origin + MatrixCoord origin_; + + /// Used to load residual tile + bool is_residual_; + + /// residual offset of each thread + TensorCoord residual_offset_; + + /// Iterations in a tile + int iterations_; + +public: + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileAccessIterator( + TensorRef const &ref, + TensorCoord extent, + int lane_id + ): ref_(ref), extent_(extent), is_residual_(false), iterations_(0) { + + if (kOperand == Operand::kA) { + origin_ = MatrixCoord(lane_id / 4, (lane_id % 4) * kElementsPerAccess); + } + else { + origin_ = MatrixCoord((lane_id % 4) * kElementsPerAccess, lane_id / 4); + } + + ref_.add_coord_offset(origin_); + + if(EnableResidual) { + // compute residual offset + if (kOperand == Operand::kA) { + typename TensorCoord::Index residual_size = + extent_.column() % Shape::kColumn; + if(residual_size) { + is_residual_ = true; + residual_offset_ = make_Coord(0, residual_size); + } + } + else { + typename TensorCoord::Index residual_size = + extent_.row() % Shape::kRow; + if(residual_size) { + is_residual_ = true; + residual_offset_ = make_Coord(residual_size, 0); + } + } + } + } + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileAccessIterator( + TensorRef const &ref, + int lane_id + ): MmaTensorOpMultiplicandTileAccessIterator(ref, + {Shape::kRow, Shape::kColumn}, lane_id) { + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileAccessIterator &add_tile_offset(TensorCoord const &tile_offset) { + + TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); + origin_ += coord_offset; + + ref_.add_coord_offset(coord_offset); + + + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_DEVICE + void advance() { + + if(EnableResidual && is_residual_) { + is_residual_ = false; + + origin_ += residual_offset_; + ref_.add_coord_offset(residual_offset_); + + } + + else { + if (kOperand == Operand::kA) { + add_tile_offset({0, 1}); + } + else { + add_tile_offset({1, 0}); + } + } + + iterations_ = 0; + } + + /// increase iterations in a tile + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileAccessIterator & operator++() { + + iterations_++; + + if(iterations_ >= kIterations) + advance(); + + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + + int const kWarpShapeDivisibleInner = + (kOperand == Operand::kA ? InstructionShape::kColumn : InstructionShape::kRow); + + // Take advantage of Tensor Op's 8 x 4T access pattern + int const kAccessesInner = (kWarpShapeDivisibleInner / kElementsPerAccess) / 4; + + AccessType *access_ptr = reinterpret_cast(&frag); + + if (kOperand == Operand::kA) { + int const kTilesPerInstruction = InstructionShape::kRow / 8; + + CUTLASS_PRAGMA_UNROLL + for (int inst_m_idx = 0; inst_m_idx < InstructionCount::kRow; ++inst_m_idx) { + + CUTLASS_PRAGMA_UNROLL + for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) { + + CUTLASS_PRAGMA_UNROLL + for (int access_m_idx = 0; access_m_idx < kTilesPerInstruction; ++access_m_idx) { + int access_idx = + access_m_idx + kTilesPerInstruction * (inner_idx + kAccessesInner * inst_m_idx); + + MatrixCoord offset( + access_m_idx * 8 + inst_m_idx * InstructionShape::kRow, + inner_idx * 4 * kElementsPerAccess + iterations_ * InstructionShape::kColumn); + + MatrixCoord access_coord = origin_ + offset; + +// if(access_coord.row() < extent_.row() && access_coord.column() < extent_.column()) { + + access_ptr[access_idx] = *reinterpret_cast( + ref_.data() + ref_.offset(offset)); +// } +// else { +// AccessType zero; +// zero.clear(); +// access_ptr[access_idx] = zero; +// } + } + } + } + } + else { + CUTLASS_PRAGMA_UNROLL + for (int inst_n_idx = 0; inst_n_idx < InstructionCount::kColumn; ++inst_n_idx) { + + CUTLASS_PRAGMA_UNROLL + for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) { + int access_idx = inner_idx + kAccessesInner * inst_n_idx; + + MatrixCoord offset( + inner_idx * 4 * kElementsPerAccess + iterations_ * InstructionShape::kRow, + inst_n_idx * 8); + + MatrixCoord access_coord = origin_ + offset; + +// if(access_coord.row() < extent_.row() && access_coord.column() < extent_.column()) { + + access_ptr[access_idx] = *reinterpret_cast( + ref_.data() + ref_.offset(offset)); +// } +// else { +// AccessType zero; +// zero.clear(); +// access_ptr[access_idx] = zero; +// } + } + } + } + } + +}; + + + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/transform/threadblock/predicated_vector_access_iterator.h b/include/cutlass/transform/threadblock/predicated_vector_access_iterator.h index 8ab254d74..816767057 100644 --- a/include/cutlass/transform/threadblock/predicated_vector_access_iterator.h +++ b/include/cutlass/transform/threadblock/predicated_vector_access_iterator.h @@ -56,8 +56,20 @@ namespace threadblock { /// PredicatedVectorAccessIterator /// -template +template < + /// Shape of the vector accessed by the entire threadblock + typename Shape, + /// Shape of the vector accessed by the warp + typename WarpShape, + /// Type of Element + typename Element, + /// Layout of the vector + typename Layout, + /// Number of elements for each access + int ElementsPerAccess, + /// Support residual tile + bool EnableResidualAccess = false +> class PredicatedVectorAccessIterator; //////////////////////////////////////////////////////////////////////////////// @@ -65,8 +77,21 @@ class PredicatedVectorAccessIterator; /// Vector access iterator specialized for vectors, e.g. scale and bias /// Thread arrangements are for TensorOps /// -template -class PredicatedVectorAccessIterator { +template < + typename Shape_, + typename WarpShape_, + typename Element_, + int ElementsPerAccess, + bool EnableResidualAccess +> +class PredicatedVectorAccessIterator < + Shape_, + WarpShape_, + Element_, + layout::PitchLinear, + ElementsPerAccess, + EnableResidualAccess +> { public: using Shape = Shape_; @@ -116,6 +141,12 @@ class PredicatedVectorAccessIterator( const_cast(pointer))), - extent_(extent) { + extent_(extent), + is_residual_(false) { int warp_offset = (warp_id / kWarpCountStrided) * WarpShape::kContiguous; @@ -143,6 +175,15 @@ class PredicatedVectorAccessIterator -class PredicatedVectorAccessIterator { +template < + typename Shape_, + typename WarpShape_, + typename Element_, + int ElementsPerAccess, + bool EnableResidualAccess +> +class PredicatedVectorAccessIterator< + Shape_, + WarpShape_, + Element_, + layout::RowMajor, + ElementsPerAccess, + EnableResidualAccess +> { public: using Shape = Shape_; @@ -245,7 +305,8 @@ class PredicatedVectorAccessIterator, Element, layout::PitchLinear, - ElementsPerAccess>; + ElementsPerAccess, + EnableResidualAccess>; using AccessType = typename UnderlyingIterator::AccessType; static int const kElementsPerAccess = UnderlyingIterator::kElementsPerAccess;