Fix develop: basic gemm

This commit is contained in:
Mateusz Ozga
2025-06-16 15:12:24 +00:00
parent d996bc78be
commit 2d23e434ff
6 changed files with 12 additions and 0 deletions

View File

@@ -69,9 +69,12 @@ float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile:
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
CDEElementWise,
CodegenPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,

View File

@@ -2,6 +2,8 @@
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp"
#include "ck_tile/host/concat.hpp"
namespace ck_tile {

View File

@@ -30,6 +30,8 @@ struct GemmPipelineAGmemBGmemCRegV1
using I1 = number<1>;
using I2 = number<2>;
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t kMPerBlock = BlockGemmShape::kM;

View File

@@ -27,6 +27,8 @@ struct GemmPipelineAGmemBGmemCRegV2
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
static constexpr index_t kKPerBlock = BlockGemmShape::kK;

View File

@@ -45,6 +45,8 @@ struct GemmPipelineProblemBase
static constexpr auto Scheduler = GemmPipelineScheduler::Default;
static constexpr index_t VectorLoadSize = Traits::_VectorSize;
static constexpr index_t NumWaveGroups = Traits::NumWaveGroups;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off

View File

@@ -28,6 +28,7 @@ struct TileGemmTraits
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr index_t NumWaveGroups = 1;
};
template <bool kPadM_,