mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
GroupedGEMM + Gelu client example/instances/profiler (#614)
* Grouped gemm + Gelu instances. * Device Instance Factory for GroupedGemm+Gelu * Client example * Rangify fill helper functions. * Fix name clash. * Profiler for grouped_gemm+gelu * No need to use full namespace name. * Add check for MRaw divisible by vector load. * Ugly fix for big errors. * Add grouped_gemm+gelu to profiler CMakelists. * Store in argument additional info. * Information about Mraw, Nraw, Kraw values. * Use FastGelu instead of Gelu. * Change client ex to use FastGelu * Remove relaxed error precision. * Remove duplicate output elementwise-op --------- Co-authored-by: Adam Osewski <aosewski@amd.com> Co-authored-by: zjing14 <zhangjing14@gmail.com>
This commit is contained in:
@@ -381,6 +381,9 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
|
||||
const index_t N = gemm_descs[i].N_;
|
||||
const index_t K = gemm_descs[i].K_;
|
||||
|
||||
a_mtx_mraw_kraw_.emplace_back(M, K);
|
||||
b_mtx_nraw_kraw_.emplace_back(N, K);
|
||||
|
||||
if(M == 0)
|
||||
{
|
||||
skipped_group_count_++;
|
||||
@@ -485,6 +488,8 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
|
||||
CDEElementwiseOperation c_element_op_;
|
||||
|
||||
std::vector<GemmBiasTransKernelArg> gemm_desc_kernel_arg_;
|
||||
std::vector<Tuple<index_t, index_t>> a_mtx_mraw_kraw_;
|
||||
std::vector<Tuple<index_t, index_t>> b_mtx_nraw_kraw_;
|
||||
|
||||
index_t grid_size_;
|
||||
};
|
||||
@@ -599,7 +604,28 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
bool supported = true;
|
||||
|
||||
// If we use padding we do not support vector loads for dimensions not divisible by vector
|
||||
// load size.
|
||||
if constexpr(GemmSpec != GemmSpecialization::Default)
|
||||
{
|
||||
// [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1} layout,
|
||||
// thus we have to adapt it to the {M,K} or {N,K} layout.
|
||||
const auto a_raw_vector_dim = ABlockTransferSrcVectorDim != 1 ? 1 : 0;
|
||||
const auto b_raw_vector_dim = BBlockTransferSrcVectorDim != 1 ? 1 : 0;
|
||||
|
||||
for(index_t i = 0; i < arg.group_count_; ++i)
|
||||
{
|
||||
const auto a_vector_dim = arg.a_mtx_mraw_kraw_[i].At(Number<a_raw_vector_dim>{});
|
||||
const auto b_vector_dim = arg.b_mtx_nraw_kraw_[i].At(Number<b_raw_vector_dim>{});
|
||||
|
||||
supported = supported & (a_vector_dim % ABlockTransferSrcScalarPerVector == 0);
|
||||
supported = supported & (b_vector_dim % BBlockTransferSrcScalarPerVector == 0);
|
||||
}
|
||||
}
|
||||
|
||||
return supported;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
@@ -661,7 +687,8 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
|
||||
<< MPerXDL << ", "
|
||||
<< NPerXDL << ", "
|
||||
<< MXdlPerWave << ", "
|
||||
<< NXdlPerWave
|
||||
<< NXdlPerWave << ", "
|
||||
<< getGemmSpecializationString(GemmSpec)
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
|
||||
Reference in New Issue
Block a user