cutlass 3.9 update (#2255)

* cutlass 3.9 update

* rebase

* fixes out of shared memory for blockwise Blackwell

* doc format

* fix issue 2253

* disable host ref by default

* fix sm120 smem capacity

---------

Co-authored-by: yuzhai <yuzhai@nvidia.com>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
Yujia Zhai
2025-04-24 12:42:40 -07:00
committed by GitHub
parent 8e345c5c5b
commit 331a1f5b3f
143 changed files with 18089 additions and 5935 deletions

View File

@@ -511,16 +511,23 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
return [], []
if CudaToolkitVersionSatisfies(cuda_version, 12, 1) and can_do_cooperative and can_do_tma_epilogue:
schedules = []
schedules.append(
[
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped),
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
])
schedules.append(
[
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped),
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
])
if is_blockwise(gemm_kind):
schedules.append(
[
to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative, grouped),
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
])
else:
schedules.append(
[
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped),
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
])
schedules.append(
[
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped),
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
])
return schedules, []
return [], []
@@ -547,26 +554,34 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
epilogue_schedule = EpilogueScheduleType.TmaWarpSpecialized
if a_type_size > b_type_size:
epilogue_schedule = EpilogueScheduleType.EpilogueTransposed
schedules.append([
KernelScheduleType.TmaWarpSpecialized,
epilogue_schedule
])
schedules.append([
KernelScheduleType.TmaWarpSpecializedPingpong,
epilogue_schedule
])
if not is_blockwise(gemm_kind):
schedules.append([
KernelScheduleType.TmaWarpSpecialized,
epilogue_schedule
])
schedules.append([
KernelScheduleType.TmaWarpSpecializedPingpong,
epilogue_schedule
])
if cta_m >= 128:
if a_type_size > b_type_size:
epilogue_schedule = EpilogueScheduleType.EpilogueTransposed
else:
epilogue_schedule = EpilogueScheduleType.TmaWarpSpecializedCooperative
schedules.append([
KernelScheduleType.TmaWarpSpecializedCooperative,
epilogue_schedule
])
if is_blockwise(gemm_kind):
schedules.append([
KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative,
epilogue_schedule
])
else:
schedules.append([
KernelScheduleType.TmaWarpSpecializedCooperative,
epilogue_schedule
])
return schedules, []
if not is_aligned:
if not is_aligned and not is_blockwise(gemm_kind):
schedules = [[KernelScheduleType.CpAsyncWarpSpecialized,
default_epilogue]]
stream_k_schedules = []
@@ -585,7 +600,7 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
schedules = []
# Pruning: emit Void-C and Grouped kernels with persistent kernels only
if (level >= 1 or not is_void_c) and not grouped:
if (level >= 1 or not is_void_c) and not grouped and not is_blockwise(gemm_kind):
# Pruning: don't stamp out fp8 kernels with auto schedule
if not is_fp8:
schedules.append([KernelScheduleType.ScheduleAuto, auto_epilogue])
@@ -596,7 +611,7 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
if can_do_tma_epilogue:
assert not requires_transposed_epilogue
# Inconsistency: fp8 pingpong only gets stamped out with fast accum
if not is_fp8 or level >= 1:
if (not is_fp8 or level >= 1) and not is_blockwise(gemm_kind):
schedules.append([
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpong, grouped),
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized, grouped)
@@ -618,14 +633,24 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
schedules.append([to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, grouped), to_grouped_schedule(default_epilogue, grouped)])
if can_do_cooperative:
schedules.append([
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped),
to_grouped_schedule(default_epilogue, grouped)
])
stream_k_schedules.append([
KernelScheduleType.TmaWarpSpecializedCooperative,
default_epilogue
])
if is_blockwise(gemm_kind):
schedules.append([
to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative, grouped),
to_grouped_schedule(default_epilogue, grouped)
])
stream_k_schedules.append([
KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative,
default_epilogue
])
else:
schedules.append([
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped),
to_grouped_schedule(default_epilogue, grouped)
])
stream_k_schedules.append([
KernelScheduleType.TmaWarpSpecializedCooperative,
default_epilogue
])
if can_do_fp8_fast_accum:
schedules.append([
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped),
@@ -640,14 +665,24 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
if can_do_tma_epilogue:
assert not requires_transposed_epilogue
if can_do_cooperative:
schedules.append([
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped),
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
])
stream_k_schedules.append([
KernelScheduleType.TmaWarpSpecializedCooperative,
EpilogueScheduleType.TmaWarpSpecializedCooperative
])
if is_blockwise(gemm_kind):
schedules.append([
to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative, grouped),
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
])
stream_k_schedules.append([
KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative,
EpilogueScheduleType.TmaWarpSpecializedCooperative
])
else:
schedules.append([
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped),
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
])
stream_k_schedules.append([
KernelScheduleType.TmaWarpSpecializedCooperative,
EpilogueScheduleType.TmaWarpSpecializedCooperative
])
if can_do_fp8_fast_accum:
schedules.append([
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped),