Integrate Multi D GEMMs into Grouped GEMMs along with unit tests (#2923)

* feat(grouped_gemm_multi_d): add new example that integrates grouped_gemm and multi_d_gemm feature

* feat: generalized grouped_gemm_kernel.hpp

* feat: generalized grouped_gemm_kernel.hpp even further by removing hardcoded 0

* refactor: grouped_gemm_multi_d relies on grouped_gemm_kernel

* tests(grouped_gemm): grouped_gemm test suite passes with minor adjustments

* fix: segfault fix by passing correct parameters for d tensors

* docs: add multi d info and trim down outdated content

* tests: add unit tests for grouped_gemm_multi_d and minor changes in grouped_gemm related test for compatibility

* style: clang format

* fix: incorrect validation method and Dtensor layout in test suite
This commit is contained in:
Aviral Goel
2025-09-26 12:59:58 -04:00
committed by GitHub
parent e40c0acef2
commit a44bea45b2
16 changed files with 1527 additions and 216 deletions

View File

@@ -9,7 +9,6 @@
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "ck_tile/utility/json_dump.hpp"
#define CK_TILE_PIPELINE_COMPUTE_V3 1
@@ -296,7 +295,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_PRESHUFFLE_V2>
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<PipelineProblem>;
};
using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs;
using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs<>;
std::pair<bool, ck_tile::ArgParser> create_args(int argc, char* argv[])
{
@@ -325,7 +324,7 @@ std::pair<bool, ck_tile::ArgParser> create_args(int argc, char* argv[])
inline std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
{
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg);
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg<>);
}
template <typename GemmConfig, typename T>