mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-20 14:59:01 +00:00
Support for Group GEMM in CUTLASS Profiler for Geforce and Spark (#3092)
Co-authored-by: dePaul Miller <23461061+depaulmillz@users.noreply.github.com>
This commit is contained in:
@@ -46,7 +46,7 @@ set(TEST_SMALL_LARGE_GROUP --m=128 --n=128 --groups=50 --iterations=0)
|
||||
set(TEST_RANDOM_PERF --iterations=10) # Random problem sizes
|
||||
set(TEST_RANDOM_PERF_LARGE_GROUP --groups=50 --iterations=10) # Random problem sizes
|
||||
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES "120a|121a")
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES "120a|120f|121a")
|
||||
cutlass_example_add_executable(
|
||||
79a_blackwell_geforce_nvfp4_bf16_gemm
|
||||
79a_blackwell_geforce_nvfp4_bf16_gemm.cu
|
||||
|
||||
@@ -27,7 +27,7 @@
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES "120a|121a")
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES "120a|120f|121a")
|
||||
cutlass_example_add_executable(
|
||||
80a_blackwell_geforce_mxfp8_bf16_sparse_gemm
|
||||
80a_blackwell_geforce_mxfp8_bf16_sparse_gemm.cu
|
||||
|
||||
@@ -28,7 +28,7 @@
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES "120a|121a")
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES "120a|120f|121a")
|
||||
cutlass_example_add_executable(
|
||||
87a_blackwell_geforce_fp8_bf16_gemm_blockwise
|
||||
87a_blackwell_geforce_fp8_bf16_gemm_blockwise.cu
|
||||
|
||||
@@ -186,13 +186,13 @@ struct CollectiveBuilder<
|
||||
// Basic storage block for new Scaling Factor Layouts
|
||||
using mnBasicBlockShape = Shape<_32,_4>;
|
||||
using mnBasicBlockStride = Stride<_16,_4>;
|
||||
using kBasicBlockShape = Shape<Int<SFVectorSize>, Int<MMA_NSF>>;
|
||||
using kBasicBlockShape = Shape<Int<(int)SFVectorSize>, Int<MMA_NSF>>;
|
||||
using kBasicBlockStride = Stride<_0, _1>;
|
||||
|
||||
using sSFA_shapeM = decltype(prepend(size<0>(TileShape_MNK{}) / Blk_MN{}, mnBasicBlockShape{}));
|
||||
using sSF_strideMN = decltype(prepend( Blk_Elems{}, mnBasicBlockStride{}));
|
||||
using sSFA_strideM = sSF_strideMN;
|
||||
using sSF_shapeK = decltype(prepend(make_shape( Blk_SF{}/Int<MMA_NSF>{}, size<2>(TileShape_MNK{}) / Int<SFVectorSize>{} / Blk_SF{}), kBasicBlockShape{}));
|
||||
using sSF_shapeK = decltype(prepend(make_shape( Blk_SF{}/Int<MMA_NSF>{}, size<2>(TileShape_MNK{}) / Int<(int)SFVectorSize>{} / Blk_SF{}), kBasicBlockShape{}));
|
||||
|
||||
using sSFA_strideK = decltype(prepend(make_stride( Int<MMA_NSF>{}, size<0>(TileShape_MNK{}) / Blk_MN{} * Blk_Elems{}), kBasicBlockStride{}));
|
||||
using sSFA_shape = decltype(make_shape( sSFA_shapeM{}, sSF_shapeK{}));
|
||||
@@ -209,11 +209,6 @@ struct CollectiveBuilder<
|
||||
using SmemLayoutAtomsA = decltype(cute::make_tuple(SmemLayoutAtomA{}, SmemLayoutAtomSFA{}));
|
||||
using SmemLayoutAtomsB = decltype(cute::make_tuple(SmemLayoutAtomB{}, SmemLayoutAtomSFB{}));
|
||||
|
||||
static constexpr int PipelineStages = cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override_blockscaled<
|
||||
detail::sm120_smem_capacity_bytes, SmemAllocTypeA, SmemAllocTypeB, TileShape_MNK, SmemLayoutAtomSFA, SmemLayoutAtomSFB>(StageCountType{});
|
||||
|
||||
static constexpr uint32_t SchedulerPipelineStageCount = 3;
|
||||
|
||||
using StrideA = cutlass::gemm::TagToStrideA_t<GmemLayoutATag>;
|
||||
using StrideB = cutlass::gemm::TagToStrideB_t<GmemLayoutBTag>;
|
||||
using InternalStrideA = cute::remove_pointer_t<StrideA>;
|
||||
@@ -232,6 +227,34 @@ struct CollectiveBuilder<
|
||||
cute::is_base_of_v<KernelPtrArrayTmaWarpSpecializedPingpong, BuilderScheduleTag>,
|
||||
"Invalid builder schedule tag for grouped GEMM");
|
||||
|
||||
|
||||
static constexpr uint32_t SchedulerPipelineStageCount = 3;
|
||||
|
||||
static constexpr int CLCResponseSize = sizeof(typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm100<Shape<_1,_1,_1>,1>::CLCResponse{});
|
||||
|
||||
static constexpr auto SchedulerPipelineStorage = IsGroupedGemmKernel ? sizeof(cutlass::PipelineDetail::PipelineAsyncSharedStorage<8>)
|
||||
: sizeof(typename cutlass::PipelineCLCFetchAsync<SchedulerPipelineStageCount, Shape<_1,_1,_1>>::SharedStorage);
|
||||
static constexpr auto CLCResponseStorage = IsGroupedGemmKernel ? 0 : (SchedulerPipelineStageCount *
|
||||
CLCResponseSize);
|
||||
static constexpr auto TensorMapStorage =
|
||||
IsGroupedGemmKernel ? sizeof(cute::TmaDescriptor) * 2 /* We have two tensormaps smem */ :
|
||||
0;
|
||||
|
||||
// TensorMapReady pipeline storage (specific to grouped/array kernels)
|
||||
static constexpr auto TensorMapReadyPipelineStorage =
|
||||
IsGroupedGemmKernel ? sizeof(typename cutlass::PipelineAsync<SchedulerPipelineStageCount>::SharedStorage) :
|
||||
0;
|
||||
|
||||
static constexpr int ReducedSmemCapacityBytes = detail::sm120_smem_capacity_bytes -
|
||||
SchedulerPipelineStorage -
|
||||
TensorMapStorage -
|
||||
TensorMapReadyPipelineStorage -
|
||||
CLCResponseStorage;
|
||||
|
||||
static constexpr int PipelineStages = cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override_blockscaled<
|
||||
ReducedSmemCapacityBytes, SmemAllocTypeA, SmemAllocTypeB, TileShape_MNK, SmemLayoutAtomSFA, SmemLayoutAtomSFB>(StageCountType{});
|
||||
|
||||
|
||||
using KernelSchedule = cute::conditional_t<IsGroupedGemmKernel,
|
||||
// PtrArray
|
||||
cute::conditional_t<IsCooperative,
|
||||
|
||||
@@ -760,6 +760,8 @@ select_instr() {
|
||||
(SfVectorSize == 32 && cute::is_same_v<KernelScheduleAuto, BuilderScheduleTag>)
|
||||
|| (SfVectorSize == 32 && cute::is_base_of_v<KernelScheduleBlockScaledGemmSm100, BuilderScheduleTag>)
|
||||
|| (SfVectorSize == 32 && cute::is_base_of_v<KernelTmaWarpSpecializedPingpong, BuilderScheduleTag>)
|
||||
|| (SfVectorSize == 32 && cute::is_base_of_v<KernelPtrArrayTmaWarpSpecializedCooperative, BuilderScheduleTag>)
|
||||
|| (SfVectorSize == 32 && cute::is_base_of_v<KernelPtrArrayTmaWarpSpecializedPingpong, BuilderScheduleTag>)
|
||||
|| (SfVectorSize == 32 && cute::is_base_of_v<KernelTmaWarpSpecializedCooperative, BuilderScheduleTag>)
|
||||
|| (SfVectorSize == 32 && cute::is_base_of_v<KernelSchedulePtrArrayBlockScaledGemmSm100, BuilderScheduleTag>)
|
||||
|| (SfVectorSize == 64 && cute::is_base_of_v<KernelScheduleBlockScaledSparseGemmSm100, BuilderScheduleTag>
|
||||
|
||||
@@ -11176,11 +11176,13 @@ def GenerateSM100_TensorOp_fp8_UMMA_conv3x(manifest, cuda_version,
|
||||
conv_kind = ConvKind.Fprop,
|
||||
log_indent_level = log_indent_level)
|
||||
|
||||
def GenerateSM120_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version):
|
||||
def GenerateSM120_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.BlockScaledUniversal3x):
|
||||
# SM120 MMA with mixed F4/F6/F8 inputs + block scale
|
||||
if not CudaToolkitVersionSatisfies(cuda_version, 12, 8):
|
||||
return
|
||||
|
||||
grouped = is_grouped(gemm_kind)
|
||||
|
||||
layouts = [
|
||||
[[LayoutType.RowMajor, 128], [LayoutType.ColumnMajor, 128], [LayoutType.RowMajor, 0]]
|
||||
]
|
||||
@@ -11206,16 +11208,17 @@ def GenerateSM120_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
|
||||
acc_types = [ DataType.f32 ]
|
||||
|
||||
def is_pingpong(kernel_schedule):
|
||||
if kernel_schedule == KernelScheduleType.Mxf8f6f4TmaWarpSpecializedPingpongSm120:
|
||||
if kernel_schedule == KernelScheduleType.Mxf8f6f4TmaWarpSpecializedPingpongSm120 or \
|
||||
kernel_schedule == KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def tile_schedulers(sfdtype, kernel_schedule):
|
||||
# Pingpong kernel schedule doesn't support stream-K.
|
||||
# Only use the stream-K scheduler for non-void SFD to limit kernel count. When SFD is void,
|
||||
# the epilogue is the traditional linear combination, for which we already have tests with stream-K
|
||||
if is_pingpong(kernel_schedule):
|
||||
if grouped or is_pingpong(kernel_schedule):
|
||||
return [TileSchedulerType.Default]
|
||||
elif sfdtype["type"] == DataType.void:
|
||||
return [TileSchedulerType.Default]
|
||||
@@ -11226,12 +11229,12 @@ def GenerateSM120_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
|
||||
max_cc = 121
|
||||
|
||||
epi_type = DataType.f32
|
||||
|
||||
|
||||
math_instructions = []
|
||||
|
||||
kernel_schedules = [
|
||||
KernelScheduleType.Mxf8f6f4TmaWarpSpecializedCooperativeSm120,
|
||||
KernelScheduleType.Mxf8f6f4TmaWarpSpecializedPingpongSm120
|
||||
to_grouped_schedule(KernelScheduleType.Mxf8f6f4TmaWarpSpecializedCooperativeSm120, grouped),
|
||||
to_grouped_schedule(KernelScheduleType.Mxf8f6f4TmaWarpSpecializedPingpongSm120, grouped)
|
||||
]
|
||||
|
||||
for instr_size, a_type, b_type, acc_type in product(instruction_sizes, ab_types, ab_types, acc_types):
|
||||
@@ -11299,16 +11302,18 @@ def GenerateSM120_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
|
||||
|
||||
for data_type, kernel_schedule in product(data_types, kernel_schedules):
|
||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
|
||||
[[kernel_schedule, EpilogueScheduleType.ScheduleAuto]],
|
||||
[[kernel_schedule, EpilogueScheduleType.ScheduleAuto]],
|
||||
tile_schedulers = tile_schedulers(data_type["sfd_type"], kernel_schedule),
|
||||
gemm_kind = GemmKind.BlockScaledUniversal3x
|
||||
gemm_kind = gemm_kind
|
||||
)
|
||||
|
||||
def GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version):
|
||||
def GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.BlockScaledUniversal3x):
|
||||
# SM120 MMA with with F4 + block scale
|
||||
if not CudaToolkitVersionSatisfies(cuda_version, 12, 8):
|
||||
return
|
||||
|
||||
grouped = is_grouped(gemm_kind)
|
||||
|
||||
# layouts for ABC and their alignments.
|
||||
layouts = [
|
||||
[[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 32], [LayoutType.RowMajor, 0]]
|
||||
@@ -11344,11 +11349,12 @@ def GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
|
||||
|
||||
def is_pingpong(kernel_schedule):
|
||||
if kernel_schedule == KernelScheduleType.Nvf4TmaWarpSpecializedPingpongSm120 or \
|
||||
kernel_schedule == KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120:
|
||||
kernel_schedule == KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120 or \
|
||||
kernel_schedule == KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def is_nvf4(kernel_schedule):
|
||||
if kernel_schedule == KernelScheduleType.Nvf4TmaWarpSpecializedCooperativeSm120 or \
|
||||
kernel_schedule == KernelScheduleType.Nvf4TmaWarpSpecializedPingpongSm120:
|
||||
@@ -11360,7 +11366,7 @@ def GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
|
||||
# Pingpong kernel schedule doesn't support stream-K.
|
||||
# Only use the stream-K scheduler for non-void SFD to limit kernel count. When SFD is void,
|
||||
# the epilogue is the traditional linear combination, for which we already have tests with stream-K
|
||||
if is_pingpong(kernel_schedule):
|
||||
if grouped or is_pingpong(kernel_schedule):
|
||||
return [TileSchedulerType.Default]
|
||||
elif sfdtype["type"] == DataType.void:
|
||||
return [TileSchedulerType.Default]
|
||||
@@ -11374,12 +11380,12 @@ def GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
|
||||
|
||||
math_instructions = []
|
||||
|
||||
kernel_schedules = [
|
||||
KernelScheduleType.Nvf4TmaWarpSpecializedCooperativeSm120,
|
||||
KernelScheduleType.Nvf4TmaWarpSpecializedPingpongSm120,
|
||||
KernelScheduleType.Mxf4TmaWarpSpecializedCooperativeSm120,
|
||||
KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120
|
||||
]
|
||||
kernel_schedules = list(set([
|
||||
to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecializedCooperativeSm120, grouped),
|
||||
to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecializedPingpongSm120, grouped),
|
||||
to_grouped_schedule(KernelScheduleType.Mxf4TmaWarpSpecializedCooperativeSm120, grouped),
|
||||
to_grouped_schedule(KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120, grouped)
|
||||
])) # ensure no duplicates
|
||||
|
||||
for instr_size, a_type, b_type, acc_type, sf_type in product(instruction_sizes, ab_types, ab_types, acc_types, sf_types):
|
||||
math_instructions.append(
|
||||
@@ -11394,12 +11400,16 @@ def GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
|
||||
for math_inst in math_instructions:
|
||||
for kernel_schedule in kernel_schedules:
|
||||
tile_descriptions = []
|
||||
is_grouped_schedule = grouped
|
||||
tile_sizes = tile_sizes_pingpong if is_pingpong(kernel_schedule) else tile_sizes_cooperative
|
||||
for tile_size in tile_sizes:
|
||||
# nvf4 kernel only supports ue4m3 SF
|
||||
# mxf4 kernel only supports ue8m0 SF
|
||||
if (math_inst.element_scale_factor == DataType.ue4m3 and is_nvf4(kernel_schedule)) or \
|
||||
(math_inst.element_scale_factor == DataType.ue8m0 and not is_nvf4(kernel_schedule)):
|
||||
# grouped schedules only support ue8m0 (MXF4); NVF4 (ue4m3) grouped requires
|
||||
# NVF4-specific PtrArray schedule tags not yet available
|
||||
if (is_grouped_schedule and math_inst.element_scale_factor == DataType.ue8m0) or \
|
||||
(not is_grouped_schedule and math_inst.element_scale_factor == DataType.ue4m3 and is_nvf4(kernel_schedule)) or \
|
||||
(not is_grouped_schedule and math_inst.element_scale_factor == DataType.ue8m0 and not is_nvf4(kernel_schedule)):
|
||||
tile_descriptions.append(
|
||||
TileDescription(tile_size, 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape))
|
||||
|
||||
@@ -11482,10 +11492,10 @@ def GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
|
||||
|
||||
for data_type in data_types:
|
||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
|
||||
[[kernel_schedule, EpilogueScheduleType.ScheduleAuto]],
|
||||
[[kernel_schedule, EpilogueScheduleType.ScheduleAuto]],
|
||||
tile_schedulers = tile_schedulers(data_type["sfd_type"], kernel_schedule),
|
||||
gemm_kind = GemmKind.BlockScaledUniversal3x
|
||||
)
|
||||
gemm_kind = gemm_kind
|
||||
)
|
||||
|
||||
def GenerateSM120_Sparse_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version):
|
||||
# SM120 MMA with mixed F4/F6/F8 inputs + block scale
|
||||
@@ -12048,6 +12058,11 @@ def GenerateSM120(manifest, cuda_version):
|
||||
GenerateSM120_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version)
|
||||
GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version)
|
||||
#
|
||||
# Grouped Block Scaled Gemm
|
||||
#
|
||||
GenerateSM120_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockScaledUniversal3x)
|
||||
GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockScaledUniversal3x)
|
||||
#
|
||||
# Sparse Block Scaled Gemm
|
||||
#
|
||||
GenerateSM120_Sparse_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version)
|
||||
|
||||
@@ -615,6 +615,9 @@ class KernelScheduleType(enum.Enum):
|
||||
BlockwiseTmaWarpSpecializedCooperativeSm120 = enum_auto()
|
||||
BlockwiseTmaWarpSpecializedPingpongSm120 = enum_auto()
|
||||
|
||||
PtrArrayTmaWarpSpecializedCooperativeBlockScaledSm120 = enum_auto()
|
||||
PtrArrayTmaWarpSpecializedPingpongBlockScaledSm120 = enum_auto()
|
||||
|
||||
KernelScheduleTag = {
|
||||
KernelScheduleType.ScheduleAuto: 'cutlass::gemm::collective::KernelScheduleAuto',
|
||||
KernelScheduleType.Multistage: 'cutlass::gemm::KernelMultistage',
|
||||
@@ -730,6 +733,8 @@ KernelScheduleTag = {
|
||||
|
||||
KernelScheduleType.BlockwiseTmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelTmaWarpSpecializedBlockwiseCooperativeSm120',
|
||||
KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120: 'cutlass::gemm::KernelTmaWarpSpecializedBlockwisePingpongSm120',
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeBlockScaledSm120: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeBlockScaledSm120<3>',
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongBlockScaledSm120: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongBlockScaledSm120<3>',
|
||||
KernelScheduleType.SparseMxf8f6f4TmaWarpSpecializedSm120: 'cutlass::gemm::KernelSparseTmaWarpSpecializedMxf8f6f4Sm120',
|
||||
KernelScheduleType.SparseMxf8f6f4TmaWarpSpecializedAcc2x4Sm120: 'cutlass::gemm::KernelSparseTmaWarpSpecializedMxf8f6f4Acc2x4Sm120',
|
||||
KernelScheduleType.SparseNvf4TmaWarpSpecializedSm120: 'cutlass::gemm::KernelSparseTmaWarpSpecializedNvf4Sm120',
|
||||
@@ -1040,6 +1045,13 @@ def to_grouped_schedule(schedule, grouped):
|
||||
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch,
|
||||
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch,
|
||||
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch,
|
||||
# SM120
|
||||
KernelScheduleType.Mxf8f6f4TmaWarpSpecializedCooperativeSm120: KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative,
|
||||
KernelScheduleType.Mxf8f6f4TmaWarpSpecializedPingpongSm120: KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong,
|
||||
KernelScheduleType.Nvf4TmaWarpSpecializedCooperativeSm120: KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative,
|
||||
KernelScheduleType.Nvf4TmaWarpSpecializedPingpongSm120: KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong,
|
||||
KernelScheduleType.Mxf4TmaWarpSpecializedCooperativeSm120: KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative,
|
||||
KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120: KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong,
|
||||
}
|
||||
|
||||
return group_schedule_map[schedule]
|
||||
|
||||
@@ -563,9 +563,9 @@ public:
|
||||
}
|
||||
|
||||
operator_args.mainloop.ptr_SFA =
|
||||
static_cast<const typename Operator::GemmKernel::ElementSF**>(arguments->SFA);
|
||||
static_cast<const typename CollectiveMainloop::ElementSF**>(arguments->SFA);
|
||||
operator_args.mainloop.ptr_SFB =
|
||||
static_cast<const typename Operator::GemmKernel::ElementSF**>(arguments->SFB);
|
||||
static_cast<const typename CollectiveMainloop::ElementSF**>(arguments->SFB);
|
||||
|
||||
operator_args.mainloop.layout_SFA =
|
||||
static_cast<typename CollectiveMainloop::InternalLayoutSFA*>(this->layout_SFA_device.data());
|
||||
|
||||
Reference in New Issue
Block a user