No asm ver. for merging moe blocksale fp8 into mainline

This commit is contained in:
OscarXu
2025-05-29 03:38:56 -05:00
parent 52d68c9529
commit 6be76c53b6
20 changed files with 1 additions and 475 deletions

View File

@@ -52,15 +52,4 @@ target_compile_options(example_gemm_multiply_multiply_xdl_fp8_ab_scale PRIVATE $
target_compile_options(example_gemm_multiply_multiply_xdl_fp8_blockscale_bpreshuffle PRIVATE ${BLOCKSCALE_GEMM_OPTIONS})
target_compile_options(example_moe_gemm2_xdl_fp8_blockscale PRIVATE ${BLOCKSCALE_GEMM_OPTIONS})
target_compile_options(example_moe_gemm1_xdl_fp8_blockscale PRIVATE ${BLOCKSCALE_GEMM_OPTIONS})
#hacky fix for bs_moe_stage2 with rocm < 6.4
list(APPEND gpu_list gfx942 gfx950)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
target_compile_definitions(example_moe_gemm2_xdl_fp8_blockscale PRIVATE MOE_STAGE2_ASM_DIR="${CMAKE_CURRENT_SOURCE_DIR}/hsa/${gpu}/")
target_compile_definitions(example_moe_gemm1_xdl_fp8_blockscale PRIVATE MOE_STAGE2_ASM_DIR="${CMAKE_CURRENT_SOURCE_DIR}/hsa/${gpu}/")
set(target 1)
endif()
endforeach()
target_compile_options(example_moe_gemm1_xdl_fp8_blockscale PRIVATE ${BLOCKSCALE_GEMM_OPTIONS})

View File

@@ -169,9 +169,6 @@
// operations
#define CK_USE_PK4_LAYOUT_SHUFFLE 1
// using .co compiled shader for moe_stage2_blockscale
#define CK_USE_ASM_MOE_BLOCKSCALE 0
// block synchronization only s_wait lgkmcnt(0), not vmcnt(0)
#define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1

View File

@@ -167,260 +167,4 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
return 0;
#endif
}
template <typename Args>
float launch_and_time_kernel_from_module(const StreamConfig& stream_config,
std::string hsa_dir,
std::string kernel_name,
dim3 grid_dim,
dim3 block_dim,
std::size_t lds_byte,
Args args)
{
hipModule_t module;
hipFunction_t kernel_func;
// printf("hsa_dir: %s, func: %s,\n", hsa_dir.c_str(), kernel_name.c_str());
hip_check_error(hipModuleLoad(&module, hsa_dir.c_str()));
hip_check_error(hipModuleGetFunction(&kernel_func, module, kernel_name.c_str()));
auto arg_size = sizeof(args);
auto arg_ptr = args;
void* config[] = {reinterpret_cast<void*>(0x1),
reinterpret_cast<void*>(&arg_ptr),
reinterpret_cast<void*>(0x2),
&arg_size,
reinterpret_cast<void*>(0x3)};
#if CK_TIME_KERNEL
if(stream_config.time_kernel_)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
printf("%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n",
__func__,
grid_dim.x,
grid_dim.y,
grid_dim.z,
block_dim.x,
block_dim.y,
block_dim.z);
printf("Warm up %d times\n", stream_config.cold_niters_);
}
// warm up
for(int i = 0; i < stream_config.cold_niters_; ++i)
{
hip_check_error(hipModuleLaunchKernel(kernel_func,
grid_dim.x,
grid_dim.y,
grid_dim.z,
block_dim.x,
block_dim.y,
block_dim.z,
lds_byte,
stream_config.stream_id_,
nullptr,
reinterpret_cast<void**>(&config)));
}
const int nrepeat = stream_config.nrepeat_;
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
printf("Start running %d times...\n", nrepeat);
}
hipEvent_t start, stop;
hip_check_error(hipEventCreate(&start));
hip_check_error(hipEventCreate(&stop));
hip_check_error(hipDeviceSynchronize());
hip_check_error(hipEventRecord(start, stream_config.stream_id_));
for(int i = 0; i < nrepeat; ++i)
{
hip_check_error(hipModuleLaunchKernel(kernel_func,
grid_dim.x,
grid_dim.y,
grid_dim.z,
block_dim.x,
block_dim.y,
block_dim.z,
lds_byte,
stream_config.stream_id_,
nullptr,
reinterpret_cast<void**>(&config)));
}
hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
hip_check_error(hipEventSynchronize(stop));
float total_time = 0;
hip_check_error(hipEventElapsedTime(&total_time, start, stop));
hip_check_error(hipEventDestroy(start));
hip_check_error(hipEventDestroy(stop));
return total_time / nrepeat;
}
else
{
hip_check_error(hipModuleLaunchKernel(kernel_func,
grid_dim.x,
grid_dim.y,
grid_dim.z,
block_dim.x,
block_dim.y,
block_dim.z,
lds_byte,
stream_config.stream_id_,
nullptr,
reinterpret_cast<void**>(&config)));
return 0;
}
#else
hip_check_error(hipModuleLaunchKernel(kernel_func,
grid_dim.x,
grid_dim.y,
grid_dim.z,
block_dim.x,
block_dim.y,
block_dim.z,
lds_byte,
stream_config.stream_id_,
nullptr,
reinterpret_cast<void**>(&config)));
return 0;
#endif
}
template <typename Args, typename PreProcessFunc>
float launch_and_time_kernel_from_module_with_preprocess(const StreamConfig& stream_config,
PreProcessFunc preprocess,
std::string hsa_dir,
std::string kernel_name,
dim3 grid_dim,
dim3 block_dim,
std::size_t lds_byte,
Args args)
{
hipModule_t module;
hipFunction_t kernel_func;
hip_check_error(hipModuleLoad(&module, hsa_dir.c_str()));
hip_check_error(hipModuleGetFunction(&kernel_func, module, kernel_name.c_str()));
auto arg_size = sizeof(args);
auto arg_ptr = args;
void* config[] = {reinterpret_cast<void*>(0x1),
reinterpret_cast<void*>(&arg_ptr),
reinterpret_cast<void*>(0x2),
&arg_size,
reinterpret_cast<void*>(0x3)};
#if CK_TIME_KERNEL
if(stream_config.time_kernel_)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
printf("%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n",
__func__,
grid_dim.x,
grid_dim.y,
grid_dim.z,
block_dim.x,
block_dim.y,
block_dim.z);
printf("Warm up %d times\n", stream_config.cold_niters_);
}
// warm up
preprocess();
for(int i = 0; i < stream_config.cold_niters_; ++i)
{
hip_check_error(hipModuleLaunchKernel(kernel_func,
grid_dim.x,
grid_dim.y,
grid_dim.z,
block_dim.x,
block_dim.y,
block_dim.z,
lds_byte,
stream_config.stream_id_,
nullptr,
reinterpret_cast<void**>(&config)));
}
const int nrepeat = stream_config.nrepeat_;
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
printf("Start running %d times...\n", nrepeat);
}
hipEvent_t start, stop;
hip_check_error(hipEventCreate(&start));
hip_check_error(hipEventCreate(&stop));
hip_check_error(hipDeviceSynchronize());
hip_check_error(hipEventRecord(start, stream_config.stream_id_));
for(int i = 0; i < nrepeat; ++i)
{
preprocess();
hip_check_error(hipModuleLaunchKernel(kernel_func,
grid_dim.x,
grid_dim.y,
grid_dim.z,
block_dim.x,
block_dim.y,
block_dim.z,
lds_byte,
stream_config.stream_id_,
nullptr,
reinterpret_cast<void**>(&config)));
}
hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
hip_check_error(hipEventSynchronize(stop));
float total_time = 0;
hip_check_error(hipEventElapsedTime(&total_time, start, stop));
hip_check_error(hipEventDestroy(start));
hip_check_error(hipEventDestroy(stop));
return total_time / nrepeat;
}
else
{
preprocess();
hip_check_error(hipModuleLaunchKernel(kernel_func,
grid_dim.x,
grid_dim.y,
grid_dim.z,
block_dim.x,
block_dim.y,
block_dim.z,
lds_byte,
stream_config.stream_id_,
nullptr,
reinterpret_cast<void**>(&config)));
return 0;
}
#else
hip_check_error(hipModuleLaunchKernel(kernel_func,
grid_dim.x,
grid_dim.y,
grid_dim.z,
block_dim.x,
block_dim.y,
block_dim.z,
lds_byte,
stream_config.stream_id_,
nullptr,
reinterpret_cast<void**>(&config)));
return 0;
#endif
}
#endif

