v4.2.1 update. (#2666)

This commit is contained in:
Junkai-Wu
2025-09-24 01:25:43 +08:00
committed by GitHub
parent 2b8dff1f90
commit 7a6d4ee099
12 changed files with 163 additions and 65 deletions

View File

@@ -7467,8 +7467,9 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version,
kernel_schedule = to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecialized1SmSm100, grouped)
epi_schedule = to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped)
epi_schedule_nosmem = to_grouped_schedule(EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized1Sm, grouped)
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
[[kernel_schedule, epi_schedule]],
[[kernel_schedule, epi_schedule], [kernel_schedule, epi_schedule_nosmem]],
tile_schedulers=tile_schedulers, gemm_kind=gemm_kind)
def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.Universal3x):

View File

@@ -811,10 +811,14 @@ class EpilogueScheduleType(enum.Enum):
NoSmemWarpSpecialized2Sm = enum_auto()
FastF32NoSmemWarpSpecialized1Sm = enum_auto()
FastF32NoSmemWarpSpecialized2Sm = enum_auto()
BlockwiseNoSmemWarpSpecialized1Sm = enum_auto()
BlockwiseNoSmemWarpSpecialized2Sm = enum_auto()
PtrArrayNoSmemWarpSpecialized1Sm = enum_auto()
PtrArrayNoSmemWarpSpecialized2Sm = enum_auto()
PtrArrayFastF32NoSmemWarpSpecialized1Sm = enum_auto()
PtrArrayFastF32NoSmemWarpSpecialized2Sm = enum_auto()
PtrArrayBlockwiseNoSmemWarpSpecialized1Sm = enum_auto()
PtrArrayBlockwiseNoSmemWarpSpecialized2Sm = enum_auto()
TmaWarpSpecialized = enum_auto()
TmaWarpSpecializedCooperative = enum_auto()
TmaWarpSpecialized1Sm = enum_auto()
@@ -834,10 +838,14 @@ EpilogueScheduleTag = {
EpilogueScheduleType.NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::NoSmemWarpSpecialized2Sm',
EpilogueScheduleType.FastF32NoSmemWarpSpecialized1Sm: 'cutlass::epilogue::FastF32NoSmemWarpSpecialized1Sm',
EpilogueScheduleType.FastF32NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::FastF32NoSmemWarpSpecialized2Sm',
EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized1Sm: 'cutlass::epilogue::BlockwiseNoSmemWarpSpecialized1Sm',
EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized2Sm: 'cutlass::epilogue::BlockwiseNoSmemWarpSpecialized2Sm',
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized1Sm',
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized2Sm',
EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayFastF32NoSmemWarpSpecialized1Sm',
EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayFastF32NoSmemWarpSpecialized2Sm',
EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayBlockwiseNoSmemWarpSpecialized1Sm',
EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayBlockwiseNoSmemWarpSpecialized2Sm',
EpilogueScheduleType.TmaWarpSpecialized: 'cutlass::epilogue::TmaWarpSpecialized',
EpilogueScheduleType.TmaWarpSpecializedCooperative: 'cutlass::epilogue::TmaWarpSpecializedCooperative',
EpilogueScheduleType.TmaWarpSpecialized1Sm: 'cutlass::epilogue::TmaWarpSpecialized1Sm',
@@ -858,10 +866,14 @@ EpilogueScheduleSuffixes = {
EpilogueScheduleType.NoSmemWarpSpecialized2Sm: '_epi_nosmem',
EpilogueScheduleType.FastF32NoSmemWarpSpecialized1Sm: '_epi_nosmem_fastf32',
EpilogueScheduleType.FastF32NoSmemWarpSpecialized2Sm: '_epi_nosmem_fastf32',
EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized1Sm: '_epi_nosmem',
EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized2Sm: '_epi_nosmem',
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm: '_epi_nosmem',
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm: '_epi_nosmem',
EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized1Sm: '_epi_nosmem_fastf32',
EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized2Sm: '_epi_nosmem_fastf32',
EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized1Sm: '_epi_nosmem',
EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized2Sm: '_epi_nosmem',
EpilogueScheduleType.TmaWarpSpecialized: '_epi_tma',
EpilogueScheduleType.TmaWarpSpecializedCooperative: '_epi_tma',
EpilogueScheduleType.TmaWarpSpecialized1Sm: '',
@@ -926,6 +938,8 @@ def to_grouped_schedule(schedule, grouped):
EpilogueScheduleType.TmaWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm,
EpilogueScheduleType.NoSmemWarpSpecialized1Sm: EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm,
EpilogueScheduleType.NoSmemWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm,
EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized1Sm: EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized1Sm,
EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized2Sm,
# SM103
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103,
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103,