mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
Add .co direct asm support by CK_USE_ASM_MOE_STAGE2_BLOCKSCALE
This commit is contained in:
@@ -19,4 +19,15 @@ endforeach()
|
||||
set(EXAMPLE_COMPILE_OPTIONS)
|
||||
list(APPEND EXAMPLE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
|
||||
list(APPEND EXAMPLE_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --schedmodel=0 -mllvm -misched=gcn-iterative-max-occupancy-experimental")
|
||||
target_compile_options(example_moe_gemm2_xdl_fp8_blockscale PRIVATE ${EXAMPLE_COMPILE_OPTIONS})
|
||||
target_compile_options(example_moe_gemm2_xdl_fp8_blockscale PRIVATE ${EXAMPLE_COMPILE_OPTIONS})
|
||||
|
||||
#hacky fix for bs_moe_stage2 with rocm < 6.4
|
||||
add_custom_command(
|
||||
TARGET example_moe_gemm2_xdl_fp8_blockscale
|
||||
PRE_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E copy_directory
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/hsa/
|
||||
${CMAKE_CURRENT_BINARY_DIR}/hsa/
|
||||
)
|
||||
|
||||
target_compile_definitions(example_moe_gemm2_xdl_fp8_blockscale PRIVATE MOE_STAGE2_ASM_DIR="${CMAKE_CURRENT_BINARY_DIR}/hsa/")
|
||||
BIN
example/65_gemm_multiply_multiply/hsa/moe_bs_stage2_v1_128x128x128.co
Executable file
BIN
example/65_gemm_multiply_multiply/hsa/moe_bs_stage2_v1_128x128x128.co
Executable file
Binary file not shown.
BIN
example/65_gemm_multiply_multiply/hsa/moe_bs_stage2_v1_32x128x256.co
Executable file
BIN
example/65_gemm_multiply_multiply/hsa/moe_bs_stage2_v1_32x128x256.co
Executable file
Binary file not shown.
BIN
example/65_gemm_multiply_multiply/hsa/moe_bs_stage2_v3_128x128x128.co
Executable file
BIN
example/65_gemm_multiply_multiply/hsa/moe_bs_stage2_v3_128x128x128.co
Executable file
Binary file not shown.
@@ -169,6 +169,9 @@
|
||||
// operations
|
||||
#define CK_USE_PK4_LAYOUT_SHUFFLE 1
|
||||
|
||||
// using .co compiled shader for moe_stage2_blockscale
|
||||
#define CK_USE_ASM_MOE_STAGE2_BLOCKSCALE 1
|
||||
|
||||
// block synchronization only s_wait lgkmcnt(0), not vmcnt(0)
|
||||
#define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
@@ -180,6 +181,94 @@ struct DeviceMoeGemmBlockScale
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
|
||||
#if CK_USE_ASM_MOE_STAGE2_BLOCKSCALE
|
||||
const auto RunKernel = [&](const auto& hsa, const auto& kernel_name) {
|
||||
// printf("Loading hip kernel\n");
|
||||
#ifndef MOE_STAGE2_ASM_DIR
|
||||
printf("Failed to get moe_asm_dir.\n");
|
||||
return;
|
||||
#endif
|
||||
hipModule_t module;
|
||||
hipFunction_t kernel_func;
|
||||
auto status = hipModuleLoad(&module, (std::string(MOE_STAGE2_ASM_DIR) + hsa).c_str());
|
||||
if(status != hipSuccess)
|
||||
{
|
||||
printf("Failed to load module (%s): %s.\n", hsa.c_str(), hipGetErrorString(status));
|
||||
return;
|
||||
}
|
||||
status = hipModuleGetFunction(&kernel_func, module, kernel_name.c_str());
|
||||
if(hipSuccess != status)
|
||||
{
|
||||
printf("Failed to get function (%s): %s.\n", kernel_name.c_str(), hipGetErrorString(status));
|
||||
return;
|
||||
}
|
||||
auto arg_size = sizeof(arg);
|
||||
auto arg_ptr = arg;
|
||||
// // RunKernel(impl_ptr);
|
||||
void* config[] = {reinterpret_cast<void*>(0x1),
|
||||
reinterpret_cast<void*>(&arg_ptr),
|
||||
reinterpret_cast<void*>(0x2),
|
||||
&arg_size,
|
||||
reinterpret_cast<void*>(0x3)};
|
||||
|
||||
if(stream_config.time_kernel_)
|
||||
{
|
||||
// time kernel
|
||||
hipEvent_t start, stop;
|
||||
hip_check_error(hipEventCreate(&start));
|
||||
hip_check_error(hipEventCreate(&stop));
|
||||
|
||||
hip_check_error(hipDeviceSynchronize());
|
||||
hip_check_error(hipEventRecord(start, stream_config.stream_id_));
|
||||
|
||||
status = hipModuleLaunchKernel(kernel_func,
|
||||
gdx,
|
||||
gdy,
|
||||
1,
|
||||
256,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
stream_config.stream_id_,
|
||||
nullptr,
|
||||
reinterpret_cast<void**>(&config));
|
||||
if(hipSuccess != status)
|
||||
{
|
||||
printf("Failed to Luach Kernel: %s\n", hipGetErrorString(status));
|
||||
return;
|
||||
}
|
||||
hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
|
||||
hip_check_error(hipEventSynchronize(stop));
|
||||
|
||||
float total_time = 0;
|
||||
|
||||
hip_check_error(hipEventElapsedTime(&total_time, start, stop));
|
||||
|
||||
hip_check_error(hipEventDestroy(start));
|
||||
hip_check_error(hipEventDestroy(stop));
|
||||
ave_time = total_time;
|
||||
}
|
||||
else{
|
||||
status = hipModuleLaunchKernel(kernel_func,
|
||||
gdx,
|
||||
gdy,
|
||||
1,
|
||||
256,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
stream_config.stream_id_,
|
||||
nullptr,
|
||||
reinterpret_cast<void**>(&config));
|
||||
if(hipSuccess != status)
|
||||
{
|
||||
printf("Failed to Luach Kernel: %s\n", hipGetErrorString(status));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
#else
|
||||
const auto RunKernel = [&](const auto& kernel) {
|
||||
if(stream_config.flush_cache)
|
||||
{
|
||||
@@ -243,6 +332,7 @@ struct DeviceMoeGemmBlockScale
|
||||
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
constexpr auto estimated_reg_a = MPerBlock * KPerBlock * sizeof(ADataType) / BlockSize /
|
||||
4 * (1 + GridwiseGemm::NWave);
|
||||
@@ -257,6 +347,66 @@ struct DeviceMoeGemmBlockScale
|
||||
|
||||
constexpr auto MemoryDataOp =
|
||||
IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd;
|
||||
|
||||
#if CK_USE_ASM_MOE_STAGE2_BLOCKSCALE
|
||||
(void)minimum_occupancy;
|
||||
(void)MemoryDataOp;
|
||||
//get .co file name for ASM. select by version and shape.
|
||||
std::string hsa_name = "";
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if constexpr(MPerBlock == 32)
|
||||
{
|
||||
hsa_name = std::string("moe_bs_stage2_v1_32x128x256");
|
||||
}
|
||||
else if constexpr(MPerBlock == 128){
|
||||
hsa_name = std::string("moe_bs_stage2_v1_128x128x128");
|
||||
}
|
||||
else{
|
||||
printf("Faild: only support 32x128x256 or 128x128x1288.\n");
|
||||
}
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if constexpr(MPerBlock == 128)
|
||||
{
|
||||
hsa_name = std::string("moe_bs_stage2_v3_128x128x128");
|
||||
}
|
||||
// else if constexpr(MPerBlock == 32){
|
||||
// hsa_name = std::string("moe_bs_stage2_v3_32x128x256");
|
||||
// }
|
||||
else{
|
||||
printf("Faild: v3 only support 128x128x1288.\n");
|
||||
}
|
||||
}
|
||||
else{
|
||||
printf("Faild: only support v1 or v3.\n");
|
||||
}
|
||||
//launch kernel
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
RunKernel(hsa_name+".co", hsa_name+"_odd_loop");
|
||||
}
|
||||
else
|
||||
{
|
||||
RunKernel(hsa_name+".co", hsa_name+"_even_loop");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Tail number always 1
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
RunKernel(hsa_name+".co", hsa_name+"_odd_noloop");
|
||||
}
|
||||
else
|
||||
{
|
||||
RunKernel(hsa_name+".co", hsa_name+"_even_noloop");
|
||||
}
|
||||
}
|
||||
#else
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
// Tail number always full
|
||||
@@ -341,31 +491,32 @@ struct DeviceMoeGemmBlockScale
|
||||
RunKernel(kernel);
|
||||
}
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 ||
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
|
||||
false,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
IsInputGemm,
|
||||
TailNumber::Odd>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
|
||||
false,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
IsInputGemm,
|
||||
TailNumber::Even>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
}
|
||||
// else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 ||
|
||||
// BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
// {
|
||||
// if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
// {
|
||||
// const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
|
||||
// false,
|
||||
// MemoryDataOp,
|
||||
// minimum_occupancy,
|
||||
// IsInputGemm,
|
||||
// TailNumber::Odd>;
|
||||
// RunKernel(kernel);
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
|
||||
// false,
|
||||
// MemoryDataOp,
|
||||
// minimum_occupancy,
|
||||
// IsInputGemm,
|
||||
// TailNumber::Even>;
|
||||
// RunKernel(kernel);
|
||||
// }
|
||||
// }
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
return ave_time;
|
||||
|
||||
Reference in New Issue
Block a user