DistGEMM bug fixes (#2713)

* Blackwell DistGEMM bug fixes

1. If using preferred cluster, there needs to be a branch so that
   the universal GEMM wrapper finds the correct base params.
2. Workspace sizes can change depending on problem shape in Blackwell,
   and DistGEMM was previously using the per-device shape to evaluate
   workspace size instead of the per-gemm shape.
3. Flattened size used to initialize host tensors can overflow (in
   Hopper example as well)
4. Preferred and fallback cluster args need to be set explicitly,
   otherwise if someone modifies the example to use preferred cluster,
   it will just fail.

* Fix example runtimes

* Set default fallback cluster shapes to the static ones
This commit is contained in:
Ali Hassani
2025-11-06 13:31:24 -05:00
committed by GitHub
parent 020c700e97
commit d1ef0e87f2
4 changed files with 84 additions and 27 deletions

View File

@@ -132,7 +132,7 @@ using namespace cute;
using TP = _8;
static constexpr int TP_ = TP{};
#if defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && \
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && \
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
// Distributed GEMM tiling/sharding schedule
@@ -252,7 +252,7 @@ HostTensorB tensor_B_arr[TP_];
HostTensorD tensor_C_arr[TP_];
HostTensorD tensor_D_arr[TP_];
#endif // (defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) &&
#endif // (defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) &&
// (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
/////////////////////////////////////////////////////////////////////////////////////////////////
@@ -345,8 +345,7 @@ struct Result {
};
#if defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && \
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
#if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
@@ -403,9 +402,9 @@ void initialize(const Options &options) {
stride_C = cutlass::make_cute_packed_stride(StrideC{}, shape_C);
stride_D = cutlass::make_cute_packed_stride(StrideD{}, shape_D);
auto a_coord = cutlass::make_Coord(size(shape_A), 1);
auto b_coord = cutlass::make_Coord(size(shape_B), 1);
auto c_coord = cutlass::make_Coord(size(shape_C), 1);
auto a_coord = cutlass::make_Coord(size<2>(shape_A)*size<0>(shape_A), size<1>(shape_A));
auto b_coord = cutlass::make_Coord(size<2>(shape_B)*size<0>(shape_B), size<1>(shape_B));
auto c_coord = cutlass::make_Coord(size<2>(shape_C)*size<0>(shape_C), size<1>(shape_C));
tensor_A.resize(a_coord);
tensor_B.resize(b_coord);
@@ -650,7 +649,7 @@ int run(Options &options) {
arguments_[device_idx] = dist_gemm_args_from_options(options, device_idx, stream_arr[device_idx]);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = DistGemm::get_workspace_size(arguments_[device_idx]);
size_t workspace_size = DistGemm::get_workspace_size(arguments_, device_idx);
size_t exclusive_workspace_size = DistGemm::get_exclusive_workspace_size();
workspace_arr[device_idx] = cutlass::device_memory::allocation<uint8_t>(workspace_size);
@@ -804,8 +803,7 @@ int run(Options &options) {
return 0;
}
#endif // (defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) &&
// (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
#endif //(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
///////////////////////////////////////////////////////////////////////////////////////////////////
@@ -859,7 +857,7 @@ int main(int argc, char const **args) {
// Evaluate CUTLASS kernels
//
#if (defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6)))
#if ((__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6)))
run(options);
#else
std::cerr

View File

@@ -132,7 +132,7 @@ using namespace cute;
using TP = _8;
static constexpr int TP_ = TP{};
#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && \
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && \
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))
// Distributed GEMM tiling/sharding schedule
@@ -254,7 +254,7 @@ HostTensorB tensor_B_arr[TP_];
HostTensorD tensor_C_arr[TP_];
HostTensorD tensor_D_arr[TP_];
#endif // (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) &&
#endif // (defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) &&
// (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))
/////////////////////////////////////////////////////////////////////////////////////////////////
@@ -347,8 +347,7 @@ struct Result {
};
#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && \
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))
#if ((__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8)))
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
@@ -405,9 +404,9 @@ void initialize(const Options &options) {
stride_C = cutlass::make_cute_packed_stride(StrideC{}, shape_C);
stride_D = cutlass::make_cute_packed_stride(StrideD{}, shape_D);
auto a_coord = cutlass::make_Coord(size(shape_A), 1);
auto b_coord = cutlass::make_Coord(size(shape_B), 1);
auto c_coord = cutlass::make_Coord(size(shape_C), 1);
auto a_coord = cutlass::make_Coord(size<2>(shape_A)*size<0>(shape_A), size<1>(shape_A));
auto b_coord = cutlass::make_Coord(size<2>(shape_B)*size<0>(shape_B), size<1>(shape_B));
auto c_coord = cutlass::make_Coord(size<2>(shape_C)*size<0>(shape_C), size<1>(shape_C));
tensor_A.resize(a_coord);
tensor_B.resize(b_coord);
@@ -475,6 +474,9 @@ GemmArguments gemm_args_from_options(const Options &options) {
tensor_ref_D.device_data(), stride_D
}
};
// Preferred cluster can fail if these aren't set explicitly
arguments.hw_info.cluster_shape = dim3(2,1,1);
arguments.hw_info.cluster_shape_fallback = dim3(2,1,1);
return arguments;
}
@@ -548,6 +550,9 @@ DistGemmArguments dist_gemm_args_from_options(
{}, // hw_info
{} // scheduler
};
// Preferred cluster can fail if these aren't set explicitly
arguments.hw_info.cluster_shape = dim3(2,1,1);
arguments.hw_info.cluster_shape_fallback = dim3(2,1,1);
return arguments;
}
@@ -652,7 +657,7 @@ int run(Options &options) {
arguments_[device_idx] = dist_gemm_args_from_options(options, device_idx, stream_arr[device_idx]);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = DistGemm::get_workspace_size(arguments_[device_idx]);
size_t workspace_size = DistGemm::get_workspace_size(arguments_, device_idx);
size_t exclusive_workspace_size = DistGemm::get_exclusive_workspace_size();
workspace_arr[device_idx] = cutlass::device_memory::allocation<uint8_t>(workspace_size);
@@ -806,8 +811,7 @@ int run(Options &options) {
return 0;
}
#endif // (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) &&
// (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))
#endif // (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))
///////////////////////////////////////////////////////////////////////////////////////////////////
@@ -861,7 +865,7 @@ int main(int argc, char const **args) {
// Evaluate CUTLASS kernels
//
#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8)))
#if ((__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8)))
run(options);
#else
std::cerr

View File

@@ -253,16 +253,59 @@ public:
return DistSchedule::get_tensor_D(tensor_D, tensor_buffer, device_idx, iteration);
}
static
auto make_dummy_base_args(Arguments const* args, int device_idx, int iteration, void ** buffer_space) {
// Set up GEMM arguments for the current stage/iteration
auto tensor_a_iter = get_tensor_A_for_iter(args, buffer_space, device_idx, iteration);
auto tensor_b_iter = get_tensor_B_for_iter(args, buffer_space, device_idx, iteration);
auto tensor_c_iter = get_tensor_C_for_iter(args, buffer_space, device_idx, iteration);
auto tensor_d_iter = get_tensor_D_for_iter(args, buffer_space, device_idx, iteration);
Arguments base_args = args[device_idx];
base_args.problem_shape = DistSchedule::get_local_gemm_shape(args[device_idx].problem_shape);
base_args.mainloop = {
reinterpret_cast<const ElementA*>(tensor_a_iter.data()),
tensor_a_iter.stride(),
reinterpret_cast<const ElementB*>(tensor_b_iter.data()),
tensor_b_iter.stride()
};
base_args.epilogue = {
base_args.epilogue.thread,
reinterpret_cast<const ElementC*>(tensor_c_iter.data()),
tensor_c_iter.stride(),
reinterpret_cast<ElementD*>(tensor_d_iter.data()),
tensor_d_iter.stride()
};
if constexpr (DistSchedule::RemoteC) {
if (iteration > 0) {
base_args.epilogue.thread.beta = 1.0;
}
else if (iteration == 0){
base_args.epilogue.thread.beta = 0.0;
}
}
return base_args;
}
static size_t
get_workspace_size(Arguments const& args) {
get_workspace_size(Arguments const* args, int device_idx) {
size_t workspace_bytes = 0;
workspace_bytes = get_buffer_space_size(args);
workspace_bytes = get_buffer_space_size(args[device_idx]);
void* dummy_buffer_space[TP_];
for (int iteration = 0; iteration < TP_; ++iteration) {
// Workspace sizes can vary if arguments change, therefore we must
// construct args for each iteration exactly as it will be run.
auto args_base = make_dummy_base_args(args, device_idx, iteration, dummy_buffer_space);
// NOTE: assumes underlying kernels align up to alignment requirements on their own,
// and that the alignment requirements of the individual kernels match.
workspace_bytes += GemmKernel::get_workspace_size(args);
workspace_bytes += GemmKernel::get_workspace_size(args_base);
}
return workspace_bytes;

View File

@@ -110,6 +110,13 @@ constexpr int stages_member(DispatchPolicy) {
}
}
template <class GemmKernel, class = void>
struct IsDistGemmKernel : cute::false_type { };
template <typename GemmKernel>
struct IsDistGemmKernel<GemmKernel, cute::void_t<typename GemmKernel::TP>>
: cute::true_type { };
} // namespace detail
template <class GemmKernel_>
@@ -396,8 +403,13 @@ public:
|| GemmKernel::ArchTag::kMinComputeCapability == 103
) {
if constexpr (!cute::is_static_v<typename GemmKernel::DispatchPolicy::ClusterShape>) {
fallback_cluster = params.hw_info.cluster_shape_fallback;
cluster = params.hw_info.cluster_shape;
if constexpr (detail::IsDistGemmKernel<GemmKernel>::value) {
fallback_cluster = params.base.hw_info.cluster_shape_fallback;
cluster = params.base.hw_info.cluster_shape;
} else {
fallback_cluster = params.hw_info.cluster_shape_fallback;
cluster = params.hw_info.cluster_shape;
}
}
}