diff --git a/example/65_gemm_multiply_multiply/CMakeLists.txt b/example/65_gemm_multiply_multiply/CMakeLists.txt index f654550417..6a1a2582f6 100644 --- a/example/65_gemm_multiply_multiply/CMakeLists.txt +++ b/example/65_gemm_multiply_multiply/CMakeLists.txt @@ -19,4 +19,15 @@ endforeach() set(EXAMPLE_COMPILE_OPTIONS) list(APPEND EXAMPLE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker) list(APPEND EXAMPLE_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --schedmodel=0 -mllvm -misched=gcn-iterative-max-occupancy-experimental") -target_compile_options(example_moe_gemm2_xdl_fp8_blockscale PRIVATE ${EXAMPLE_COMPILE_OPTIONS}) \ No newline at end of file +target_compile_options(example_moe_gemm2_xdl_fp8_blockscale PRIVATE ${EXAMPLE_COMPILE_OPTIONS}) + +#hacky fix for bs_moe_stage2 with rocm < 6.4 +add_custom_command( + TARGET example_moe_gemm2_xdl_fp8_blockscale + PRE_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_directory + ${CMAKE_CURRENT_SOURCE_DIR}/hsa/ + ${CMAKE_CURRENT_BINARY_DIR}/hsa/ +) + +target_compile_definitions(example_moe_gemm2_xdl_fp8_blockscale PRIVATE MOE_STAGE2_ASM_DIR="${CMAKE_CURRENT_BINARY_DIR}/hsa/") \ No newline at end of file diff --git a/example/65_gemm_multiply_multiply/hsa/moe_bs_stage2_v1_128x128x128.co b/example/65_gemm_multiply_multiply/hsa/moe_bs_stage2_v1_128x128x128.co new file mode 100755 index 0000000000..1e6fea5a85 Binary files /dev/null and b/example/65_gemm_multiply_multiply/hsa/moe_bs_stage2_v1_128x128x128.co differ diff --git a/example/65_gemm_multiply_multiply/hsa/moe_bs_stage2_v1_32x128x256.co b/example/65_gemm_multiply_multiply/hsa/moe_bs_stage2_v1_32x128x256.co new file mode 100755 index 0000000000..821bfa77f4 Binary files /dev/null and b/example/65_gemm_multiply_multiply/hsa/moe_bs_stage2_v1_32x128x256.co differ diff --git a/example/65_gemm_multiply_multiply/hsa/moe_bs_stage2_v3_128x128x128.co b/example/65_gemm_multiply_multiply/hsa/moe_bs_stage2_v3_128x128x128.co new file mode 100755 index 0000000000..128657224d Binary files /dev/null and b/example/65_gemm_multiply_multiply/hsa/moe_bs_stage2_v3_128x128x128.co differ diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index 0c2dc799ab..fae4178941 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -169,6 +169,9 @@ // operations #define CK_USE_PK4_LAYOUT_SHUFFLE 1 +// using .co compiled shader for moe_stage2_blockscale +#define CK_USE_ASM_MOE_STAGE2_BLOCKSCALE 1 + // block synchronization only s_wait lgkmcnt(0), not vmcnt(0) #define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1 diff --git a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp index f707f1600b..a8466a311b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp @@ -5,6 +5,7 @@ #include #include +#include #include "ck/utility/common_header.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" @@ -180,6 +181,94 @@ struct DeviceMoeGemmBlockScale const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + #if CK_USE_ASM_MOE_STAGE2_BLOCKSCALE + const auto RunKernel = [&](const auto& hsa, const auto& kernel_name) { + // printf("Loading hip kernel\n"); + #ifndef MOE_STAGE2_ASM_DIR + printf("Failed to get moe_asm_dir.\n"); + return; + #endif + hipModule_t module; + hipFunction_t kernel_func; + auto status = hipModuleLoad(&module, (std::string(MOE_STAGE2_ASM_DIR) + hsa).c_str()); + if(status != hipSuccess) + { + printf("Failed to load module (%s): %s.\n", hsa.c_str(), hipGetErrorString(status)); + return; + } + status = hipModuleGetFunction(&kernel_func, module, kernel_name.c_str()); + if(hipSuccess != status) + { + printf("Failed to get function (%s): %s.\n", kernel_name.c_str(), hipGetErrorString(status)); + return; + } + auto arg_size = sizeof(arg); + auto arg_ptr = arg; + // // RunKernel(impl_ptr); + void* config[] = {reinterpret_cast(0x1), + reinterpret_cast(&arg_ptr), + reinterpret_cast(0x2), + &arg_size, + reinterpret_cast(0x3)}; + + if(stream_config.time_kernel_) + { + // time kernel + hipEvent_t start, stop; + hip_check_error(hipEventCreate(&start)); + hip_check_error(hipEventCreate(&stop)); + + hip_check_error(hipDeviceSynchronize()); + hip_check_error(hipEventRecord(start, stream_config.stream_id_)); + + status = hipModuleLaunchKernel(kernel_func, + gdx, + gdy, + 1, + 256, + 1, + 1, + 0, + stream_config.stream_id_, + nullptr, + reinterpret_cast(&config)); + if(hipSuccess != status) + { + printf("Failed to Luach Kernel: %s\n", hipGetErrorString(status)); + return; + } + hip_check_error(hipEventRecord(stop, stream_config.stream_id_)); + hip_check_error(hipEventSynchronize(stop)); + + float total_time = 0; + + hip_check_error(hipEventElapsedTime(&total_time, start, stop)); + + hip_check_error(hipEventDestroy(start)); + hip_check_error(hipEventDestroy(stop)); + ave_time = total_time; + } + else{ + status = hipModuleLaunchKernel(kernel_func, + gdx, + gdy, + 1, + 256, + 1, + 1, + 0, + stream_config.stream_id_, + nullptr, + reinterpret_cast(&config)); + if(hipSuccess != status) + { + printf("Failed to Luach Kernel: %s\n", hipGetErrorString(status)); + return; + } + } + + }; + #else const auto RunKernel = [&](const auto& kernel) { if(stream_config.flush_cache) { @@ -243,6 +332,7 @@ struct DeviceMoeGemmBlockScale stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); } }; + #endif constexpr auto estimated_reg_a = MPerBlock * KPerBlock * sizeof(ADataType) / BlockSize / 4 * (1 + GridwiseGemm::NWave); @@ -257,6 +347,66 @@ struct DeviceMoeGemmBlockScale constexpr auto MemoryDataOp = IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd; + + #if CK_USE_ASM_MOE_STAGE2_BLOCKSCALE + (void)minimum_occupancy; + (void)MemoryDataOp; + //get .co file name for ASM. select by version and shape. + std::string hsa_name = ""; + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if constexpr(MPerBlock == 32) + { + hsa_name = std::string("moe_bs_stage2_v1_32x128x256"); + } + else if constexpr(MPerBlock == 128){ + hsa_name = std::string("moe_bs_stage2_v1_128x128x128"); + } + else{ + printf("Faild: only support 32x128x256 or 128x128x1288.\n"); + } + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if constexpr(MPerBlock == 128) + { + hsa_name = std::string("moe_bs_stage2_v3_128x128x128"); + } + // else if constexpr(MPerBlock == 32){ + // hsa_name = std::string("moe_bs_stage2_v3_32x128x256"); + // } + else{ + printf("Faild: v3 only support 128x128x1288.\n"); + } + } + else{ + printf("Faild: only support v1 or v3.\n"); + } + //launch kernel + if(has_main_k_block_loop) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + RunKernel(hsa_name+".co", hsa_name+"_odd_loop"); + } + else + { + RunKernel(hsa_name+".co", hsa_name+"_even_loop"); + } + } + else + { + // Tail number always 1 + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + RunKernel(hsa_name+".co", hsa_name+"_odd_noloop"); + } + else + { + RunKernel(hsa_name+".co", hsa_name+"_even_noloop"); + } + } + #else if(has_main_k_block_loop) { // Tail number always full @@ -341,31 +491,32 @@ struct DeviceMoeGemmBlockScale RunKernel(kernel); } } - else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 || - BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) - { - const auto kernel = kernel_moe_gemm_2lds; - RunKernel(kernel); - } - else - { - const auto kernel = kernel_moe_gemm_2lds; - RunKernel(kernel); - } - } + // else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 || + // BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + // { + // if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + // { + // const auto kernel = kernel_moe_gemm_2lds; + // RunKernel(kernel); + // } + // else + // { + // const auto kernel = kernel_moe_gemm_2lds; + // RunKernel(kernel); + // } + // } } +#endif #endif return ave_time;