mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-03 13:48:30 +00:00
Update v1_128x128x128 to 2x2 instead of 4x1
This commit is contained in:
Binary file not shown.
@@ -118,7 +118,7 @@ static constexpr ck::index_t Scale_Block_M = 1;
|
||||
static constexpr ck::index_t Scale_Block_N = 128;
|
||||
static constexpr ck::index_t Scale_Block_K = 128;
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
static constexpr ck::index_t MPerBlock = 32;
|
||||
static constexpr ck::index_t BLOCKSIZE = 256;
|
||||
static constexpr ck::index_t MXDLPerWave = 1;
|
||||
@@ -161,11 +161,11 @@ static constexpr ck::index_t MPerBlock = 128; using DeviceOpInstance = ck::tenso
|
||||
MPerBlock, 128, 128,
|
||||
16, 16,
|
||||
32, 32,
|
||||
4, 1,
|
||||
2, 2,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
1, 1, S<1, 32, 1, 8>, S<2, 1, 1, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, false, false, A0DataType>;
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, A0DataType>;
|
||||
#endif
|
||||
// clang-format on
|
||||
|
||||
@@ -180,11 +180,11 @@ int main(int argc, char* argv[])
|
||||
// experts = 8
|
||||
// per expert:
|
||||
|
||||
constexpr ck::index_t valid_tile_num = 52;
|
||||
constexpr ck::index_t sorted_tile_num = valid_tile_num + 3;
|
||||
constexpr ck::index_t valid_tile_num = 13; //13 for 128; 52 for 32; 4096 for ds // > token * topk / MPerBlock
|
||||
constexpr ck::index_t sorted_tile_num = valid_tile_num;// + 3;
|
||||
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
|
||||
ck::index_t valid_size = valid_tile_num * MPerBlock;
|
||||
#if 0
|
||||
#if 1
|
||||
// GEMM shape
|
||||
ck::index_t N = 6144;
|
||||
ck::index_t K = 4096;
|
||||
@@ -249,14 +249,22 @@ int main(int argc, char* argv[])
|
||||
// int eids[] = {0, 1, 2, 3, 4, 5, 6, 7}; //, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
|
||||
//int eids[] = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 3, 3, 3};
|
||||
int eids[sorted_tile_num]{};
|
||||
int e_select = 0;
|
||||
for(int i = 0; i < sorted_tile_num; i++)
|
||||
{
|
||||
if (i < valid_tile_num){
|
||||
eids[i] = std::rand() % experts;
|
||||
eids[i] = e_select;
|
||||
//std::rand() % experts;
|
||||
}
|
||||
else{
|
||||
eids[i] = 3;
|
||||
}
|
||||
if (i > ((e_select + 1) * (sorted_tile_num / experts))){
|
||||
e_select++;
|
||||
if (e_select >= experts){
|
||||
e_select = experts - 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// int eids[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
@@ -319,9 +327,9 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-0.5, 0.5});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0, 1.0});
|
||||
break;
|
||||
@@ -445,7 +453,7 @@ int main(int argc, char* argv[])
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s" << device_op.GetTypeString() << std::endl;
|
||||
<< " GB/s.\n" << device_op.GetTypeString() << std::endl;
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
@@ -540,10 +548,14 @@ int main(int argc, char* argv[])
|
||||
#endif
|
||||
// e_t_n_device_result.savetxt("out.txt");
|
||||
// e_t_n_host_result.savetxt("ref.txt");
|
||||
return ck::utils::check_err(
|
||||
auto status = ck::utils::check_err(
|
||||
e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2)
|
||||
? 0
|
||||
: 1;
|
||||
if (status == 0){
|
||||
printf("Validation Pass.\n");
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
return 0;
|
||||
|
||||
@@ -190,18 +190,8 @@ struct DeviceMoeGemmBlockScale
|
||||
#endif
|
||||
hipModule_t module;
|
||||
hipFunction_t kernel_func;
|
||||
auto status = hipModuleLoad(&module, (std::string(MOE_STAGE2_ASM_DIR) + hsa).c_str());
|
||||
if(status != hipSuccess)
|
||||
{
|
||||
printf("Failed to load module (%s): %s.\n", hsa.c_str(), hipGetErrorString(status));
|
||||
return;
|
||||
}
|
||||
status = hipModuleGetFunction(&kernel_func, module, kernel_name.c_str());
|
||||
if(hipSuccess != status)
|
||||
{
|
||||
printf("Failed to get function (%s): %s.\n", kernel_name.c_str(), hipGetErrorString(status));
|
||||
return;
|
||||
}
|
||||
hip_check_error(hipModuleLoad(&module, (std::string(MOE_STAGE2_ASM_DIR) + hsa).c_str()));
|
||||
hip_check_error(hipModuleGetFunction(&kernel_func, module, kernel_name.c_str()));
|
||||
auto arg_size = sizeof(arg);
|
||||
auto arg_ptr = arg;
|
||||
// // RunKernel(impl_ptr);
|
||||
@@ -221,7 +211,7 @@ struct DeviceMoeGemmBlockScale
|
||||
hip_check_error(hipDeviceSynchronize());
|
||||
hip_check_error(hipEventRecord(start, stream_config.stream_id_));
|
||||
|
||||
status = hipModuleLaunchKernel(kernel_func,
|
||||
hip_check_error(hipModuleLaunchKernel(kernel_func,
|
||||
gdx,
|
||||
gdy,
|
||||
1,
|
||||
@@ -231,12 +221,7 @@ struct DeviceMoeGemmBlockScale
|
||||
0,
|
||||
stream_config.stream_id_,
|
||||
nullptr,
|
||||
reinterpret_cast<void**>(&config));
|
||||
if(hipSuccess != status)
|
||||
{
|
||||
printf("Failed to Luach Kernel: %s\n", hipGetErrorString(status));
|
||||
return;
|
||||
}
|
||||
reinterpret_cast<void**>(&config)));
|
||||
hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
|
||||
hip_check_error(hipEventSynchronize(stop));
|
||||
|
||||
@@ -249,7 +234,7 @@ struct DeviceMoeGemmBlockScale
|
||||
ave_time = total_time;
|
||||
}
|
||||
else{
|
||||
status = hipModuleLaunchKernel(kernel_func,
|
||||
hip_check_error(hipModuleLaunchKernel(kernel_func,
|
||||
gdx,
|
||||
gdy,
|
||||
1,
|
||||
@@ -259,12 +244,7 @@ struct DeviceMoeGemmBlockScale
|
||||
0,
|
||||
stream_config.stream_id_,
|
||||
nullptr,
|
||||
reinterpret_cast<void**>(&config));
|
||||
if(hipSuccess != status)
|
||||
{
|
||||
printf("Failed to Luach Kernel: %s\n", hipGetErrorString(status));
|
||||
return;
|
||||
}
|
||||
reinterpret_cast<void**>(&config)));
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user