mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 04:31:25 +00:00
Fix for Unsupported Input Shapes/Sizes in Stream-K GEMM - BF16/FP16 (#1866)
This commit is contained in:
committed by
GitHub
parent
c287418dcc
commit
92b79ead0a
@@ -230,6 +230,23 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
}
|
||||
}();
|
||||
|
||||
// Pad both M and K to be multiples of the block sizes
|
||||
const auto a_grid_desc_m_k =
|
||||
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
|
||||
make_tuple(make_right_pad_transform(M, MPad - M),
|
||||
make_right_pad_transform(K, KPad - K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
|
||||
make_pass_through_transform(MPad)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
#if 0
|
||||
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
|
||||
@@ -296,6 +313,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ static auto MakeBGridDescriptor_BK0_N_BK1(
|
||||
@@ -312,6 +330,23 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
}
|
||||
}();
|
||||
|
||||
// Pad both N and K to be multiples of the block sizes
|
||||
const auto b_grid_desc_n_k =
|
||||
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
|
||||
make_tuple(make_right_pad_transform(N, NPad - N),
|
||||
make_right_pad_transform(K, KPad - K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
|
||||
b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
|
||||
make_pass_through_transform(NPad)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
#if 0
|
||||
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
|
||||
@@ -378,6 +413,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename ABlockDesc_AK0_M_AK1>
|
||||
@@ -412,6 +448,13 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
}
|
||||
}();
|
||||
|
||||
// Pad both M and N to be multiples of the block sizes
|
||||
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
|
||||
make_tuple(make_right_pad_transform(M, MPad - M),
|
||||
make_right_pad_transform(N, NPad - N)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
#if 0
|
||||
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
|
||||
@@ -449,6 +492,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
// not pad M or N
|
||||
return c_grid_desc_mraw_nraw;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
struct Problem
|
||||
@@ -953,7 +997,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) &&
|
||||
!(is_same<tensor_layout::gemm::RowMajor, ALayout>::value))
|
||||
{
|
||||
if(!(karg.M % MPerBlock == 0))
|
||||
{
|
||||
@@ -970,7 +1015,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) &&
|
||||
(is_same<tensor_layout::gemm::RowMajor, BLayout>::value))
|
||||
{
|
||||
if(!(karg.N % NPerBlock == 0))
|
||||
{
|
||||
@@ -1036,6 +1082,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
<< ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -1051,6 +1098,10 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
std::cout << "Arg N (" << karg.N
|
||||
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
|
||||
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -1065,6 +1116,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -1082,6 +1134,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -1098,17 +1151,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << " Grid size: " << karg.Grid_size << " > 1 is not support yet"
|
||||
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
Normal file → Executable file
0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
Normal file → Executable file
Reference in New Issue
Block a user