mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-20 06:48:59 +00:00
cutlass 3.9 update (#2255)
* cutlass 3.9 update * rebase * fixes out of shared memory for blockwise Blackwell * doc format * fix issue 2253 * disable host ref by default * fix sm120 smem capacity --------- Co-authored-by: yuzhai <yuzhai@nvidia.com> Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
@@ -511,16 +511,23 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
|
||||
return [], []
|
||||
if CudaToolkitVersionSatisfies(cuda_version, 12, 1) and can_do_cooperative and can_do_tma_epilogue:
|
||||
schedules = []
|
||||
schedules.append(
|
||||
[
|
||||
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped),
|
||||
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
|
||||
])
|
||||
schedules.append(
|
||||
[
|
||||
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped),
|
||||
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
|
||||
])
|
||||
if is_blockwise(gemm_kind):
|
||||
schedules.append(
|
||||
[
|
||||
to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative, grouped),
|
||||
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
|
||||
])
|
||||
else:
|
||||
schedules.append(
|
||||
[
|
||||
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped),
|
||||
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
|
||||
])
|
||||
schedules.append(
|
||||
[
|
||||
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped),
|
||||
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
|
||||
])
|
||||
return schedules, []
|
||||
return [], []
|
||||
|
||||
@@ -547,26 +554,34 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
|
||||
epilogue_schedule = EpilogueScheduleType.TmaWarpSpecialized
|
||||
if a_type_size > b_type_size:
|
||||
epilogue_schedule = EpilogueScheduleType.EpilogueTransposed
|
||||
schedules.append([
|
||||
KernelScheduleType.TmaWarpSpecialized,
|
||||
epilogue_schedule
|
||||
])
|
||||
schedules.append([
|
||||
KernelScheduleType.TmaWarpSpecializedPingpong,
|
||||
epilogue_schedule
|
||||
])
|
||||
|
||||
if not is_blockwise(gemm_kind):
|
||||
schedules.append([
|
||||
KernelScheduleType.TmaWarpSpecialized,
|
||||
epilogue_schedule
|
||||
])
|
||||
schedules.append([
|
||||
KernelScheduleType.TmaWarpSpecializedPingpong,
|
||||
epilogue_schedule
|
||||
])
|
||||
if cta_m >= 128:
|
||||
if a_type_size > b_type_size:
|
||||
epilogue_schedule = EpilogueScheduleType.EpilogueTransposed
|
||||
else:
|
||||
epilogue_schedule = EpilogueScheduleType.TmaWarpSpecializedCooperative
|
||||
schedules.append([
|
||||
KernelScheduleType.TmaWarpSpecializedCooperative,
|
||||
epilogue_schedule
|
||||
])
|
||||
if is_blockwise(gemm_kind):
|
||||
schedules.append([
|
||||
KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative,
|
||||
epilogue_schedule
|
||||
])
|
||||
else:
|
||||
schedules.append([
|
||||
KernelScheduleType.TmaWarpSpecializedCooperative,
|
||||
epilogue_schedule
|
||||
])
|
||||
return schedules, []
|
||||
|
||||
if not is_aligned:
|
||||
if not is_aligned and not is_blockwise(gemm_kind):
|
||||
schedules = [[KernelScheduleType.CpAsyncWarpSpecialized,
|
||||
default_epilogue]]
|
||||
stream_k_schedules = []
|
||||
@@ -585,7 +600,7 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
|
||||
|
||||
schedules = []
|
||||
# Pruning: emit Void-C and Grouped kernels with persistent kernels only
|
||||
if (level >= 1 or not is_void_c) and not grouped:
|
||||
if (level >= 1 or not is_void_c) and not grouped and not is_blockwise(gemm_kind):
|
||||
# Pruning: don't stamp out fp8 kernels with auto schedule
|
||||
if not is_fp8:
|
||||
schedules.append([KernelScheduleType.ScheduleAuto, auto_epilogue])
|
||||
@@ -596,7 +611,7 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
|
||||
if can_do_tma_epilogue:
|
||||
assert not requires_transposed_epilogue
|
||||
# Inconsistency: fp8 pingpong only gets stamped out with fast accum
|
||||
if not is_fp8 or level >= 1:
|
||||
if (not is_fp8 or level >= 1) and not is_blockwise(gemm_kind):
|
||||
schedules.append([
|
||||
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpong, grouped),
|
||||
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized, grouped)
|
||||
@@ -618,14 +633,24 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
|
||||
schedules.append([to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, grouped), to_grouped_schedule(default_epilogue, grouped)])
|
||||
|
||||
if can_do_cooperative:
|
||||
schedules.append([
|
||||
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped),
|
||||
to_grouped_schedule(default_epilogue, grouped)
|
||||
])
|
||||
stream_k_schedules.append([
|
||||
KernelScheduleType.TmaWarpSpecializedCooperative,
|
||||
default_epilogue
|
||||
])
|
||||
if is_blockwise(gemm_kind):
|
||||
schedules.append([
|
||||
to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative, grouped),
|
||||
to_grouped_schedule(default_epilogue, grouped)
|
||||
])
|
||||
stream_k_schedules.append([
|
||||
KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative,
|
||||
default_epilogue
|
||||
])
|
||||
else:
|
||||
schedules.append([
|
||||
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped),
|
||||
to_grouped_schedule(default_epilogue, grouped)
|
||||
])
|
||||
stream_k_schedules.append([
|
||||
KernelScheduleType.TmaWarpSpecializedCooperative,
|
||||
default_epilogue
|
||||
])
|
||||
if can_do_fp8_fast_accum:
|
||||
schedules.append([
|
||||
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped),
|
||||
@@ -640,14 +665,24 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
|
||||
if can_do_tma_epilogue:
|
||||
assert not requires_transposed_epilogue
|
||||
if can_do_cooperative:
|
||||
schedules.append([
|
||||
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped),
|
||||
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
|
||||
])
|
||||
stream_k_schedules.append([
|
||||
KernelScheduleType.TmaWarpSpecializedCooperative,
|
||||
EpilogueScheduleType.TmaWarpSpecializedCooperative
|
||||
])
|
||||
if is_blockwise(gemm_kind):
|
||||
schedules.append([
|
||||
to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative, grouped),
|
||||
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
|
||||
])
|
||||
stream_k_schedules.append([
|
||||
KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative,
|
||||
EpilogueScheduleType.TmaWarpSpecializedCooperative
|
||||
])
|
||||
else:
|
||||
schedules.append([
|
||||
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped),
|
||||
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
|
||||
])
|
||||
stream_k_schedules.append([
|
||||
KernelScheduleType.TmaWarpSpecializedCooperative,
|
||||
EpilogueScheduleType.TmaWarpSpecializedCooperative
|
||||
])
|
||||
if can_do_fp8_fast_accum:
|
||||
schedules.append([
|
||||
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped),
|
||||
|
||||
Reference in New Issue
Block a user