mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
Integrate Multi D GEMMs into Grouped GEMMs along with unit tests (#2923)
* feat(grouped_gemm_multi_d): add new example that integrates grouped_gemm and multi_d_gemm feature * feat: generalized grouped_gemm_kernel.hpp * feat: generalized grouped_gemm_kernel.hpp even further by removing hardcoded 0 * refactor: grouped_gemm_multi_d relies on grouped_gemm_kernel * tests(grouped_gemm): grouped_gemm test suite passes with minor adjustments * fix: segfault fix by passing correct parameters for d tensors * docs: add multi d info and trim down outdated content * tests: add unit tests for grouped_gemm_multi_d and minor changes in grouped_gemm related test for compatibility * style: clang format * fix: incorrect validation method and Dtensor layout in test suite
This commit is contained in:
@@ -9,7 +9,6 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V3 1
|
||||
@@ -296,7 +295,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_PRESHUFFLE_V2>
|
||||
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<PipelineProblem>;
|
||||
};
|
||||
|
||||
using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs;
|
||||
using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs<>;
|
||||
|
||||
std::pair<bool, ck_tile::ArgParser> create_args(int argc, char* argv[])
|
||||
{
|
||||
@@ -325,7 +324,7 @@ std::pair<bool, ck_tile::ArgParser> create_args(int argc, char* argv[])
|
||||
|
||||
inline std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
|
||||
{
|
||||
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg);
|
||||
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg<>);
|
||||
}
|
||||
|
||||
template <typename GemmConfig, typename T>
|
||||
|
||||
Reference in New Issue
Block a user