mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
Fix develop: basic gemm
This commit is contained in:
@@ -69,9 +69,12 @@ float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile:
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
CodegenPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp"
|
||||
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -30,6 +30,8 @@ struct GemmPipelineAGmemBGmemCRegV1
|
||||
using I1 = number<1>;
|
||||
using I2 = number<2>;
|
||||
|
||||
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
|
||||
|
||||
@@ -27,6 +27,8 @@ struct GemmPipelineAGmemBGmemCRegV2
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
|
||||
|
||||
@@ -45,6 +45,8 @@ struct GemmPipelineProblemBase
|
||||
static constexpr auto Scheduler = GemmPipelineScheduler::Default;
|
||||
static constexpr index_t VectorLoadSize = Traits::_VectorSize;
|
||||
|
||||
static constexpr index_t NumWaveGroups = Traits::NumWaveGroups;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
|
||||
@@ -28,6 +28,7 @@ struct TileGemmTraits
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
static constexpr index_t NumWaveGroups = 1;
|
||||
};
|
||||
|
||||
template <bool kPadM_,
|
||||
|
||||
Reference in New Issue
Block a user