Update v1_128x128x128 to 2x2 instead of 4x1

This commit is contained in:
OscarXu
2025-05-07 14:24:30 +08:00
parent e7fe8587f6
commit c989bbe3aa
3 changed files with 29 additions and 37 deletions

View File

@@ -118,7 +118,7 @@ static constexpr ck::index_t Scale_Block_M = 1;
static constexpr ck::index_t Scale_Block_N = 128;
static constexpr ck::index_t Scale_Block_K = 128;
#if 1
#if 0
static constexpr ck::index_t MPerBlock = 32;
static constexpr ck::index_t BLOCKSIZE = 256;
static constexpr ck::index_t MXDLPerWave = 1;
@@ -161,11 +161,11 @@ static constexpr ck::index_t MPerBlock = 128; using DeviceOpInstance = ck::tenso
MPerBlock, 128, 128,
16, 16,
32, 32,
4, 1,
2, 2,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
1, 1, S<1, 32, 1, 8>, S<2, 1, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, false, false, A0DataType>;
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, A0DataType>;
#endif
// clang-format on
@@ -180,11 +180,11 @@ int main(int argc, char* argv[])
// experts = 8
// per expert:
constexpr ck::index_t valid_tile_num = 52;
constexpr ck::index_t sorted_tile_num = valid_tile_num + 3;
constexpr ck::index_t valid_tile_num = 13; //13 for 128; 52 for 32; 4096 for ds // > token * topk / MPerBlock
constexpr ck::index_t sorted_tile_num = valid_tile_num;// + 3;
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock;
#if 0
#if 1
// GEMM shape
ck::index_t N = 6144;
ck::index_t K = 4096;
@@ -249,14 +249,22 @@ int main(int argc, char* argv[])
// int eids[] = {0, 1, 2, 3, 4, 5, 6, 7}; //, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
//int eids[] = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 3, 3, 3};
int eids[sorted_tile_num]{};
int e_select = 0;
for(int i = 0; i < sorted_tile_num; i++)
{
if (i < valid_tile_num){
eids[i] = std::rand() % experts;
eids[i] = e_select;
//std::rand() % experts;
}
else{
eids[i] = 3;
}
if (i > ((e_select + 1) * (sorted_tile_num / experts))){
e_select++;
if (e_select >= experts){
e_select = experts - 1;
}
}
}
// int eids[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
@@ -319,9 +327,9 @@ int main(int argc, char* argv[])
{
case 0: break;
case 1:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
a0_t_k_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-0.5, 0.5});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0, 1.0});
break;
@@ -445,7 +453,7 @@ int main(int argc, char* argv[])
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s" << device_op.GetTypeString() << std::endl;
<< " GB/s.\n" << device_op.GetTypeString() << std::endl;
}
if(do_verification)
@@ -540,10 +548,14 @@ int main(int argc, char* argv[])
#endif
// e_t_n_device_result.savetxt("out.txt");
// e_t_n_host_result.savetxt("ref.txt");
return ck::utils::check_err(
auto status = ck::utils::check_err(
e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2)
? 0
: 1;
if (status == 0){
printf("Validation Pass.\n");
}
return status;
}
return 0;

View File

@@ -190,18 +190,8 @@ struct DeviceMoeGemmBlockScale
#endif
hipModule_t module;
hipFunction_t kernel_func;
auto status = hipModuleLoad(&module, (std::string(MOE_STAGE2_ASM_DIR) + hsa).c_str());
if(status != hipSuccess)
{
printf("Failed to load module (%s): %s.\n", hsa.c_str(), hipGetErrorString(status));
return;
}
status = hipModuleGetFunction(&kernel_func, module, kernel_name.c_str());
if(hipSuccess != status)
{
printf("Failed to get function (%s): %s.\n", kernel_name.c_str(), hipGetErrorString(status));
return;
}
hip_check_error(hipModuleLoad(&module, (std::string(MOE_STAGE2_ASM_DIR) + hsa).c_str()));
hip_check_error(hipModuleGetFunction(&kernel_func, module, kernel_name.c_str()));
auto arg_size = sizeof(arg);
auto arg_ptr = arg;
// // RunKernel(impl_ptr);
@@ -221,7 +211,7 @@ struct DeviceMoeGemmBlockScale
hip_check_error(hipDeviceSynchronize());
hip_check_error(hipEventRecord(start, stream_config.stream_id_));
status = hipModuleLaunchKernel(kernel_func,
hip_check_error(hipModuleLaunchKernel(kernel_func,
gdx,
gdy,
1,
@@ -231,12 +221,7 @@ struct DeviceMoeGemmBlockScale
0,
stream_config.stream_id_,
nullptr,
reinterpret_cast<void**>(&config));
if(hipSuccess != status)
{
printf("Failed to Luach Kernel: %s\n", hipGetErrorString(status));
return;
}
reinterpret_cast<void**>(&config)));
hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
hip_check_error(hipEventSynchronize(stop));
@@ -249,7 +234,7 @@ struct DeviceMoeGemmBlockScale
ave_time = total_time;
}
else{
status = hipModuleLaunchKernel(kernel_func,
hip_check_error(hipModuleLaunchKernel(kernel_func,
gdx,
gdy,
1,
@@ -259,12 +244,7 @@ struct DeviceMoeGemmBlockScale
0,
stream_config.stream_id_,
nullptr,
reinterpret_cast<void**>(&config));
if(hipSuccess != status)
{
printf("Failed to Luach Kernel: %s\n", hipGetErrorString(status));
return;
}
reinterpret_cast<void**>(&config)));
}
};