View File

@@ -201,77 +201,6 @@ struct DeviceMoeGemmBlockScale
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
#if CK_USE_ASM_MOE_BLOCKSCALE
const auto RunKernel = [&](const auto& hsa, const auto& kernel_name) {
// printf("Loading hip kernel\n");
#ifndef MOE_STAGE2_ASM_DIR
printf("Failed to get moe_asm_dir.\n");
return;
#endif
if(stream_config.flush_cache)
{
std::array<std::size_t, NumDTensor> DsSize;
Argument arg_ = arg;
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
auto size_a_buffer =
a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType);
auto size_b_buffer =
b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType);
const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
});
ck::utility::RotatingMemWrapperMultiD<Argument, DsDataType> rotating_mem(
arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize);
rotating_mem.Print();
auto run_flush_cache = [&]() {
// flush icache
ck::utility::flush_icache();
// rotating mem
rotating_mem.Next();
// clear c mem
if(arg_.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
0,
arg_.M * arg_.N * sizeof(CDataType),
stream_config.stream_id_));
};
ave_time = launch_and_time_kernel_from_module_with_preprocess(
stream_config,
run_flush_cache,
std::string(MOE_STAGE2_ASM_DIR) + hsa,
kernel_name,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg);
}
else
{
ave_time =
launch_and_time_kernel_from_module(stream_config,
std::string(MOE_STAGE2_ASM_DIR) + hsa,
kernel_name,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg);
}
};
#else
const auto RunKernel = [&](const auto& kernel) {
if(stream_config.flush_cache)
{
@@ -335,7 +264,6 @@ struct DeviceMoeGemmBlockScale
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
}
};
#endif
// constexpr auto estimated_reg_a = MPerBlock * KPerBlock * sizeof(ADataType) /
// BlockSize /
@@ -352,137 +280,6 @@ struct DeviceMoeGemmBlockScale
constexpr auto MemoryDataOp =
IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd;
#if CK_USE_ASM_MOE_BLOCKSCALE
(void)minimum_occupancy;
(void)MemoryDataOp;
// do_weight stage check
if(MulRoutedWeight == IsInputGemm)
{
throw std::runtime_error("MOE_BS_ASM Faild: Only gemm2 can do weight.\n");
}
// get .co file name for ASM. select by version and shape.
std::string hsa_name = "";
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(IsInputGemm)
{
if constexpr(MPerBlock == 32)
{
hsa_name = std::string("moe_bs_stage1_v1_32x128x128");
}
else if constexpr(MPerBlock == 64)
{
hsa_name = std::string("moe_bs_stage1_v1_64x128x128");
}
else
{
throw std::runtime_error(
"MOE_BS_ASM Faild: Gemm2 only support 32x128x128 or 64x128x1288.\n");
}
if constexpr(ActivationOP == Activation::silu_and_mul)
{
hsa_name += "_silu";
}
else if constexpr(ActivationOP == Activation::gelu_and_mul)
{
hsa_name += "_gelu";
}
else
{
throw std::runtime_error(
"MOE_BS_ASM Faild: Gemm1 ACT only support silu or gelu.\n");
}
}
else
{
if constexpr(MPerBlock == 32)
{
hsa_name = std::string("moe_bs_stage2_v1_32x128x128");
}
else if constexpr(MPerBlock == 128)
{
hsa_name = std::string("moe_bs_stage2_v1_128x128x128");
}
else
{
throw std::runtime_error(
"MOE_BS_ASM Faild: Gemm2 only support 32x128x128 or 128x128x1288.\n");
}
}
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(IsInputGemm)
{
if constexpr(MPerBlock == 64)
{
hsa_name = std::string("moe_bs_stage1_v3_64x128x128");
}
else
{
throw std::runtime_error(
"MOE_BS_ASM Faild: v3 only support 64x128x1288.\n");
}
if constexpr(ActivationOP == Activation::silu_and_mul)
{
hsa_name += "_silu";
}
else if constexpr(ActivationOP == Activation::gelu_and_mul)
{
hsa_name += "_gelu";
}
else
{
throw std::runtime_error(
"MOE_BS_ASM Faild: Gemm1 ACT only support silu or gelu.\n");
}
}
else
{
if constexpr(MPerBlock == 128)
{
hsa_name = std::string("moe_bs_stage2_v3_128x128x128");
}
else if constexpr(MPerBlock == 64)
{
hsa_name = std::string("moe_bs_stage2_v3_64x128x128");
}
else
{
throw std::runtime_error(
"MOE_BS_ASM Faild: v3 only support 128x128x128 or 64x128x128.\n");
}
}
}
else
{
throw std::runtime_error("MOE_BS_ASM Faild: only support v1 or v3.\n");
}
// launch kernel
if(has_main_k_block_loop)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
RunKernel(hsa_name + ".co", std::string("odd_loop"));
}
else
{
RunKernel(hsa_name + ".co", std::string("even_loop"));
}
}
else
{
// Tail number always 1
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
RunKernel(hsa_name + ".co", std::string("odd_noloop"));
}
else
{
RunKernel(hsa_name + ".co", std::string("even_noloop"));
}
}
#else
if(has_main_k_block_loop)
{
// Tail number always full
@@ -584,7 +381,6 @@ struct DeviceMoeGemmBlockScale
}
}
}
#endif
#endif
return ave_time;