diff --git a/example/ck_tile/03_gemm/CMakeLists.txt b/example/ck_tile/03_gemm/CMakeLists.txt index 3d3a54020c..e1dce331d1 100644 --- a/example/ck_tile/03_gemm/CMakeLists.txt +++ b/example/ck_tile/03_gemm/CMakeLists.txt @@ -2,9 +2,14 @@ add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp) add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp) add_executable(tile_example_gemm_weight_preshuffle EXCLUDE_FROM_ALL gemm_weight_preshuffle.cpp) set(EXAMPLE_GEMM_COMPILE_OPTIONS) +set(EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS) if(CK_USE_OCP_FP8) list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) endif() list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0) +list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS -Wno-unused-local-typedef) +list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS --save-temps) +list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm -enable-noalias-to-md-conversion=0") target_compile_options(tile_example_gemm_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) target_compile_options(tile_example_gemm_universal PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +target_compile_options(tile_example_gemm_weight_preshuffle PRIVATE ${EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS}) diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index 5068707160..c51a3937fe 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -249,9 +249,9 @@ struct GemmConfigPreshuffle_1 : public GemmConfigBase static constexpr int kBlockPerCu = 2; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V3; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V1; static constexpr bool Preshuffle = true; - static constexpr bool DoubleSmemBuffer = true; + static constexpr bool DoubleSmemBuffer = false; }; template @@ -271,7 +271,7 @@ struct GemmConfigPreshuffle_2 : public GemmConfigBase static constexpr int kBlockPerCu = 2; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V3; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2; static constexpr bool Preshuffle = true; static constexpr bool DoubleSmemBuffer = true; }; @@ -291,7 +291,7 @@ struct GemmConfigPreshuffle_3 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); - static constexpr int kBlockPerCu = 2; + static constexpr int kBlockPerCu = 1; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2; static constexpr bool Preshuffle = true; diff --git a/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp b/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp index 247b21a3fc..372bf0615f 100644 --- a/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp +++ b/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp @@ -36,10 +36,13 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile: GemmConfig::PermuteA, GemmConfig::PermuteB>; + // using TilePartitioner = + // ck_tile::GemmSpatiallyLocalTilePartitioner; + using TilePartitioner = - ck_tile::GemmSpatiallyLocalTilePartitioner; + ck_tile::GemmTile1DPartitioner; using Traits = ck_tile::TileGemmTraits& args, const ck_tile: const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile; const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile; + std::cout << "k_grain: " << k_grain << " K_split: " << K_split << std::endl; const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); @@ -290,5 +294,5 @@ int main(int argc, char* argv[]) // Return a non-zero code to indicate failure return EXIT_FAILURE; } - return EXIT_SUCCESS; + //return EXIT_SUCCESS; } diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index aafc6c0a85..ca5ed41dd4 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -41,10 +41,10 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz { buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD}; int32x4_t r = __builtin_bit_cast(int32x4_t, res); - r.x = __builtin_amdgcn_readfirstlane(r.x); - r.y = __builtin_amdgcn_readfirstlane(r.y); - r.z = __builtin_amdgcn_readfirstlane(r.z); - r.w = __builtin_amdgcn_readfirstlane(r.w); + // r.x = __builtin_amdgcn_readfirstlane(r.x); + // r.y = __builtin_amdgcn_readfirstlane(r.y); + // r.z = __builtin_amdgcn_readfirstlane(r.z); + // r.w = __builtin_amdgcn_readfirstlane(r.w); return r; } diff --git a/include/ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1.hpp index f4659c44fe..e60b654079 100644 --- a/include/ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1.hpp @@ -113,6 +113,7 @@ struct BlockWeightPreshuffleASmemBSmemCRegV1 merge_sequences(sequence{}, c_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), c_warp_tensor.get_thread_buffer()); + __builtin_amdgcn_sched_barrier(0x7F6); }); }); }); diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index 53c21b49f5..cbc509e424 100755 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -246,6 +246,11 @@ struct GemmKernel { return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize2ndBuffer() + { + return GemmPipeline::GetSmemSize(); + } struct SplitKBatchOffset { @@ -950,7 +955,7 @@ struct GemmKernel if constexpr(GemmPipeline::DoubleSmemBuffer == true) { - __shared__ char smem_ptr_1[GetSmemSize()]; + __shared__ char smem_ptr_1[(GemmPipeline::Preshuffle) ? GetSmemSize2ndBuffer() : GetSmemSize()]; if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && EpiloguePipeline::GetVectorSizeC() % 2 != 0 && is_any_of::value)) diff --git a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp index 28e8bee908..b03c58981f 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp @@ -112,7 +112,7 @@ struct GemmTile1DPartitioner * @param N GEMM's N dimension. * @return dim3 Structure holding grid's X,Y and Z dimensions. */ - CK_TILE_HOST static auto + CK_TILE_HOST_DEVICE static auto GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock != 0 && NPerBlock != 0)) -> index_t { const index_t GridDimX = (M + MPerBlock - 1) / MPerBlock; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index c19d42ce25..1326c462af 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -118,6 +118,10 @@ struct GemmPipelineProblemBase } static constexpr index_t VectorSizeA = []() { + + // std::cout << "FixedVectorSize: " << FixedVectorSize << std::endl; + // std::cout << "kPadK: " << kPadK << std::endl; + // std::cout << "kPadM: " << kPadM << std::endl; if constexpr(FixedVectorSize) { return VectorSizeA_; @@ -133,6 +137,9 @@ struct GemmPipelineProblemBase }(); static constexpr index_t VectorSizeB = []() { + // std::cout << "FixedVectorSize: " << FixedVectorSize << std::endl; + // std::cout << "kPadK: " << kPadK << std::endl; + // std::cout << "kPadN: " << kPadN << std::endl; if constexpr(FixedVectorSize) { return VectorSizeB_; diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp index a4a6f9a9cb..4d7c072a66 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp @@ -14,36 +14,37 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy static constexpr auto I1 = number<1>{}; static constexpr auto I2 = number<2>{}; - template - CK_TILE_HOST_DEVICE static constexpr auto MakeADramDistribution() - { - using ADataType = remove_cvref_t; - // using ALayout = remove_cvref_t; + // template + // CK_TILE_HOST_DEVICE static constexpr auto MakeADramDistribution() //looks like this function is not getting used + // { + // using ADataType = remove_cvref_t; + // // using ALayout = remove_cvref_t; - constexpr index_t BlockSize = Problem::kBlockSize; + // constexpr index_t BlockSize = Problem::kBlockSize; - // constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + // // constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + // constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t K1 = 16 / sizeof(ADataType); - constexpr index_t K0 = KPerBlock / K1; - constexpr index_t M2 = get_warp_size() / K0; - constexpr index_t M1 = BlockSize / get_warp_size(); - static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error."); - static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error."); - // constexpr index_t M0 = MPerBlock / (M2 * M1); - // static_assert(M0 * M1 * M2 == MPerBlock, - // "Incorrect M0, M2, M1 configuration! " - // "M0, M1, M2 must cover whole MPerBlock!"); + // constexpr index_t K1 = 16 / sizeof(ADataType); + // constexpr index_t K0 = KPerBlock / K1; + // constexpr index_t M2 = get_warp_size() / K0; + // constexpr index_t M1 = BlockSize / get_warp_size(); + // static_assert(K1 == 1, "M2 is zero, which will lead to a division by zero error."); + // static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error."); + // static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error."); + // // constexpr index_t M0 = MPerBlock / (M2 * M1); + // // static_assert(M0 * M1 * M2 == MPerBlock, + // // "Incorrect M0, M2, M1 configuration! " + // // "M0, M1, M2 must cover whole MPerBlock!"); - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<1, 0>>, - sequence<2>, - sequence<1>>{}); - } + // return make_static_tile_distribution( + // tile_distribution_encoding, + // tuple, sequence>, + // tuple, sequence<1, 2>>, + // tuple, sequence<1, 0>>, + // sequence<2>, + // sequence<1>>{}); + // } // 3d + padding template CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp index 692e4b4218..ebbb760670 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp @@ -17,8 +17,12 @@ struct BaseWeightPreshufflePipelineAGmemBGmemCRegV2 static constexpr index_t GlobalBufferNum = 1; static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel; + CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } + CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop) { + + std::cout << "BlockHasHotloop: " << num_loop << std::endl; return num_loop > PrefetchStages; } @@ -33,10 +37,12 @@ struct BaseWeightPreshufflePipelineAGmemBGmemCRegV2 { if(tail_number == TailNumber::Odd) { + std::cout << "TailHandler: Odd" << std::endl; run_func(bool_constant{}, integral_constant{}); } else if(tail_number == TailNumber::Even) { + std::cout << "TailHandler: Even" << std::endl; run_func(bool_constant{}, integral_constant{}); } } @@ -74,8 +80,14 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp; static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp; - static constexpr index_t GetVectorSizeA() { return Problem::VectorSizeA; } - static constexpr index_t GetVectorSizeB() { return Problem::VectorSizeB; } + static constexpr index_t GetVectorSizeA() { + return PipelinePolicy::template GetVectorSizeA(); + //return Problem::VectorSizeA; + } + static constexpr index_t GetVectorSizeB() { + return PipelinePolicy::template GetVectorSizeB(); + //return Problem::VectorSizeB; + } static constexpr bool kPadM = Problem::kPadM; static constexpr bool kPadN = Problem::kPadN; @@ -127,7 +139,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 // clang-format on } - static constexpr bool DoubleSmemBuffer = true; + static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; static constexpr index_t Preshuffle = Problem::Preshuffle; using Base::UsePersistentKernel; @@ -226,6 +238,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 #if defined(__gfx950__) if constexpr(kMPerBlock == 128 && kNPerBlock == 256 && kKPerBlock == 256) { + //printf("Inside gfx950, with 16x16 128x256x256 \n"); static_for<0, 2, 1>{}([&](auto j) { ignore = j; static_for<0, 3, 1>{}([&](auto i) { @@ -273,6 +286,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 } else { + //printf("Inside gfx950, with 16x16 otherwise \n"); static_for<0, 2, 1>{}([&](auto j) { ignore = j; static_for<0, 3, 1>{}([&](auto i) { @@ -311,8 +325,9 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 // MFMA → MFMA → MFMA → MFMA → DS Read // For other device engine we need more agressive MFMA with DS writes interleaved #else - if constexpr(kMPerBlock == 128 && kNPerBlock == 256 && kKPerBlock == 256) + if constexpr(kMPerBlock == 128 && kNPerBlock == 256 && kKPerBlock == 256) //TODO :: 128x256x128 { + //printf("Inside gfx942, with 16x16 128x256x256 \n"); static_for<0, 2, 1>{}([&](auto j) { ignore = j; // Uses loops to amortize scheduling overhead @@ -388,6 +403,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 } else if constexpr(kMPerBlock == 16 && kNPerBlock == 64 && kKPerBlock == 256) { + //printf("Inside gfx942, with 16x16 16x64x256 \n"); static_for<0, 1, 1>{}([&](auto i) { ignore = i; __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA @@ -416,6 +432,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 } else if constexpr(kMPerBlock == 128 && kNPerBlock == 128 && kKPerBlock == 128) { + //printf("Inside gfx942, with 16x16 128x128x128 \n"); // prioritize MFMA to avoid LDS write conflicts static_for<0, 2, 1>{}([&](auto j) { ignore = j; @@ -478,6 +495,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 } else { + //printf("Inside gfx942, with 16x16 otherwise \n"); static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { ignore = i; __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read @@ -505,6 +523,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 } else { + //printf("Inside gfx950 or gfx942, with other then 16x16 any block sizes \n"); if constexpr((A_LDS_Read_Inst_Num / 2 > A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num)) {