Add .co direct asm support by CK_USE_ASM_MOE_STAGE2_BLOCKSCALE

This commit is contained in:
OscarXu
2025-04-30 22:05:54 +08:00
parent fe1648f0ce
commit efbb85be2a
6 changed files with 190 additions and 25 deletions

View File

@@ -19,4 +19,15 @@ endforeach()
set(EXAMPLE_COMPILE_OPTIONS)
list(APPEND EXAMPLE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
list(APPEND EXAMPLE_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --schedmodel=0 -mllvm -misched=gcn-iterative-max-occupancy-experimental")
target_compile_options(example_moe_gemm2_xdl_fp8_blockscale PRIVATE ${EXAMPLE_COMPILE_OPTIONS})
target_compile_options(example_moe_gemm2_xdl_fp8_blockscale PRIVATE ${EXAMPLE_COMPILE_OPTIONS})
#hacky fix for bs_moe_stage2 with rocm < 6.4
add_custom_command(
TARGET example_moe_gemm2_xdl_fp8_blockscale
PRE_BUILD
COMMAND ${CMAKE_COMMAND} -E copy_directory
${CMAKE_CURRENT_SOURCE_DIR}/hsa/
${CMAKE_CURRENT_BINARY_DIR}/hsa/
)
target_compile_definitions(example_moe_gemm2_xdl_fp8_blockscale PRIVATE MOE_STAGE2_ASM_DIR="${CMAKE_CURRENT_BINARY_DIR}/hsa/")

View File

@@ -169,6 +169,9 @@
// operations
#define CK_USE_PK4_LAYOUT_SHUFFLE 1
// using .co compiled shader for moe_stage2_blockscale
#define CK_USE_ASM_MOE_STAGE2_BLOCKSCALE 1
// block synchronization only s_wait lgkmcnt(0), not vmcnt(0)
#define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1

View File

@@ -5,6 +5,7 @@
#include <iostream>
#include <sstream>
#include <hip/hip_runtime.h>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
@@ -180,6 +181,94 @@ struct DeviceMoeGemmBlockScale
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
#if CK_USE_ASM_MOE_STAGE2_BLOCKSCALE
const auto RunKernel = [&](const auto& hsa, const auto& kernel_name) {
// printf("Loading hip kernel\n");
#ifndef MOE_STAGE2_ASM_DIR
printf("Failed to get moe_asm_dir.\n");
return;
#endif
hipModule_t module;
hipFunction_t kernel_func;
auto status = hipModuleLoad(&module, (std::string(MOE_STAGE2_ASM_DIR) + hsa).c_str());
if(status != hipSuccess)
{
printf("Failed to load module (%s): %s.\n", hsa.c_str(), hipGetErrorString(status));
return;
}
status = hipModuleGetFunction(&kernel_func, module, kernel_name.c_str());
if(hipSuccess != status)
{
printf("Failed to get function (%s): %s.\n", kernel_name.c_str(), hipGetErrorString(status));
return;
}
auto arg_size = sizeof(arg);
auto arg_ptr = arg;
// // RunKernel(impl_ptr);
void* config[] = {reinterpret_cast<void*>(0x1),
reinterpret_cast<void*>(&arg_ptr),
reinterpret_cast<void*>(0x2),
&arg_size,
reinterpret_cast<void*>(0x3)};
if(stream_config.time_kernel_)
{
// time kernel
hipEvent_t start, stop;
hip_check_error(hipEventCreate(&start));
hip_check_error(hipEventCreate(&stop));
hip_check_error(hipDeviceSynchronize());
hip_check_error(hipEventRecord(start, stream_config.stream_id_));
status = hipModuleLaunchKernel(kernel_func,
gdx,
gdy,
1,
256,
1,
1,
0,
stream_config.stream_id_,
nullptr,
reinterpret_cast<void**>(&config));
if(hipSuccess != status)
{
printf("Failed to Luach Kernel: %s\n", hipGetErrorString(status));
return;
}
hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
hip_check_error(hipEventSynchronize(stop));
float total_time = 0;
hip_check_error(hipEventElapsedTime(&total_time, start, stop));
hip_check_error(hipEventDestroy(start));
hip_check_error(hipEventDestroy(stop));
ave_time = total_time;
}
else{
status = hipModuleLaunchKernel(kernel_func,
gdx,
gdy,
1,
256,
1,
1,
0,
stream_config.stream_id_,
nullptr,
reinterpret_cast<void**>(&config));
if(hipSuccess != status)
{
printf("Failed to Luach Kernel: %s\n", hipGetErrorString(status));
return;
}
}
};
#else
const auto RunKernel = [&](const auto& kernel) {
if(stream_config.flush_cache)
{
@@ -243,6 +332,7 @@ struct DeviceMoeGemmBlockScale
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
}
};
#endif
constexpr auto estimated_reg_a = MPerBlock * KPerBlock * sizeof(ADataType) / BlockSize /
4 * (1 + GridwiseGemm::NWave);
@@ -257,6 +347,66 @@ struct DeviceMoeGemmBlockScale
constexpr auto MemoryDataOp =
IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd;
#if CK_USE_ASM_MOE_STAGE2_BLOCKSCALE
(void)minimum_occupancy;
(void)MemoryDataOp;
//get .co file name for ASM. select by version and shape.
std::string hsa_name = "";
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if constexpr(MPerBlock == 32)
{
hsa_name = std::string("moe_bs_stage2_v1_32x128x256");
}
else if constexpr(MPerBlock == 128){
hsa_name = std::string("moe_bs_stage2_v1_128x128x128");
}
else{
printf("Faild: only support 32x128x256 or 128x128x1288.\n");
}
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if constexpr(MPerBlock == 128)
{
hsa_name = std::string("moe_bs_stage2_v3_128x128x128");
}
// else if constexpr(MPerBlock == 32){
// hsa_name = std::string("moe_bs_stage2_v3_32x128x256");
// }
else{
printf("Faild: v3 only support 128x128x1288.\n");
}
}
else{
printf("Faild: only support v1 or v3.\n");
}
//launch kernel
if(has_main_k_block_loop)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
RunKernel(hsa_name+".co", hsa_name+"_odd_loop");
}
else
{
RunKernel(hsa_name+".co", hsa_name+"_even_loop");
}
}
else
{
// Tail number always 1
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
RunKernel(hsa_name+".co", hsa_name+"_odd_noloop");
}
else
{
RunKernel(hsa_name+".co", hsa_name+"_even_noloop");
}
}
#else
if(has_main_k_block_loop)
{
// Tail number always full
@@ -341,31 +491,32 @@ struct DeviceMoeGemmBlockScale
RunKernel(kernel);
}
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
false,
MemoryDataOp,
minimum_occupancy,
IsInputGemm,
TailNumber::Odd>;
RunKernel(kernel);
}
else
{
const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
false,
MemoryDataOp,
minimum_occupancy,
IsInputGemm,
TailNumber::Even>;
RunKernel(kernel);
}
}
// else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 ||
// BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
// {
// if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
// {
// const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
// false,
// MemoryDataOp,
// minimum_occupancy,
// IsInputGemm,
// TailNumber::Odd>;
// RunKernel(kernel);
// }
// else
// {
// const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
// false,
// MemoryDataOp,
// minimum_occupancy,
// IsInputGemm,
// TailNumber::Even>;
// RunKernel(kernel);
// }
// }
}
#endif
#endif
return ave_time;