mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 00:40:09 +00:00
update pipeline v1: add atomic IGLP schedule
This commit is contained in:
@@ -155,7 +155,7 @@ float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
|||||||
using GemmPipelineProblem =
|
using GemmPipelineProblem =
|
||||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenFlatmmShape, Traits>;
|
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenFlatmmShape, Traits>;
|
||||||
|
|
||||||
using BaseGemmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV0<GemmPipelineProblem>;
|
using BaseGemmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV1<GemmPipelineProblem>;
|
||||||
|
|
||||||
const ck_tile::index_t k_grain = args.k_batch * FlatmmConfig::K_Tile;
|
const ck_tile::index_t k_grain = args.k_batch * FlatmmConfig::K_Tile;
|
||||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * FlatmmConfig::K_Tile;
|
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * FlatmmConfig::K_Tile;
|
||||||
@@ -182,7 +182,7 @@ float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
|||||||
tail_number_v>;
|
tail_number_v>;
|
||||||
|
|
||||||
using CodegenFlatmmPipeline =
|
using CodegenFlatmmPipeline =
|
||||||
ck_tile::FlatmmPipelineAGmemBGmemCRegV0<CodegenPipelineProblem>;
|
ck_tile::FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||||
|
|
||||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ struct FlatmmConfig32
|
|||||||
static constexpr bool TransposeC = false;
|
static constexpr bool TransposeC = false;
|
||||||
static constexpr bool UseStructuredSparsity = false;
|
static constexpr bool UseStructuredSparsity = false;
|
||||||
|
|
||||||
static constexpr int kBlockPerCu = 2;
|
static constexpr int kBlockPerCu = 1;
|
||||||
static constexpr int TileParitionerGroupNum = 8;
|
static constexpr int TileParitionerGroupNum = 8;
|
||||||
static constexpr int TileParitionerM01 = 4;
|
static constexpr int TileParitionerM01 = 4;
|
||||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||||
@@ -73,7 +73,7 @@ struct FlatmmConfig16
|
|||||||
static constexpr bool TransposeC = false;
|
static constexpr bool TransposeC = false;
|
||||||
static constexpr bool UseStructuredSparsity = false;
|
static constexpr bool UseStructuredSparsity = false;
|
||||||
|
|
||||||
static constexpr int kBlockPerCu = 2;
|
static constexpr int kBlockPerCu = 1;
|
||||||
static constexpr int TileParitionerGroupNum = 8;
|
static constexpr int TileParitionerGroupNum = 8;
|
||||||
static constexpr int TileParitionerM01 = 4;
|
static constexpr int TileParitionerM01 = 4;
|
||||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||||
|
|||||||
@@ -55,8 +55,9 @@ int run_flatmm_example_with_layouts(int argc,
|
|||||||
// TODO: add different init types
|
// TODO: add different init types
|
||||||
if(init_method == 0)
|
if(init_method == 0)
|
||||||
{
|
{
|
||||||
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_host);
|
// ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_host);
|
||||||
memset(a_host.data(), 0, 4);
|
// ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_origin_host);
|
||||||
|
ck_tile::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_host);
|
||||||
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_origin_host);
|
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_origin_host);
|
||||||
ck_tile::FillUniformDistribution<AccDataType>{-1.f, 1.f}(per_token_scale);
|
ck_tile::FillUniformDistribution<AccDataType>{-1.f, 1.f}(per_token_scale);
|
||||||
ck_tile::FillUniformDistribution<AccDataType>{-1.f, 1.f}(per_channel_scale);
|
ck_tile::FillUniformDistribution<AccDataType>{-1.f, 1.f}(per_channel_scale);
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user