mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
No asm ver. for merging moe blocksale fp8 into mainline
This commit is contained in:
@@ -52,15 +52,4 @@ target_compile_options(example_gemm_multiply_multiply_xdl_fp8_ab_scale PRIVATE $
|
||||
target_compile_options(example_gemm_multiply_multiply_xdl_fp8_blockscale_bpreshuffle PRIVATE ${BLOCKSCALE_GEMM_OPTIONS})
|
||||
|
||||
target_compile_options(example_moe_gemm2_xdl_fp8_blockscale PRIVATE ${BLOCKSCALE_GEMM_OPTIONS})
|
||||
target_compile_options(example_moe_gemm1_xdl_fp8_blockscale PRIVATE ${BLOCKSCALE_GEMM_OPTIONS})
|
||||
|
||||
#hacky fix for bs_moe_stage2 with rocm < 6.4
|
||||
list(APPEND gpu_list gfx942 gfx950)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
target_compile_definitions(example_moe_gemm2_xdl_fp8_blockscale PRIVATE MOE_STAGE2_ASM_DIR="${CMAKE_CURRENT_SOURCE_DIR}/hsa/${gpu}/")
|
||||
target_compile_definitions(example_moe_gemm1_xdl_fp8_blockscale PRIVATE MOE_STAGE2_ASM_DIR="${CMAKE_CURRENT_SOURCE_DIR}/hsa/${gpu}/")
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
target_compile_options(example_moe_gemm1_xdl_fp8_blockscale PRIVATE ${BLOCKSCALE_GEMM_OPTIONS})
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -169,9 +169,6 @@
|
||||
// operations
|
||||
#define CK_USE_PK4_LAYOUT_SHUFFLE 1
|
||||
|
||||
// using .co compiled shader for moe_stage2_blockscale
|
||||
#define CK_USE_ASM_MOE_BLOCKSCALE 0
|
||||
|
||||
// block synchronization only s_wait lgkmcnt(0), not vmcnt(0)
|
||||
#define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
|
||||
|
||||
|
||||
@@ -167,260 +167,4 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename Args>
|
||||
float launch_and_time_kernel_from_module(const StreamConfig& stream_config,
|
||||
std::string hsa_dir,
|
||||
std::string kernel_name,
|
||||
dim3 grid_dim,
|
||||
dim3 block_dim,
|
||||
std::size_t lds_byte,
|
||||
Args args)
|
||||
{
|
||||
hipModule_t module;
|
||||
hipFunction_t kernel_func;
|
||||
// printf("hsa_dir: %s, func: %s,\n", hsa_dir.c_str(), kernel_name.c_str());
|
||||
hip_check_error(hipModuleLoad(&module, hsa_dir.c_str()));
|
||||
hip_check_error(hipModuleGetFunction(&kernel_func, module, kernel_name.c_str()));
|
||||
auto arg_size = sizeof(args);
|
||||
auto arg_ptr = args;
|
||||
void* config[] = {reinterpret_cast<void*>(0x1),
|
||||
reinterpret_cast<void*>(&arg_ptr),
|
||||
reinterpret_cast<void*>(0x2),
|
||||
&arg_size,
|
||||
reinterpret_cast<void*>(0x3)};
|
||||
#if CK_TIME_KERNEL
|
||||
if(stream_config.time_kernel_)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
printf("%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n",
|
||||
__func__,
|
||||
grid_dim.x,
|
||||
grid_dim.y,
|
||||
grid_dim.z,
|
||||
block_dim.x,
|
||||
block_dim.y,
|
||||
block_dim.z);
|
||||
|
||||
printf("Warm up %d times\n", stream_config.cold_niters_);
|
||||
}
|
||||
// warm up
|
||||
for(int i = 0; i < stream_config.cold_niters_; ++i)
|
||||
{
|
||||
hip_check_error(hipModuleLaunchKernel(kernel_func,
|
||||
grid_dim.x,
|
||||
grid_dim.y,
|
||||
grid_dim.z,
|
||||
block_dim.x,
|
||||
block_dim.y,
|
||||
block_dim.z,
|
||||
lds_byte,
|
||||
stream_config.stream_id_,
|
||||
nullptr,
|
||||
reinterpret_cast<void**>(&config)));
|
||||
}
|
||||
|
||||
const int nrepeat = stream_config.nrepeat_;
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
printf("Start running %d times...\n", nrepeat);
|
||||
}
|
||||
hipEvent_t start, stop;
|
||||
|
||||
hip_check_error(hipEventCreate(&start));
|
||||
hip_check_error(hipEventCreate(&stop));
|
||||
|
||||
hip_check_error(hipDeviceSynchronize());
|
||||
hip_check_error(hipEventRecord(start, stream_config.stream_id_));
|
||||
|
||||
for(int i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
hip_check_error(hipModuleLaunchKernel(kernel_func,
|
||||
grid_dim.x,
|
||||
grid_dim.y,
|
||||
grid_dim.z,
|
||||
block_dim.x,
|
||||
block_dim.y,
|
||||
block_dim.z,
|
||||
lds_byte,
|
||||
stream_config.stream_id_,
|
||||
nullptr,
|
||||
reinterpret_cast<void**>(&config)));
|
||||
}
|
||||
|
||||
hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
|
||||
hip_check_error(hipEventSynchronize(stop));
|
||||
|
||||
float total_time = 0;
|
||||
|
||||
hip_check_error(hipEventElapsedTime(&total_time, start, stop));
|
||||
|
||||
hip_check_error(hipEventDestroy(start));
|
||||
hip_check_error(hipEventDestroy(stop));
|
||||
|
||||
return total_time / nrepeat;
|
||||
}
|
||||
else
|
||||
{
|
||||
hip_check_error(hipModuleLaunchKernel(kernel_func,
|
||||
grid_dim.x,
|
||||
grid_dim.y,
|
||||
grid_dim.z,
|
||||
block_dim.x,
|
||||
block_dim.y,
|
||||
block_dim.z,
|
||||
lds_byte,
|
||||
stream_config.stream_id_,
|
||||
nullptr,
|
||||
reinterpret_cast<void**>(&config)));
|
||||
|
||||
return 0;
|
||||
}
|
||||
#else
|
||||
hip_check_error(hipModuleLaunchKernel(kernel_func,
|
||||
grid_dim.x,
|
||||
grid_dim.y,
|
||||
grid_dim.z,
|
||||
block_dim.x,
|
||||
block_dim.y,
|
||||
block_dim.z,
|
||||
lds_byte,
|
||||
stream_config.stream_id_,
|
||||
nullptr,
|
||||
reinterpret_cast<void**>(&config)));
|
||||
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename Args, typename PreProcessFunc>
|
||||
float launch_and_time_kernel_from_module_with_preprocess(const StreamConfig& stream_config,
|
||||
PreProcessFunc preprocess,
|
||||
std::string hsa_dir,
|
||||
std::string kernel_name,
|
||||
dim3 grid_dim,
|
||||
dim3 block_dim,
|
||||
std::size_t lds_byte,
|
||||
Args args)
|
||||
{
|
||||
hipModule_t module;
|
||||
hipFunction_t kernel_func;
|
||||
|
||||
hip_check_error(hipModuleLoad(&module, hsa_dir.c_str()));
|
||||
hip_check_error(hipModuleGetFunction(&kernel_func, module, kernel_name.c_str()));
|
||||
auto arg_size = sizeof(args);
|
||||
auto arg_ptr = args;
|
||||
void* config[] = {reinterpret_cast<void*>(0x1),
|
||||
reinterpret_cast<void*>(&arg_ptr),
|
||||
reinterpret_cast<void*>(0x2),
|
||||
&arg_size,
|
||||
reinterpret_cast<void*>(0x3)};
|
||||
#if CK_TIME_KERNEL
|
||||
if(stream_config.time_kernel_)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
printf("%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n",
|
||||
__func__,
|
||||
grid_dim.x,
|
||||
grid_dim.y,
|
||||
grid_dim.z,
|
||||
block_dim.x,
|
||||
block_dim.y,
|
||||
block_dim.z);
|
||||
|
||||
printf("Warm up %d times\n", stream_config.cold_niters_);
|
||||
}
|
||||
// warm up
|
||||
preprocess();
|
||||
for(int i = 0; i < stream_config.cold_niters_; ++i)
|
||||
{
|
||||
hip_check_error(hipModuleLaunchKernel(kernel_func,
|
||||
grid_dim.x,
|
||||
grid_dim.y,
|
||||
grid_dim.z,
|
||||
block_dim.x,
|
||||
block_dim.y,
|
||||
block_dim.z,
|
||||
lds_byte,
|
||||
stream_config.stream_id_,
|
||||
nullptr,
|
||||
reinterpret_cast<void**>(&config)));
|
||||
}
|
||||
|
||||
const int nrepeat = stream_config.nrepeat_;
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
printf("Start running %d times...\n", nrepeat);
|
||||
}
|
||||
hipEvent_t start, stop;
|
||||
|
||||
hip_check_error(hipEventCreate(&start));
|
||||
hip_check_error(hipEventCreate(&stop));
|
||||
|
||||
hip_check_error(hipDeviceSynchronize());
|
||||
hip_check_error(hipEventRecord(start, stream_config.stream_id_));
|
||||
|
||||
for(int i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
preprocess();
|
||||
hip_check_error(hipModuleLaunchKernel(kernel_func,
|
||||
grid_dim.x,
|
||||
grid_dim.y,
|
||||
grid_dim.z,
|
||||
block_dim.x,
|
||||
block_dim.y,
|
||||
block_dim.z,
|
||||
lds_byte,
|
||||
stream_config.stream_id_,
|
||||
nullptr,
|
||||
reinterpret_cast<void**>(&config)));
|
||||
}
|
||||
|
||||
hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
|
||||
hip_check_error(hipEventSynchronize(stop));
|
||||
|
||||
float total_time = 0;
|
||||
|
||||
hip_check_error(hipEventElapsedTime(&total_time, start, stop));
|
||||
|
||||
hip_check_error(hipEventDestroy(start));
|
||||
hip_check_error(hipEventDestroy(stop));
|
||||
|
||||
return total_time / nrepeat;
|
||||
}
|
||||
else
|
||||
{
|
||||
preprocess();
|
||||
hip_check_error(hipModuleLaunchKernel(kernel_func,
|
||||
grid_dim.x,
|
||||
grid_dim.y,
|
||||
grid_dim.z,
|
||||
block_dim.x,
|
||||
block_dim.y,
|
||||
block_dim.z,
|
||||
lds_byte,
|
||||
stream_config.stream_id_,
|
||||
nullptr,
|
||||
reinterpret_cast<void**>(&config)));
|
||||
|
||||
return 0;
|
||||
}
|
||||
#else
|
||||
hip_check_error(hipModuleLaunchKernel(kernel_func,
|
||||
grid_dim.x,
|
||||
grid_dim.y,
|
||||
grid_dim.z,
|
||||
block_dim.x,
|
||||
block_dim.y,
|
||||
block_dim.z,
|
||||
lds_byte,
|
||||
stream_config.stream_id_,
|
||||
nullptr,
|
||||
reinterpret_cast<void**>(&config)));
|
||||
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -201,77 +201,6 @@ struct DeviceMoeGemmBlockScale
|
||||
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
|
||||
#if CK_USE_ASM_MOE_BLOCKSCALE
|
||||
const auto RunKernel = [&](const auto& hsa, const auto& kernel_name) {
|
||||
// printf("Loading hip kernel\n");
|
||||
#ifndef MOE_STAGE2_ASM_DIR
|
||||
printf("Failed to get moe_asm_dir.\n");
|
||||
return;
|
||||
#endif
|
||||
if(stream_config.flush_cache)
|
||||
{
|
||||
|
||||
std::array<std::size_t, NumDTensor> DsSize;
|
||||
|
||||
Argument arg_ = arg;
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
|
||||
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
|
||||
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
|
||||
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
|
||||
|
||||
auto size_a_buffer =
|
||||
a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType);
|
||||
auto size_b_buffer =
|
||||
b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType);
|
||||
|
||||
const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
|
||||
arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
|
||||
});
|
||||
ck::utility::RotatingMemWrapperMultiD<Argument, DsDataType> rotating_mem(
|
||||
arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck::utility::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(arg_.KBatch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
|
||||
0,
|
||||
arg_.M * arg_.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
};
|
||||
ave_time = launch_and_time_kernel_from_module_with_preprocess(
|
||||
stream_config,
|
||||
run_flush_cache,
|
||||
std::string(MOE_STAGE2_ASM_DIR) + hsa,
|
||||
kernel_name,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg);
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time =
|
||||
launch_and_time_kernel_from_module(stream_config,
|
||||
std::string(MOE_STAGE2_ASM_DIR) + hsa,
|
||||
kernel_name,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg);
|
||||
}
|
||||
};
|
||||
#else
|
||||
const auto RunKernel = [&](const auto& kernel) {
|
||||
if(stream_config.flush_cache)
|
||||
{
|
||||
@@ -335,7 +264,6 @@ struct DeviceMoeGemmBlockScale
|
||||
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
// constexpr auto estimated_reg_a = MPerBlock * KPerBlock * sizeof(ADataType) /
|
||||
// BlockSize /
|
||||
@@ -352,137 +280,6 @@ struct DeviceMoeGemmBlockScale
|
||||
constexpr auto MemoryDataOp =
|
||||
IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd;
|
||||
|
||||
#if CK_USE_ASM_MOE_BLOCKSCALE
|
||||
(void)minimum_occupancy;
|
||||
(void)MemoryDataOp;
|
||||
// do_weight stage check
|
||||
if(MulRoutedWeight == IsInputGemm)
|
||||
{
|
||||
throw std::runtime_error("MOE_BS_ASM Faild: Only gemm2 can do weight.\n");
|
||||
}
|
||||
// get .co file name for ASM. select by version and shape.
|
||||
std::string hsa_name = "";
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if(IsInputGemm)
|
||||
{
|
||||
if constexpr(MPerBlock == 32)
|
||||
{
|
||||
hsa_name = std::string("moe_bs_stage1_v1_32x128x128");
|
||||
}
|
||||
else if constexpr(MPerBlock == 64)
|
||||
{
|
||||
hsa_name = std::string("moe_bs_stage1_v1_64x128x128");
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"MOE_BS_ASM Faild: Gemm2 only support 32x128x128 or 64x128x1288.\n");
|
||||
}
|
||||
if constexpr(ActivationOP == Activation::silu_and_mul)
|
||||
{
|
||||
hsa_name += "_silu";
|
||||
}
|
||||
else if constexpr(ActivationOP == Activation::gelu_and_mul)
|
||||
{
|
||||
hsa_name += "_gelu";
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"MOE_BS_ASM Faild: Gemm1 ACT only support silu or gelu.\n");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(MPerBlock == 32)
|
||||
{
|
||||
hsa_name = std::string("moe_bs_stage2_v1_32x128x128");
|
||||
}
|
||||
else if constexpr(MPerBlock == 128)
|
||||
{
|
||||
hsa_name = std::string("moe_bs_stage2_v1_128x128x128");
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"MOE_BS_ASM Faild: Gemm2 only support 32x128x128 or 128x128x1288.\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if(IsInputGemm)
|
||||
{
|
||||
if constexpr(MPerBlock == 64)
|
||||
{
|
||||
hsa_name = std::string("moe_bs_stage1_v3_64x128x128");
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"MOE_BS_ASM Faild: v3 only support 64x128x1288.\n");
|
||||
}
|
||||
if constexpr(ActivationOP == Activation::silu_and_mul)
|
||||
{
|
||||
hsa_name += "_silu";
|
||||
}
|
||||
else if constexpr(ActivationOP == Activation::gelu_and_mul)
|
||||
{
|
||||
hsa_name += "_gelu";
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"MOE_BS_ASM Faild: Gemm1 ACT only support silu or gelu.\n");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(MPerBlock == 128)
|
||||
{
|
||||
hsa_name = std::string("moe_bs_stage2_v3_128x128x128");
|
||||
}
|
||||
else if constexpr(MPerBlock == 64)
|
||||
{
|
||||
hsa_name = std::string("moe_bs_stage2_v3_64x128x128");
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"MOE_BS_ASM Faild: v3 only support 128x128x128 or 64x128x128.\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("MOE_BS_ASM Faild: only support v1 or v3.\n");
|
||||
}
|
||||
// launch kernel
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
RunKernel(hsa_name + ".co", std::string("odd_loop"));
|
||||
}
|
||||
else
|
||||
{
|
||||
RunKernel(hsa_name + ".co", std::string("even_loop"));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Tail number always 1
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
RunKernel(hsa_name + ".co", std::string("odd_noloop"));
|
||||
}
|
||||
else
|
||||
{
|
||||
RunKernel(hsa_name + ".co", std::string("even_noloop"));
|
||||
}
|
||||
}
|
||||
#else
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
// Tail number always full
|
||||
@@ -584,7 +381,6 @@ struct DeviceMoeGemmBlockScale
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
return ave_time;
|
||||
|
||||
Reference in New Issue
Block a user