mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-20 06:48:59 +00:00
DistGEMM bug fixes (#2713)
* Blackwell DistGEMM bug fixes 1. If using preferred cluster, there needs to be a branch so that the universal GEMM wrapper finds the correct base params. 2. Workspace sizes can change depending on problem shape in Blackwell, and DistGEMM was previously using the per-device shape to evaluate workspace size instead of the per-gemm shape. 3. Flattened size used to initialize host tensors can overflow (in Hopper example as well) 4. Preferred and fallback cluster args need to be set explicitly, otherwise if someone modifies the example to use preferred cluster, it will just fail. * Fix example runtimes * Set default fallback cluster shapes to the static ones
This commit is contained in:
@@ -132,7 +132,7 @@ using namespace cute;
|
||||
using TP = _8;
|
||||
static constexpr int TP_ = TP{};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && \
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && \
|
||||
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
|
||||
|
||||
// Distributed GEMM tiling/sharding schedule
|
||||
@@ -252,7 +252,7 @@ HostTensorB tensor_B_arr[TP_];
|
||||
HostTensorD tensor_C_arr[TP_];
|
||||
HostTensorD tensor_D_arr[TP_];
|
||||
|
||||
#endif // (defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) &&
|
||||
#endif // (defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) &&
|
||||
// (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -345,8 +345,7 @@ struct Result {
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && \
|
||||
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
|
||||
#if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
@@ -403,9 +402,9 @@ void initialize(const Options &options) {
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, shape_C);
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, shape_D);
|
||||
|
||||
auto a_coord = cutlass::make_Coord(size(shape_A), 1);
|
||||
auto b_coord = cutlass::make_Coord(size(shape_B), 1);
|
||||
auto c_coord = cutlass::make_Coord(size(shape_C), 1);
|
||||
auto a_coord = cutlass::make_Coord(size<2>(shape_A)*size<0>(shape_A), size<1>(shape_A));
|
||||
auto b_coord = cutlass::make_Coord(size<2>(shape_B)*size<0>(shape_B), size<1>(shape_B));
|
||||
auto c_coord = cutlass::make_Coord(size<2>(shape_C)*size<0>(shape_C), size<1>(shape_C));
|
||||
|
||||
tensor_A.resize(a_coord);
|
||||
tensor_B.resize(b_coord);
|
||||
@@ -650,7 +649,7 @@ int run(Options &options) {
|
||||
arguments_[device_idx] = dist_gemm_args_from_options(options, device_idx, stream_arr[device_idx]);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = DistGemm::get_workspace_size(arguments_[device_idx]);
|
||||
size_t workspace_size = DistGemm::get_workspace_size(arguments_, device_idx);
|
||||
size_t exclusive_workspace_size = DistGemm::get_exclusive_workspace_size();
|
||||
|
||||
workspace_arr[device_idx] = cutlass::device_memory::allocation<uint8_t>(workspace_size);
|
||||
@@ -804,8 +803,7 @@ int run(Options &options) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // (defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) &&
|
||||
// (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
|
||||
#endif //(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@@ -859,7 +857,7 @@ int main(int argc, char const **args) {
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
|
||||
#if (defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6)))
|
||||
#if ((__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6)))
|
||||
run(options);
|
||||
#else
|
||||
std::cerr
|
||||
|
||||
@@ -132,7 +132,7 @@ using namespace cute;
|
||||
using TP = _8;
|
||||
static constexpr int TP_ = TP{};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && \
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && \
|
||||
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))
|
||||
|
||||
// Distributed GEMM tiling/sharding schedule
|
||||
@@ -254,7 +254,7 @@ HostTensorB tensor_B_arr[TP_];
|
||||
HostTensorD tensor_C_arr[TP_];
|
||||
HostTensorD tensor_D_arr[TP_];
|
||||
|
||||
#endif // (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) &&
|
||||
#endif // (defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) &&
|
||||
// (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -347,8 +347,7 @@ struct Result {
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && \
|
||||
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))
|
||||
#if ((__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8)))
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
@@ -405,9 +404,9 @@ void initialize(const Options &options) {
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, shape_C);
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, shape_D);
|
||||
|
||||
auto a_coord = cutlass::make_Coord(size(shape_A), 1);
|
||||
auto b_coord = cutlass::make_Coord(size(shape_B), 1);
|
||||
auto c_coord = cutlass::make_Coord(size(shape_C), 1);
|
||||
auto a_coord = cutlass::make_Coord(size<2>(shape_A)*size<0>(shape_A), size<1>(shape_A));
|
||||
auto b_coord = cutlass::make_Coord(size<2>(shape_B)*size<0>(shape_B), size<1>(shape_B));
|
||||
auto c_coord = cutlass::make_Coord(size<2>(shape_C)*size<0>(shape_C), size<1>(shape_C));
|
||||
|
||||
tensor_A.resize(a_coord);
|
||||
tensor_B.resize(b_coord);
|
||||
@@ -475,6 +474,9 @@ GemmArguments gemm_args_from_options(const Options &options) {
|
||||
tensor_ref_D.device_data(), stride_D
|
||||
}
|
||||
};
|
||||
// Preferred cluster can fail if these aren't set explicitly
|
||||
arguments.hw_info.cluster_shape = dim3(2,1,1);
|
||||
arguments.hw_info.cluster_shape_fallback = dim3(2,1,1);
|
||||
|
||||
return arguments;
|
||||
}
|
||||
@@ -548,6 +550,9 @@ DistGemmArguments dist_gemm_args_from_options(
|
||||
{}, // hw_info
|
||||
{} // scheduler
|
||||
};
|
||||
// Preferred cluster can fail if these aren't set explicitly
|
||||
arguments.hw_info.cluster_shape = dim3(2,1,1);
|
||||
arguments.hw_info.cluster_shape_fallback = dim3(2,1,1);
|
||||
|
||||
return arguments;
|
||||
}
|
||||
@@ -652,7 +657,7 @@ int run(Options &options) {
|
||||
arguments_[device_idx] = dist_gemm_args_from_options(options, device_idx, stream_arr[device_idx]);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = DistGemm::get_workspace_size(arguments_[device_idx]);
|
||||
size_t workspace_size = DistGemm::get_workspace_size(arguments_, device_idx);
|
||||
size_t exclusive_workspace_size = DistGemm::get_exclusive_workspace_size();
|
||||
|
||||
workspace_arr[device_idx] = cutlass::device_memory::allocation<uint8_t>(workspace_size);
|
||||
@@ -806,8 +811,7 @@ int run(Options &options) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) &&
|
||||
// (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))
|
||||
#endif // (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@@ -861,7 +865,7 @@ int main(int argc, char const **args) {
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
|
||||
#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8)))
|
||||
#if ((__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8)))
|
||||
run(options);
|
||||
#else
|
||||
std::cerr
|
||||
|
||||
@@ -253,16 +253,59 @@ public:
|
||||
return DistSchedule::get_tensor_D(tensor_D, tensor_buffer, device_idx, iteration);
|
||||
}
|
||||
|
||||
static
|
||||
auto make_dummy_base_args(Arguments const* args, int device_idx, int iteration, void ** buffer_space) {
|
||||
|
||||
// Set up GEMM arguments for the current stage/iteration
|
||||
auto tensor_a_iter = get_tensor_A_for_iter(args, buffer_space, device_idx, iteration);
|
||||
auto tensor_b_iter = get_tensor_B_for_iter(args, buffer_space, device_idx, iteration);
|
||||
auto tensor_c_iter = get_tensor_C_for_iter(args, buffer_space, device_idx, iteration);
|
||||
auto tensor_d_iter = get_tensor_D_for_iter(args, buffer_space, device_idx, iteration);
|
||||
|
||||
Arguments base_args = args[device_idx];
|
||||
base_args.problem_shape = DistSchedule::get_local_gemm_shape(args[device_idx].problem_shape);
|
||||
base_args.mainloop = {
|
||||
reinterpret_cast<const ElementA*>(tensor_a_iter.data()),
|
||||
tensor_a_iter.stride(),
|
||||
reinterpret_cast<const ElementB*>(tensor_b_iter.data()),
|
||||
tensor_b_iter.stride()
|
||||
};
|
||||
base_args.epilogue = {
|
||||
base_args.epilogue.thread,
|
||||
reinterpret_cast<const ElementC*>(tensor_c_iter.data()),
|
||||
tensor_c_iter.stride(),
|
||||
reinterpret_cast<ElementD*>(tensor_d_iter.data()),
|
||||
tensor_d_iter.stride()
|
||||
};
|
||||
|
||||
if constexpr (DistSchedule::RemoteC) {
|
||||
if (iteration > 0) {
|
||||
base_args.epilogue.thread.beta = 1.0;
|
||||
}
|
||||
else if (iteration == 0){
|
||||
base_args.epilogue.thread.beta = 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
return base_args;
|
||||
}
|
||||
|
||||
static size_t
|
||||
get_workspace_size(Arguments const& args) {
|
||||
get_workspace_size(Arguments const* args, int device_idx) {
|
||||
size_t workspace_bytes = 0;
|
||||
|
||||
workspace_bytes = get_buffer_space_size(args);
|
||||
workspace_bytes = get_buffer_space_size(args[device_idx]);
|
||||
|
||||
void* dummy_buffer_space[TP_];
|
||||
|
||||
for (int iteration = 0; iteration < TP_; ++iteration) {
|
||||
// Workspace sizes can vary if arguments change, therefore we must
|
||||
// construct args for each iteration exactly as it will be run.
|
||||
auto args_base = make_dummy_base_args(args, device_idx, iteration, dummy_buffer_space);
|
||||
|
||||
// NOTE: assumes underlying kernels align up to alignment requirements on their own,
|
||||
// and that the alignment requirements of the individual kernels match.
|
||||
workspace_bytes += GemmKernel::get_workspace_size(args);
|
||||
workspace_bytes += GemmKernel::get_workspace_size(args_base);
|
||||
}
|
||||
|
||||
return workspace_bytes;
|
||||
|
||||
@@ -110,6 +110,13 @@ constexpr int stages_member(DispatchPolicy) {
|
||||
}
|
||||
}
|
||||
|
||||
template <class GemmKernel, class = void>
|
||||
struct IsDistGemmKernel : cute::false_type { };
|
||||
|
||||
template <typename GemmKernel>
|
||||
struct IsDistGemmKernel<GemmKernel, cute::void_t<typename GemmKernel::TP>>
|
||||
: cute::true_type { };
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <class GemmKernel_>
|
||||
@@ -396,8 +403,13 @@ public:
|
||||
|| GemmKernel::ArchTag::kMinComputeCapability == 103
|
||||
) {
|
||||
if constexpr (!cute::is_static_v<typename GemmKernel::DispatchPolicy::ClusterShape>) {
|
||||
fallback_cluster = params.hw_info.cluster_shape_fallback;
|
||||
cluster = params.hw_info.cluster_shape;
|
||||
if constexpr (detail::IsDistGemmKernel<GemmKernel>::value) {
|
||||
fallback_cluster = params.base.hw_info.cluster_shape_fallback;
|
||||
cluster = params.base.hw_info.cluster_shape;
|
||||
} else {
|
||||
fallback_cluster = params.hw_info.cluster_shape_fallback;
|
||||
cluster = params.hw_info.cluster_shape;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user