Fix for Unsupported Input Shapes/Sizes in Stream-K GEMM - BF16/FP16 (#1866)

This commit is contained in:
Muhammed Emin Ozturk
2025-02-18 08:46:47 -08:00
committed by GitHub
parent c287418dcc
commit 92b79ead0a
5 changed files with 85 additions and 16 deletions

View File

@@ -230,6 +230,23 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
}
}();
// Pad both M and K to be multiples of the block sizes
const auto a_grid_desc_m_k =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_pass_through_transform(MPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
#if 0
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
@@ -296,6 +313,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
return a_grid_desc_ak0_m_ak1;
}
#endif
}
__device__ static auto MakeBGridDescriptor_BK0_N_BK1(
@@ -312,6 +330,23 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
}
}();
// Pad both N and K to be multiples of the block sizes
const auto b_grid_desc_n_k =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_right_pad_transform(N, NPad - N),
make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_pass_through_transform(NPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
#if 0
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
@@ -378,6 +413,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
return b_grid_desc_bk0_n_bk1;
}
#endif
}
template <typename ABlockDesc_AK0_M_AK1>
@@ -412,6 +448,13 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
}
}();
// Pad both M and N to be multiples of the block sizes
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
#if 0
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
@@ -449,6 +492,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
// not pad M or N
return c_grid_desc_mraw_nraw;
}
#endif
}
struct Problem
@@ -953,7 +997,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) &&
!(is_same<tensor_layout::gemm::RowMajor, ALayout>::value))
{
if(!(karg.M % MPerBlock == 0))
{
@@ -970,7 +1015,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) &&
(is_same<tensor_layout::gemm::RowMajor, BLayout>::value))
{
if(!(karg.N % NPerBlock == 0))
{
@@ -1036,6 +1082,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
}
@@ -1051,6 +1098,10 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
std::cout << "Arg N (" << karg.N
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
return false;
}
}
@@ -1065,6 +1116,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
}
@@ -1082,6 +1134,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
}
return false;
}
}
@@ -1098,17 +1151,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
}
return false;
}
}
if constexpr(is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << " Grid size: " << karg.Grid_size << " > 1 is not support yet"
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
return false;
}
}

View File