mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 03:19:48 +00:00
Refine layernorm naming and test code (#497)
* Sync the naming
* Sync the test of layernorm with groupnorm
* Sync the naming
* Minor change for comment and log
* [What] Add saveMean and SaveInvVariance in the interface.
[Why] These can optimize the backward
[ROCm/composable_kernel commit: d4d1147f0a]
This commit is contained in:
@@ -126,6 +126,8 @@ bool profile_groupnorm_impl(int do_verification,
|
||||
gamma_dev.GetDeviceBuffer(),
|
||||
beta_dev.GetDeviceBuffer(),
|
||||
y_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
nullptr,
|
||||
PassThrough{});
|
||||
|
||||
if(inst_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
@@ -196,7 +198,7 @@ bool profile_groupnorm_impl(int do_verification,
|
||||
|
||||
if(num_kernel == 0)
|
||||
{
|
||||
std::cout << "Error: No kernel is tested" << std::endl;
|
||||
std::cout << "Error: No kernel is applicable" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ template <typename XDataType,
|
||||
typename AccDataType,
|
||||
typename YDataType,
|
||||
index_t Rank>
|
||||
void profile_layernorm_impl(int do_verification,
|
||||
bool profile_layernorm_impl(int do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
bool time_kernel,
|
||||
@@ -31,7 +31,7 @@ void profile_layernorm_impl(int do_verification,
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
if(length.size() < 2)
|
||||
return;
|
||||
return false;
|
||||
|
||||
// Assume normalize dimension except for batch (first) dimension
|
||||
std::vector<index_t> reduce_length{length.begin() + 1, length.end()};
|
||||
@@ -52,7 +52,6 @@ void profile_layernorm_impl(int do_verification,
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
// case 0: break;
|
||||
case 0:
|
||||
x.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
gamma.GenerateTensorValue(GeneratorTensor_1<GammaDataType>{});
|
||||
@@ -122,6 +121,8 @@ void profile_layernorm_impl(int do_verification,
|
||||
ref_invoker.Run(ref_argument);
|
||||
}
|
||||
|
||||
int num_kernel = 0;
|
||||
|
||||
for(auto& inst_ptr : instance_ptrs)
|
||||
{
|
||||
auto argument_ptr = inst_ptr->MakeArgumentPointer(length,
|
||||
@@ -135,12 +136,21 @@ void profile_layernorm_impl(int do_verification,
|
||||
gamma_dev.GetDeviceBuffer(),
|
||||
beta_dev.GetDeviceBuffer(),
|
||||
y_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
nullptr,
|
||||
PassThrough{});
|
||||
|
||||
if(!inst_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
if(inst_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
std::cout << inst_ptr->GetTypeString() << " skipped due to unsupported argument: ";
|
||||
LogRange(std::cout << "input lengths = ", length, ", ") << std::endl;
|
||||
++num_kernel;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(time_kernel)
|
||||
{
|
||||
std::cout << inst_ptr->GetTypeString() << " skipped due to unsupported argument: ";
|
||||
LogRange(std::cout << "input lengths = ", length, ", ") << std::endl;
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
@@ -156,8 +166,9 @@ void profile_layernorm_impl(int do_verification,
|
||||
|
||||
float gb_per_sec = num_bytes / 1.E6 / avg_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, "
|
||||
<< inst_ptr->GetTypeString() << std::endl;
|
||||
if(time_kernel)
|
||||
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, "
|
||||
<< inst_ptr->GetTypeString() << std::endl;
|
||||
|
||||
if(avg_time < best_avg_time)
|
||||
{
|
||||
@@ -184,20 +195,32 @@ void profile_layernorm_impl(int do_verification,
|
||||
{
|
||||
std::cout << inst_ptr->GetTypeString() << " failed verification: ";
|
||||
LogRange(std::cout << "lengths = [", length, ", ") << "]." << std::endl;
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "pass" << std::endl;
|
||||
if(time_kernel)
|
||||
std::cout << "pass" << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
LogRange(std::cout << "length = ", length, ",") << ", ";
|
||||
LogRange(std::cout << "stride = ", strideXY, ",") << ", ";
|
||||
LogRange(std::cout << "reduce dims ", reduce_dim, ",") << std::endl;
|
||||
std::cout << "best perf = " << best_avg_time << " ms, " << best_gb_per_sec << " GB/s, "
|
||||
<< best_instance_name << std::endl;
|
||||
if(time_kernel)
|
||||
{
|
||||
LogRange(std::cout << "length = ", length, ",") << ", ";
|
||||
LogRange(std::cout << "stride = ", strideXY, ",") << ", ";
|
||||
LogRange(std::cout << "reduce dims ", reduce_dim, ",") << std::endl;
|
||||
std::cout << "best perf = " << best_avg_time << " ms, " << best_gb_per_sec << " GB/s, "
|
||||
<< best_instance_name << std::endl;
|
||||
}
|
||||
|
||||
if(num_kernel == 0)
|
||||
{
|
||||
std::cout << "Error: No kernel is applicable" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace profiler
|
||||
|
||||
Reference in New Issue
Block a user