diff --git a/example/65_gemm_multiply_multiply/hsa/moe_bs_stage2_v1_128x128x128.co b/example/65_gemm_multiply_multiply/hsa/moe_bs_stage2_v1_128x128x128.co index 1e6fea5a85..8d740f3e94 100755 Binary files a/example/65_gemm_multiply_multiply/hsa/moe_bs_stage2_v1_128x128x128.co and b/example/65_gemm_multiply_multiply/hsa/moe_bs_stage2_v1_128x128x128.co differ diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp index 6bbd08c020..7d34749a1a 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp @@ -118,7 +118,7 @@ static constexpr ck::index_t Scale_Block_M = 1; static constexpr ck::index_t Scale_Block_N = 128; static constexpr ck::index_t Scale_Block_K = 128; -#if 1 +#if 0 static constexpr ck::index_t MPerBlock = 32; static constexpr ck::index_t BLOCKSIZE = 256; static constexpr ck::index_t MXDLPerWave = 1; @@ -161,11 +161,11 @@ static constexpr ck::index_t MPerBlock = 128; using DeviceOpInstance = ck::tenso MPerBlock, 128, 128, 16, 16, 32, 32, - 4, 1, + 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<2, 1, 1, 1>, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, false, false, A0DataType>; + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, A0DataType>; #endif // clang-format on @@ -180,11 +180,11 @@ int main(int argc, char* argv[]) // experts = 8 // per expert: - constexpr ck::index_t valid_tile_num = 52; - constexpr ck::index_t sorted_tile_num = valid_tile_num + 3; + constexpr ck::index_t valid_tile_num = 13; //13 for 128; 52 for 32; 4096 for ds // > token * topk / MPerBlock + constexpr ck::index_t sorted_tile_num = valid_tile_num;// + 3; ck::index_t sorted_size = sorted_tile_num * MPerBlock; ck::index_t valid_size = valid_tile_num * MPerBlock; -#if 0 +#if 1 // GEMM shape ck::index_t N = 6144; ck::index_t K = 4096; @@ -249,14 +249,22 @@ int main(int argc, char* argv[]) // int eids[] = {0, 1, 2, 3, 4, 5, 6, 7}; //, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} //int eids[] = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 3, 3, 3}; int eids[sorted_tile_num]{}; + int e_select = 0; for(int i = 0; i < sorted_tile_num; i++) { if (i < valid_tile_num){ - eids[i] = std::rand() % experts; + eids[i] = e_select; + //std::rand() % experts; } else{ eids[i] = 3; } + if (i > ((e_select + 1) * (sorted_tile_num / experts))){ + e_select++; + if (e_select >= experts){ + e_select = experts - 1; + } + } } // int eids[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, @@ -319,9 +327,9 @@ int main(int argc, char* argv[]) { case 0: break; case 1: - a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a0_t_k_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); - b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); d2_e_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); break; @@ -445,7 +453,7 @@ int main(int argc, char* argv[]) float gb_per_sec = num_btype / 1.E6 / ave_time; std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec - << " GB/s" << device_op.GetTypeString() << std::endl; + << " GB/s.\n" << device_op.GetTypeString() << std::endl; } if(do_verification) @@ -540,10 +548,14 @@ int main(int argc, char* argv[]) #endif // e_t_n_device_result.savetxt("out.txt"); // e_t_n_host_result.savetxt("ref.txt"); - return ck::utils::check_err( + auto status = ck::utils::check_err( e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2) ? 0 : 1; + if (status == 0){ + printf("Validation Pass.\n"); + } + return status; } return 0; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp index 143f1f85d2..e5b733d1bf 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp @@ -190,18 +190,8 @@ struct DeviceMoeGemmBlockScale #endif hipModule_t module; hipFunction_t kernel_func; - auto status = hipModuleLoad(&module, (std::string(MOE_STAGE2_ASM_DIR) + hsa).c_str()); - if(status != hipSuccess) - { - printf("Failed to load module (%s): %s.\n", hsa.c_str(), hipGetErrorString(status)); - return; - } - status = hipModuleGetFunction(&kernel_func, module, kernel_name.c_str()); - if(hipSuccess != status) - { - printf("Failed to get function (%s): %s.\n", kernel_name.c_str(), hipGetErrorString(status)); - return; - } + hip_check_error(hipModuleLoad(&module, (std::string(MOE_STAGE2_ASM_DIR) + hsa).c_str())); + hip_check_error(hipModuleGetFunction(&kernel_func, module, kernel_name.c_str())); auto arg_size = sizeof(arg); auto arg_ptr = arg; // // RunKernel(impl_ptr); @@ -221,7 +211,7 @@ struct DeviceMoeGemmBlockScale hip_check_error(hipDeviceSynchronize()); hip_check_error(hipEventRecord(start, stream_config.stream_id_)); - status = hipModuleLaunchKernel(kernel_func, + hip_check_error(hipModuleLaunchKernel(kernel_func, gdx, gdy, 1, @@ -231,12 +221,7 @@ struct DeviceMoeGemmBlockScale 0, stream_config.stream_id_, nullptr, - reinterpret_cast(&config)); - if(hipSuccess != status) - { - printf("Failed to Luach Kernel: %s\n", hipGetErrorString(status)); - return; - } + reinterpret_cast(&config))); hip_check_error(hipEventRecord(stop, stream_config.stream_id_)); hip_check_error(hipEventSynchronize(stop)); @@ -249,7 +234,7 @@ struct DeviceMoeGemmBlockScale ave_time = total_time; } else{ - status = hipModuleLaunchKernel(kernel_func, + hip_check_error(hipModuleLaunchKernel(kernel_func, gdx, gdy, 1, @@ -259,12 +244,7 @@ struct DeviceMoeGemmBlockScale 0, stream_config.stream_id_, nullptr, - reinterpret_cast(&config)); - if(hipSuccess != status) - { - printf("Failed to Luach Kernel: %s\n", hipGetErrorString(status)); - return; - } + reinterpret_cast(&config))); } };