mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 07:32:36 +00:00
Merge branch 'mmflat' of https://github.com/ROCm/composable_kernel into mmflat
This commit is contained in:
@@ -2,9 +2,14 @@ add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp)
|
||||
add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp)
|
||||
add_executable(tile_example_gemm_weight_preshuffle EXCLUDE_FROM_ALL gemm_weight_preshuffle.cpp)
|
||||
set(EXAMPLE_GEMM_COMPILE_OPTIONS)
|
||||
set(EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS)
|
||||
if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
|
||||
list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS -Wno-unused-local-typedef)
|
||||
list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS --save-temps)
|
||||
list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm -enable-noalias-to-md-conversion=0")
|
||||
target_compile_options(tile_example_gemm_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(tile_example_gemm_universal PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(tile_example_gemm_weight_preshuffle PRIVATE ${EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS})
|
||||
|
||||
@@ -249,9 +249,9 @@ struct GemmConfigPreshuffle_1 : public GemmConfigBase
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V3;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V1;
|
||||
static constexpr bool Preshuffle = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -271,7 +271,7 @@ struct GemmConfigPreshuffle_2 : public GemmConfigBase
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V3;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
|
||||
static constexpr bool Preshuffle = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
};
|
||||
@@ -291,7 +291,7 @@ struct GemmConfigPreshuffle_3 : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
|
||||
static constexpr bool Preshuffle = true;
|
||||
|
||||
@@ -36,10 +36,13 @@ float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile:
|
||||
GemmConfig::PermuteA,
|
||||
GemmConfig::PermuteB>;
|
||||
|
||||
// using TilePartitioner =
|
||||
// ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
// GemmConfig::TileParitionerGroupNum,
|
||||
// GemmConfig::TileParitionerM01>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
GemmConfig::TileParitionerGroupNum,
|
||||
GemmConfig::TileParitionerM01>;
|
||||
ck_tile::GemmTile1DPartitioner<GemmShape>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
@@ -69,6 +72,7 @@ float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile:
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
|
||||
std::cout << "k_grain: " << k_grain << " K_split: " << K_split << std::endl;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
@@ -290,5 +294,5 @@ int main(int argc, char* argv[])
|
||||
// Return a non-zero code to indicate failure
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
return EXIT_SUCCESS;
|
||||
//return EXIT_SUCCESS;
|
||||
}
|
||||
|
||||
@@ -41,10 +41,10 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz
|
||||
{
|
||||
buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD};
|
||||
int32x4_t r = __builtin_bit_cast(int32x4_t, res);
|
||||
r.x = __builtin_amdgcn_readfirstlane(r.x);
|
||||
r.y = __builtin_amdgcn_readfirstlane(r.y);
|
||||
r.z = __builtin_amdgcn_readfirstlane(r.z);
|
||||
r.w = __builtin_amdgcn_readfirstlane(r.w);
|
||||
// r.x = __builtin_amdgcn_readfirstlane(r.x);
|
||||
// r.y = __builtin_amdgcn_readfirstlane(r.y);
|
||||
// r.z = __builtin_amdgcn_readfirstlane(r.z);
|
||||
// r.w = __builtin_amdgcn_readfirstlane(r.w);
|
||||
return r;
|
||||
}
|
||||
|
||||
|
||||
@@ -113,6 +113,7 @@ struct BlockWeightPreshuffleASmemBSmemCRegV1
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
__builtin_amdgcn_sched_barrier(0x7F6);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -246,6 +246,11 @@ struct GemmKernel
|
||||
{
|
||||
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize2ndBuffer()
|
||||
{
|
||||
return GemmPipeline::GetSmemSize();
|
||||
}
|
||||
|
||||
struct SplitKBatchOffset
|
||||
{
|
||||
@@ -950,7 +955,7 @@ struct GemmKernel
|
||||
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
__shared__ char smem_ptr_1[(GemmPipeline::Preshuffle) ? GetSmemSize2ndBuffer() : GetSmemSize()];
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<EDataType, fp16_t, bf16_t>::value))
|
||||
|
||||
@@ -112,7 +112,7 @@ struct GemmTile1DPartitioner
|
||||
* @param N GEMM's N dimension.
|
||||
* @return dim3 Structure holding grid's X,Y and Z dimensions.
|
||||
*/
|
||||
CK_TILE_HOST static auto
|
||||
CK_TILE_HOST_DEVICE static auto
|
||||
GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock != 0 && NPerBlock != 0)) -> index_t
|
||||
{
|
||||
const index_t GridDimX = (M + MPerBlock - 1) / MPerBlock;
|
||||
|
||||
@@ -118,6 +118,10 @@ struct GemmPipelineProblemBase
|
||||
}
|
||||
|
||||
static constexpr index_t VectorSizeA = []() {
|
||||
|
||||
// std::cout << "FixedVectorSize: " << FixedVectorSize << std::endl;
|
||||
// std::cout << "kPadK: " << kPadK << std::endl;
|
||||
// std::cout << "kPadM: " << kPadM << std::endl;
|
||||
if constexpr(FixedVectorSize)
|
||||
{
|
||||
return VectorSizeA_;
|
||||
@@ -133,6 +137,9 @@ struct GemmPipelineProblemBase
|
||||
}();
|
||||
|
||||
static constexpr index_t VectorSizeB = []() {
|
||||
// std::cout << "FixedVectorSize: " << FixedVectorSize << std::endl;
|
||||
// std::cout << "kPadK: " << kPadK << std::endl;
|
||||
// std::cout << "kPadN: " << kPadN << std::endl;
|
||||
if constexpr(FixedVectorSize)
|
||||
{
|
||||
return VectorSizeB_;
|
||||
|
||||
@@ -14,36 +14,37 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static constexpr auto I2 = number<2>{};
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeADramDistribution()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
// using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
// template <typename Problem>
|
||||
// CK_TILE_HOST_DEVICE static constexpr auto MakeADramDistribution() //looks like this function is not getting used
|
||||
// {
|
||||
// using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
// // using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
// constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
// constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
// // constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
// constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(ADataType);
|
||||
constexpr index_t K0 = KPerBlock / K1;
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
constexpr index_t M1 = BlockSize / get_warp_size();
|
||||
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
|
||||
// constexpr index_t M0 = MPerBlock / (M2 * M1);
|
||||
// static_assert(M0 * M1 * M2 == MPerBlock,
|
||||
// "Incorrect M0, M2, M1 configuration! "
|
||||
// "M0, M1, M2 must cover whole MPerBlock!");
|
||||
// constexpr index_t K1 = 16 / sizeof(ADataType);
|
||||
// constexpr index_t K0 = KPerBlock / K1;
|
||||
// constexpr index_t M2 = get_warp_size() / K0;
|
||||
// constexpr index_t M1 = BlockSize / get_warp_size();
|
||||
// static_assert(K1 == 1, "M2 is zero, which will lead to a division by zero error.");
|
||||
// static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
|
||||
// static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
|
||||
// // constexpr index_t M0 = MPerBlock / (M2 * M1);
|
||||
// // static_assert(M0 * M1 * M2 == MPerBlock,
|
||||
// // "Incorrect M0, M2, M1 configuration! "
|
||||
// // "M0, M1, M2 must cover whole MPerBlock!");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{});
|
||||
}
|
||||
// return make_static_tile_distribution(
|
||||
// tile_distribution_encoding<sequence<1>,
|
||||
// tuple<sequence<M1, M2>, sequence<K0, K1>>,
|
||||
// tuple<sequence<1>, sequence<1, 2>>,
|
||||
// tuple<sequence<0>, sequence<1, 0>>,
|
||||
// sequence<2>,
|
||||
// sequence<1>>{});
|
||||
// }
|
||||
// 3d + padding
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
|
||||
|
||||
@@ -17,8 +17,12 @@ struct BaseWeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
|
||||
|
||||
CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
|
||||
std::cout << "BlockHasHotloop: " << num_loop << std::endl;
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
@@ -33,10 +37,12 @@ struct BaseWeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
{
|
||||
if(tail_number == TailNumber::Odd)
|
||||
{
|
||||
std::cout << "TailHandler: Odd" << std::endl;
|
||||
run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_number == TailNumber::Even)
|
||||
{
|
||||
std::cout << "TailHandler: Even" << std::endl;
|
||||
run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Even>{});
|
||||
}
|
||||
}
|
||||
@@ -74,8 +80,14 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp;
|
||||
static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp;
|
||||
|
||||
static constexpr index_t GetVectorSizeA() { return Problem::VectorSizeA; }
|
||||
static constexpr index_t GetVectorSizeB() { return Problem::VectorSizeB; }
|
||||
static constexpr index_t GetVectorSizeA() {
|
||||
return PipelinePolicy::template GetVectorSizeA<Problem>();
|
||||
//return Problem::VectorSizeA;
|
||||
}
|
||||
static constexpr index_t GetVectorSizeB() {
|
||||
return PipelinePolicy::template GetVectorSizeB<Problem>();
|
||||
//return Problem::VectorSizeB;
|
||||
}
|
||||
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
@@ -127,7 +139,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
static constexpr index_t Preshuffle = Problem::Preshuffle;
|
||||
using Base::UsePersistentKernel;
|
||||
|
||||
@@ -226,6 +238,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
#if defined(__gfx950__)
|
||||
if constexpr(kMPerBlock == 128 && kNPerBlock == 256 && kKPerBlock == 256)
|
||||
{
|
||||
//printf("Inside gfx950, with 16x16 128x256x256 \n");
|
||||
static_for<0, 2, 1>{}([&](auto j) {
|
||||
ignore = j;
|
||||
static_for<0, 3, 1>{}([&](auto i) {
|
||||
@@ -273,6 +286,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
}
|
||||
else
|
||||
{
|
||||
//printf("Inside gfx950, with 16x16 otherwise \n");
|
||||
static_for<0, 2, 1>{}([&](auto j) {
|
||||
ignore = j;
|
||||
static_for<0, 3, 1>{}([&](auto i) {
|
||||
@@ -311,8 +325,9 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
// MFMA → MFMA → MFMA → MFMA → DS Read
|
||||
// For other device engine we need more agressive MFMA with DS writes interleaved
|
||||
#else
|
||||
if constexpr(kMPerBlock == 128 && kNPerBlock == 256 && kKPerBlock == 256)
|
||||
if constexpr(kMPerBlock == 128 && kNPerBlock == 256 && kKPerBlock == 256) //TODO :: 128x256x128
|
||||
{
|
||||
//printf("Inside gfx942, with 16x16 128x256x256 \n");
|
||||
static_for<0, 2, 1>{}([&](auto j) {
|
||||
ignore = j;
|
||||
// Uses loops to amortize scheduling overhead
|
||||
@@ -388,6 +403,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
}
|
||||
else if constexpr(kMPerBlock == 16 && kNPerBlock == 64 && kKPerBlock == 256)
|
||||
{
|
||||
//printf("Inside gfx942, with 16x16 16x64x256 \n");
|
||||
static_for<0, 1, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
@@ -416,6 +432,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
}
|
||||
else if constexpr(kMPerBlock == 128 && kNPerBlock == 128 && kKPerBlock == 128)
|
||||
{
|
||||
//printf("Inside gfx942, with 16x16 128x128x128 \n");
|
||||
// prioritize MFMA to avoid LDS write conflicts
|
||||
static_for<0, 2, 1>{}([&](auto j) {
|
||||
ignore = j;
|
||||
@@ -478,6 +495,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
}
|
||||
else
|
||||
{
|
||||
//printf("Inside gfx942, with 16x16 otherwise \n");
|
||||
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
@@ -505,6 +523,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
}
|
||||
else
|
||||
{
|
||||
//printf("Inside gfx950 or gfx942, with other then 16x16 any block sizes \n");
|
||||
if constexpr((A_LDS_Read_Inst_Num / 2 >
|
||||
A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num))
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user