mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
updates
This commit is contained in:
@@ -84,8 +84,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
static constexpr index_t flatKPerWarp = Problem::flatKPerWarp;
|
||||
static constexpr index_t flatNPerWarp = Problem::flatNPerWarp;
|
||||
|
||||
static constexpr index_t GetVectorSizeA() { return 32; }
|
||||
static constexpr index_t GetVectorSizeB() { return 32; /* fixed for fp4 shuffle layout*/ }
|
||||
static constexpr index_t GetVectorSizeA() { return 32; } /* fixed for fp4 shuffle layout*/
|
||||
static constexpr index_t GetVectorSizeB() { return 32; } /* fixed for fp4 shuffle layout*/
|
||||
static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; }
|
||||
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
@@ -126,11 +126,6 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
static constexpr index_t NXdlPack = Problem::NXdlPack;
|
||||
static constexpr index_t KXdlPack = Problem::KXdlPack;
|
||||
|
||||
static constexpr index_t MIterScalePerWarp = MIterPerWarp / MXdlPack;
|
||||
static constexpr index_t NIterScalePerWarp = NIterPerWarp / NXdlPack;
|
||||
static constexpr index_t KIterScalePerWarp = KIterPerWarp / KXdlPack;
|
||||
|
||||
static constexpr int MXFP4PackedSize = 2;
|
||||
static constexpr index_t AK1 = Problem::VectorLoadSize / sizeof(ADataType) * APackedSize;
|
||||
static constexpr index_t BK1 = Problem::VectorLoadSize / sizeof(BDataType) * BPackedSize;
|
||||
|
||||
@@ -138,29 +133,6 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
? DsReadPreload
|
||||
: MIterPerWarp * KIterPerWarp;
|
||||
|
||||
// static constexpr int ContinuousKPerThread = Problem::ContinuousKPerThread;
|
||||
// static constexpr int ContinuousScaleNPerThread = Problem::ContinuousScaleNPerThread;
|
||||
// static constexpr int ContinuousScaleKPerThread = Problem::ContinuousScaleKPerThread;
|
||||
|
||||
// static constexpr int ScaleKFlatPerWarp =
|
||||
// ContinuousScaleNPerThread * ContinuousScaleKPerThread * get_warp_size();
|
||||
|
||||
// static constexpr int XDLK_PerThread =
|
||||
// WarpTile::at(I2) / (get_warp_size() / WarpTile::at(I1)); // 8
|
||||
|
||||
// static constexpr int XDL_PerWeightK = 4; // 4
|
||||
// static constexpr int XDL_PerScaleK = XDL_PerWeightK * ContinuousScaleKPerThread; // 4
|
||||
// static constexpr int XDL_PerScaleN = ContinuousScaleNPerThread; // 2
|
||||
// static_assert(XDL_PerScaleK % XDL_PerWeightK == 0);
|
||||
// static_assert(KIterPerWarp % XDL_PerScaleK == 0);
|
||||
// static_assert(NIterPerWarp % XDL_PerScaleN == 0);
|
||||
|
||||
// static constexpr int MXFP4KPerWarp = KIterPerWarp / XDL_PerWeightK;
|
||||
// static constexpr int ScaleKPerWarp = KIterPerWarp / XDL_PerScaleK;
|
||||
// static constexpr int ScaleNPerWarp = NIterPerWarp / XDL_PerScaleN;
|
||||
|
||||
// static constexpr int MXFP4K_PerScaleK = MXFP4KPerWarp / ScaleKPerWarp;
|
||||
|
||||
static constexpr bool HasHotLoop = Problem::HasHotLoop;
|
||||
static constexpr auto TailNum = Problem::TailNum;
|
||||
|
||||
@@ -170,19 +142,21 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
static constexpr index_t mfma_per_wg = 1;
|
||||
#endif
|
||||
static constexpr index_t dsread_per_wg =
|
||||
WG::kM * WG::kK * sizeof(ADataType) / WaveSize / Problem::VectorLoadSize;
|
||||
static_assert((WG::kM * WG::kK * sizeof(ADataType) / WaveSize) % Problem::VectorLoadSize == 0);
|
||||
WG::kM * WG::kK * sizeof(ADataType) / APackedSize / WaveSize / Problem::VectorLoadSize;
|
||||
static_assert((WG::kM * WG::kK * sizeof(ADataType) / APackedSize / WaveSize) %
|
||||
Problem::VectorLoadSize ==
|
||||
0);
|
||||
|
||||
static constexpr index_t dsread_num_perK = dsread_per_wg * MIterPerWarp;
|
||||
static constexpr index_t dswrite_num_perK = dsread_num_perK / (MWarp * NWarp);
|
||||
static constexpr index_t dswrite_rep = (dswrite_num_perK + MIterPerWarp - 1) / MIterPerWarp;
|
||||
static constexpr index_t Aload_num_perK = dswrite_num_perK;
|
||||
static constexpr index_t Aload_rep = dswrite_rep;
|
||||
|
||||
static constexpr index_t Bload_num_perK = kNPerBlock * WG::kK / NWarp / BK1 / WaveSize;
|
||||
static constexpr index_t ScaleBload_K1 = NXdlPack * KXdlPack; // fixed for fp4
|
||||
static constexpr index_t ScaleBload_num =
|
||||
kNPerBlock * kKPerBlock / NWarp / 32 / ScaleBload_K1 /
|
||||
WaveSize; // BlockN * BlockK / NWarp / ScalePerK / ScaleB_K1 / wavesize
|
||||
kNPerBlock * kKPerBlock / NWarp / 32 / ScaleBload_K1 / WaveSize;
|
||||
static constexpr index_t KPerScaleLoad = KIterPerWarp / ScaleBload_num;
|
||||
static constexpr index_t HalfMIter = (MIterPerWarp + 1) / 2;
|
||||
static constexpr index_t Bload_rep = (Bload_num_perK + HalfMIter - 1) / HalfMIter;
|
||||
@@ -277,6 +251,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler()
|
||||
{
|
||||
// Keypoint of pipeline optimize is workload balance in time
|
||||
@@ -407,10 +382,10 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
? Aload_rep
|
||||
: 0;
|
||||
}
|
||||
if((kIter % KPerScaleLoad == 0) && (mIter == 0))
|
||||
{
|
||||
load_perM = load_perM + 1;
|
||||
}
|
||||
// if((kIter % KPerScaleLoad == 0) && (mIter == 0))
|
||||
// {
|
||||
// load_perM = load_perM + 1;
|
||||
// }
|
||||
SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user