Refine layernorm naming and test code (#497)

* Sync the naming

* Sync the test of layernorm with groupnorm

* Sync the naming

* Minor change for comment and log

* [What] Add saveMean and SaveInvVariance in the interface.
[Why] These can optimize the backward

[ROCm/composable_kernel commit: d4d1147f0a]
This commit is contained in:
rocking5566
2022-11-03 06:57:28 +08:00
committed by GitHub
parent 4c1d1a8e59
commit fe367bc917
15 changed files with 207 additions and 311 deletions

View File

@@ -126,6 +126,8 @@ bool profile_groupnorm_impl(int do_verification,
gamma_dev.GetDeviceBuffer(),
beta_dev.GetDeviceBuffer(),
y_dev.GetDeviceBuffer(),
nullptr,
nullptr,
PassThrough{});
if(inst_ptr->IsSupportedArgument(argument_ptr.get()))
@@ -196,7 +198,7 @@ bool profile_groupnorm_impl(int do_verification,
if(num_kernel == 0)
{
std::cout << "Error: No kernel is tested" << std::endl;
std::cout << "Error: No kernel is applicable" << std::endl;
return false;
}

View File

@@ -22,7 +22,7 @@ template <typename XDataType,
typename AccDataType,
typename YDataType,
index_t Rank>
void profile_layernorm_impl(int do_verification,
bool profile_layernorm_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
@@ -31,7 +31,7 @@ void profile_layernorm_impl(int do_verification,
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
if(length.size() < 2)
return;
return false;
// Assume normalize dimension except for batch (first) dimension
std::vector<index_t> reduce_length{length.begin() + 1, length.end()};
@@ -52,7 +52,6 @@ void profile_layernorm_impl(int do_verification,
switch(init_method)
{
// case 0: break;
case 0:
x.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
gamma.GenerateTensorValue(GeneratorTensor_1<GammaDataType>{});
@@ -122,6 +121,8 @@ void profile_layernorm_impl(int do_verification,
ref_invoker.Run(ref_argument);
}
int num_kernel = 0;
for(auto& inst_ptr : instance_ptrs)
{
auto argument_ptr = inst_ptr->MakeArgumentPointer(length,
@@ -135,12 +136,21 @@ void profile_layernorm_impl(int do_verification,
gamma_dev.GetDeviceBuffer(),
beta_dev.GetDeviceBuffer(),
y_dev.GetDeviceBuffer(),
nullptr,
nullptr,
PassThrough{});
if(!inst_ptr->IsSupportedArgument(argument_ptr.get()))
if(inst_ptr->IsSupportedArgument(argument_ptr.get()))
{
std::cout << inst_ptr->GetTypeString() << " skipped due to unsupported argument: ";
LogRange(std::cout << "input lengths = ", length, ", ") << std::endl;
++num_kernel;
}
else
{
if(time_kernel)
{
std::cout << inst_ptr->GetTypeString() << " skipped due to unsupported argument: ";
LogRange(std::cout << "input lengths = ", length, ", ") << std::endl;
}
continue;
}
@@ -156,8 +166,9 @@ void profile_layernorm_impl(int do_verification,
float gb_per_sec = num_bytes / 1.E6 / avg_time;
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, "
<< inst_ptr->GetTypeString() << std::endl;
if(time_kernel)
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, "
<< inst_ptr->GetTypeString() << std::endl;
if(avg_time < best_avg_time)
{
@@ -184,20 +195,32 @@ void profile_layernorm_impl(int do_verification,
{
std::cout << inst_ptr->GetTypeString() << " failed verification: ";
LogRange(std::cout << "lengths = [", length, ", ") << "]." << std::endl;
return;
return false;
}
else
{
std::cout << "pass" << std::endl;
if(time_kernel)
std::cout << "pass" << std::endl;
}
}
}
LogRange(std::cout << "length = ", length, ",") << ", ";
LogRange(std::cout << "stride = ", strideXY, ",") << ", ";
LogRange(std::cout << "reduce dims ", reduce_dim, ",") << std::endl;
std::cout << "best perf = " << best_avg_time << " ms, " << best_gb_per_sec << " GB/s, "
<< best_instance_name << std::endl;
if(time_kernel)
{
LogRange(std::cout << "length = ", length, ",") << ", ";
LogRange(std::cout << "stride = ", strideXY, ",") << ", ";
LogRange(std::cout << "reduce dims ", reduce_dim, ",") << std::endl;
std::cout << "best perf = " << best_avg_time << " ms, " << best_gb_per_sec << " GB/s, "
<< best_instance_name << std::endl;
}
if(num_kernel == 0)
{
std::cout << "Error: No kernel is applicable" << std::endl;
return false;
}
return true;
}
} // namespace profiler