mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 02:54:21 +00:00
update layernorm (#1570)
* port layernorm * change warp_welford.hpp * Update warpshuffle * 1. Add save mean and save std back 2. Move construction of tensor_view and tile_window to operator() * refine welford max count calculation * unify layernorm api * Rename file * Remove save mean and inv std * Revert "refine welford max count calculation" This reverts commit022365802b. * Fix order of parameter * refine welford max count calculation again * Remove fp32 instances * Fix bug of padding * refactor api * Support bf16 * Extract common function * Refine arg of operator() * Add kMThreadPerBlock to template parameter * clang format * Refine variable name * Refine file name * remove redundant line * refactor layernorm2d pipeline and add block-per-block utility * fix name * rename more * add more block-per-tile instance * remove duplicated define * update instance for 2048, 1024 case * support up to 2048 now * opt loading * add n1536 * Add two pass pipeline * format * Fix incorrect type * parallel compilation * Use smaller N * fix 2p pass * Support Repeat_M in distribution * Refine nameing * Add reduce example --------- Co-authored-by: letaoqin <letaoqin@amd.com> Co-authored-by: aska-0096 <haocwang@amd.com> Co-authored-by: rocking <ChunYu.Lai@amd.com> Co-authored-by: carlushuang <carlus.huang@amd.com> [ROCm/composable_kernel commit:0394f8a713]
This commit is contained in:
223
example/66_complex_contraction_bilinear/run_complex_contraction_bilinear_example.inc
Executable file → Normal file
223
example/66_complex_contraction_bilinear/run_complex_contraction_bilinear_example.inc
Executable file → Normal file
@@ -127,44 +127,47 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
case 0: break;
|
||||
case 1:
|
||||
|
||||
a_ms_ks_re.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_ns_ks_re.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
d_ms_ns_re.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
a_ms_ks_re.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_ns_ks_re.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
d_ms_ns_re.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
|
||||
a_ms_ks_img.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_ns_ks_img.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
d_ms_ns_img.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
break;
|
||||
a_ms_ks_img.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_ns_ks_img.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
d_ms_ns_img.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
break;
|
||||
|
||||
default:
|
||||
a_ms_ks_re.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_ns_ks_re.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
d_ms_ns_re.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
default:
|
||||
a_ms_ks_re.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_ns_ks_re.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
d_ms_ns_re.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
|
||||
a_ms_ks_img.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_ns_ks_img.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
d_ms_ns_img.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
a_ms_ks_img.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_ns_ks_img.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
d_ms_ns_img.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
|
||||
break;
|
||||
break;
|
||||
}
|
||||
|
||||
DeviceMem a_device_buf_re(sizeof(ADataType) * a_ms_ks_re.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf_re(sizeof(BDataType) * b_ns_ks_re.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d_device_buf_re(sizeof(DDataType) * d_ms_ns_re.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf_re(sizeof(EDataType) * e_ms_ns_device_result_re.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf_re(sizeof(EDataType) *
|
||||
e_ms_ns_device_result_re.mDesc.GetElementSpaceSize());
|
||||
|
||||
DeviceMem a_device_buf_img(sizeof(ADataType) * a_ms_ks_img.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf_img(sizeof(BDataType) * b_ns_ks_img.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d_device_buf_img(sizeof(DDataType) * d_ms_ns_img.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf_img(sizeof(EDataType) * e_ms_ns_device_result_img.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf_img(sizeof(EDataType) *
|
||||
e_ms_ns_device_result_img.mDesc.GetElementSpaceSize());
|
||||
|
||||
// Intermediate Value For E Real and Img
|
||||
DeviceMem e_device_buf_re1(sizeof(EDataType) * e_ms_ns_device_result_re.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf_img1(sizeof(EDataType) * e_ms_ns_device_result_img.mDesc.GetElementSpaceSize());
|
||||
|
||||
DeviceMem e_device_buf_re1(sizeof(EDataType) *
|
||||
e_ms_ns_device_result_re.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf_img1(sizeof(EDataType) *
|
||||
e_ms_ns_device_result_img.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf_re.ToDevice(a_ms_ks_re.mData.data());
|
||||
b_device_buf_re.ToDevice(b_ns_ks_re.mData.data());
|
||||
@@ -181,7 +184,7 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
|
||||
// set zero for intermediate values
|
||||
e_device_buf_re1.SetZero();
|
||||
e_device_buf_img1.SetZero();
|
||||
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{alpha, beta};
|
||||
@@ -189,23 +192,24 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
|
||||
// device operation
|
||||
// For real Intermediate Value re_1
|
||||
|
||||
auto op = DeviceOpInstance{};
|
||||
auto invoker = op.MakeInvoker();
|
||||
auto argument_re1 = op.MakeArgument(a_device_buf_re.GetDeviceBuffer(),
|
||||
b_device_buf_re.GetDeviceBuffer(),
|
||||
std::array<const void*, 1>{d_device_buf_re.GetDeviceBuffer()},
|
||||
e_device_buf_re1.GetDeviceBuffer(),
|
||||
a_ms_ks_lengths,
|
||||
a_ms_ks_strides,
|
||||
b_ns_ks_lengths,
|
||||
b_ns_ks_strides,
|
||||
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
|
||||
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
|
||||
e_ms_ns_lengths,
|
||||
e_ms_ns_strides,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
auto op = DeviceOpInstance{};
|
||||
auto invoker = op.MakeInvoker();
|
||||
auto argument_re1 =
|
||||
op.MakeArgument(a_device_buf_re.GetDeviceBuffer(),
|
||||
b_device_buf_re.GetDeviceBuffer(),
|
||||
std::array<const void*, 1>{d_device_buf_re.GetDeviceBuffer()},
|
||||
e_device_buf_re1.GetDeviceBuffer(),
|
||||
a_ms_ks_lengths,
|
||||
a_ms_ks_strides,
|
||||
b_ns_ks_lengths,
|
||||
b_ns_ks_strides,
|
||||
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
|
||||
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
|
||||
e_ms_ns_lengths,
|
||||
e_ms_ns_strides,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!op.IsSupportedArgument(argument_re1))
|
||||
{
|
||||
@@ -216,7 +220,6 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
|
||||
|
||||
float ave_time_re1 = invoker.Run(argument_re1, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
|
||||
alpha = -1.f;
|
||||
beta = 1.f;
|
||||
|
||||
@@ -228,21 +231,22 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
|
||||
// For real Intermediate Value re_2
|
||||
// auto op = DeviceOpInstance{};
|
||||
// auto invoker = op.MakeInvoker();
|
||||
auto argument_re2 = op.MakeArgument(a_device_buf_img.GetDeviceBuffer(),
|
||||
b_device_buf_img.GetDeviceBuffer(),
|
||||
std::array<const void*, 1>{e_device_buf_re1.GetDeviceBuffer()},
|
||||
e_device_buf_re.GetDeviceBuffer(),
|
||||
a_ms_ks_lengths,
|
||||
a_ms_ks_strides,
|
||||
b_ns_ks_lengths,
|
||||
b_ns_ks_strides,
|
||||
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
|
||||
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
|
||||
e_ms_ns_lengths,
|
||||
e_ms_ns_strides,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
auto argument_re2 =
|
||||
op.MakeArgument(a_device_buf_img.GetDeviceBuffer(),
|
||||
b_device_buf_img.GetDeviceBuffer(),
|
||||
std::array<const void*, 1>{e_device_buf_re1.GetDeviceBuffer()},
|
||||
e_device_buf_re.GetDeviceBuffer(),
|
||||
a_ms_ks_lengths,
|
||||
a_ms_ks_strides,
|
||||
b_ns_ks_lengths,
|
||||
b_ns_ks_strides,
|
||||
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
|
||||
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
|
||||
e_ms_ns_lengths,
|
||||
e_ms_ns_strides,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!op.IsSupportedArgument(argument_re2))
|
||||
{
|
||||
@@ -253,7 +257,6 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
|
||||
|
||||
float ave_time_re2 = invoker.Run(argument_re2, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
|
||||
alpha = 1.f;
|
||||
beta = 1.f;
|
||||
|
||||
@@ -261,22 +264,22 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
|
||||
b_element_op = BElementOp{};
|
||||
cde_element_op = CDEElementOp{alpha, beta};
|
||||
|
||||
auto argument_img1 = op.MakeArgument(a_device_buf_re.GetDeviceBuffer(),
|
||||
b_device_buf_img.GetDeviceBuffer(),
|
||||
std::array<const void*, 1>{d_device_buf_img.GetDeviceBuffer()},
|
||||
e_device_buf_img1.GetDeviceBuffer(),
|
||||
a_ms_ks_lengths,
|
||||
a_ms_ks_strides,
|
||||
b_ns_ks_lengths,
|
||||
b_ns_ks_strides,
|
||||
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
|
||||
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
|
||||
e_ms_ns_lengths,
|
||||
e_ms_ns_strides,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
auto argument_img1 =
|
||||
op.MakeArgument(a_device_buf_re.GetDeviceBuffer(),
|
||||
b_device_buf_img.GetDeviceBuffer(),
|
||||
std::array<const void*, 1>{d_device_buf_img.GetDeviceBuffer()},
|
||||
e_device_buf_img1.GetDeviceBuffer(),
|
||||
a_ms_ks_lengths,
|
||||
a_ms_ks_strides,
|
||||
b_ns_ks_lengths,
|
||||
b_ns_ks_strides,
|
||||
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
|
||||
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
|
||||
e_ms_ns_lengths,
|
||||
e_ms_ns_strides,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!op.IsSupportedArgument(argument_img1))
|
||||
{
|
||||
@@ -290,23 +293,22 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
|
||||
alpha = 1.f;
|
||||
beta = 1.f;
|
||||
|
||||
auto argument_img2 = op.MakeArgument(a_device_buf_img.GetDeviceBuffer(),
|
||||
b_device_buf_re.GetDeviceBuffer(),
|
||||
std::array<const void*, 1>{e_device_buf_img1.GetDeviceBuffer()},
|
||||
e_device_buf_img.GetDeviceBuffer(),
|
||||
a_ms_ks_lengths,
|
||||
a_ms_ks_strides,
|
||||
b_ns_ks_lengths,
|
||||
b_ns_ks_strides,
|
||||
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
|
||||
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
|
||||
e_ms_ns_lengths,
|
||||
e_ms_ns_strides,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
|
||||
auto argument_img2 =
|
||||
op.MakeArgument(a_device_buf_img.GetDeviceBuffer(),
|
||||
b_device_buf_re.GetDeviceBuffer(),
|
||||
std::array<const void*, 1>{e_device_buf_img1.GetDeviceBuffer()},
|
||||
e_device_buf_img.GetDeviceBuffer(),
|
||||
a_ms_ks_lengths,
|
||||
a_ms_ks_strides,
|
||||
b_ns_ks_lengths,
|
||||
b_ns_ks_strides,
|
||||
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
|
||||
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
|
||||
e_ms_ns_lengths,
|
||||
e_ms_ns_strides,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!op.IsSupportedArgument(argument_img2))
|
||||
{
|
||||
@@ -317,7 +319,6 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
|
||||
|
||||
float ave_time_img2 = invoker.Run(argument_img2, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
|
||||
ck::index_t M =
|
||||
ck::accumulate_n<ck::index_t>(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{});
|
||||
|
||||
@@ -331,9 +332,9 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
|
||||
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
|
||||
sizeof(DDataType) * M * N + sizeof(EDataType) * M * N * 2;
|
||||
|
||||
float ave_time = ave_time_img2 + ave_time_img1 + ave_time_re2 + ave_time_re1 ;
|
||||
float ave_time = ave_time_img2 + ave_time_img1 + ave_time_re2 + ave_time_re1;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
@@ -343,7 +344,7 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
|
||||
e_device_buf_img.FromDevice(e_ms_ns_device_result_img.mData.data());
|
||||
|
||||
auto isRealOk = 0;
|
||||
auto isImgOk = 0;
|
||||
auto isImgOk = 0;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
@@ -366,17 +367,16 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
|
||||
auto ref_op = ReferenceOpInstance{};
|
||||
auto ref_invoker = ref_op.MakeInvoker();
|
||||
|
||||
auto ref_argument_re =
|
||||
ref_op.MakeArgument(a_ms_ks_re, b_ns_ks_re, c_ms_ns_host_result_re, a_element_op, b_element_op);
|
||||
auto ref_argument_re = ref_op.MakeArgument(
|
||||
a_ms_ks_re, b_ns_ks_re, c_ms_ns_host_result_re, a_element_op, b_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument_re);
|
||||
|
||||
alpha = 1.f;
|
||||
beta = 1.f;
|
||||
|
||||
|
||||
cde_element_op = CDEElementOp{alpha, beta};
|
||||
|
||||
|
||||
for(size_t m0 = 0; m0 < e_ms_ns_host_result_re.mDesc.GetLengths()[0]; ++m0)
|
||||
{
|
||||
for(size_t m1 = 0; m1 < e_ms_ns_host_result_re.mDesc.GetLengths()[1]; ++m1)
|
||||
@@ -395,11 +395,11 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
|
||||
|
||||
alpha = 1.f;
|
||||
beta = -1.f;
|
||||
|
||||
|
||||
cde_element_op = CDEElementOp{alpha, beta};
|
||||
|
||||
auto ref_argument_re1 =
|
||||
ref_op.MakeArgument(a_ms_ks_img, b_ns_ks_img, c_ms_ns_host_result_re1, a_element_op, b_element_op);
|
||||
auto ref_argument_re1 = ref_op.MakeArgument(
|
||||
a_ms_ks_img, b_ns_ks_img, c_ms_ns_host_result_re1, a_element_op, b_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument_re1);
|
||||
|
||||
@@ -419,23 +419,20 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
|
||||
}
|
||||
}
|
||||
|
||||
isRealOk = ck::utils::check_err(e_ms_ns_device_result_re, e_ms_ns_host_result_re) ? 0 : 1;
|
||||
|
||||
|
||||
|
||||
isRealOk = ck::utils::check_err(e_ms_ns_device_result_re, e_ms_ns_host_result_re) ? 0 : 1;
|
||||
|
||||
// Img Part Verification
|
||||
Tensor<CShuffleDataType> c_ms_ns_host_result_img(e_ms_ns_lengths, e_ms_ns_strides);
|
||||
Tensor<CShuffleDataType> c_ms_ns_host_result_img1(e_ms_ns_lengths, e_ms_ns_strides);
|
||||
|
||||
auto ref_argument_img =
|
||||
ref_op.MakeArgument(a_ms_ks_re, b_ns_ks_img, c_ms_ns_host_result_img, a_element_op, b_element_op);
|
||||
|
||||
auto ref_argument_img = ref_op.MakeArgument(
|
||||
a_ms_ks_re, b_ns_ks_img, c_ms_ns_host_result_img, a_element_op, b_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument_img);
|
||||
|
||||
alpha = 1.f;
|
||||
beta = 1.f;
|
||||
|
||||
|
||||
cde_element_op = CDEElementOp{alpha, beta};
|
||||
|
||||
for(size_t m0 = 0; m0 < e_ms_ns_host_result_img.mDesc.GetLengths()[0]; ++m0)
|
||||
@@ -454,9 +451,9 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
|
||||
}
|
||||
}
|
||||
|
||||
auto ref_argument_img1 =
|
||||
ref_op.MakeArgument(a_ms_ks_img, b_ns_ks_re, c_ms_ns_host_result_img1, a_element_op, b_element_op);
|
||||
|
||||
auto ref_argument_img1 = ref_op.MakeArgument(
|
||||
a_ms_ks_img, b_ns_ks_re, c_ms_ns_host_result_img1, a_element_op, b_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument_img1);
|
||||
|
||||
for(size_t m0 = 0; m0 < e_ms_ns_host_result_img.mDesc.GetLengths()[0]; ++m0)
|
||||
@@ -475,7 +472,7 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
|
||||
}
|
||||
}
|
||||
|
||||
isImgOk = ck::utils::check_err(e_ms_ns_device_result_re, e_ms_ns_host_result_re) ? 0 : 1;
|
||||
isImgOk = ck::utils::check_err(e_ms_ns_device_result_re, e_ms_ns_host_result_re) ? 0 : 1;
|
||||
|
||||
return (isRealOk && isImgOk);
|
||||
}
|
||||
|
||||
@@ -1,4 +1,21 @@
|
||||
set(EXAMPLE_LAYERNORM2D_FWD "tile_example_layernorm2d_fwd")
|
||||
# not using add_example_executable() to add this target, since we don't want this to have
|
||||
# to be included in "make all/install/check"
|
||||
add_executable(tile_example_layernorm2d_fwd EXCLUDE_FROM_ALL layernorm2d_fwd.cpp)
|
||||
target_compile_options(tile_example_layernorm2d_fwd PRIVATE -DSAVE_MEAN_INV_STD)
|
||||
message("adding example ${EXAMPLE_LAYERNORM2D_FWD}")
|
||||
file(GLOB INSTANCE_SRCS instances/*.cpp)
|
||||
add_executable(${EXAMPLE_LAYERNORM2D_FWD} EXCLUDE_FROM_ALL layernorm2d_fwd.cpp)
|
||||
target_include_directories(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_sources(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${INSTANCE_SRCS})
|
||||
|
||||
set(EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS)
|
||||
|
||||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
|
||||
list(APPEND EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
|
||||
|
||||
target_compile_options(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS})
|
||||
|
||||
# TODO: we have to turn off this global prop, otherwise the progress bar generated
|
||||
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
|
||||
# however, this property may affect global
|
||||
# TODO: consider codegen a makefile by us
|
||||
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)
|
||||
|
||||
@@ -6,8 +6,7 @@ This folder contains example for Layernorm2D forward using ck_tile tile-programm
|
||||
```
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch>
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
|
||||
make tile_example_layernorm2d_fwd -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_layernorm2d_fwd`
|
||||
@@ -20,4 +19,4 @@ args:
|
||||
-e epsilon (default:1e-5)
|
||||
-v cpu validation or not (default:1)
|
||||
-prec precision (default:fp16)
|
||||
```
|
||||
```
|
||||
|
||||
155
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_api.cpp
Normal file
155
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_api.cpp
Normal file
@@ -0,0 +1,155 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include "layernorm2d_fwd.hpp"
|
||||
|
||||
template <typename DataType_,
|
||||
ck_tile::index_t Repeat_M_, // each thread repeat along M
|
||||
ck_tile::index_t Repeat_N_, // each thread repeat along N
|
||||
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
|
||||
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
|
||||
ck_tile::index_t Vector_N_, // vector size along N
|
||||
bool kPadN_,
|
||||
bool kSaveMeanInvStd_,
|
||||
bool kTwoPass_>
|
||||
using trait_ = layernorm2d_fwd_traits_<DataType_,
|
||||
Repeat_M_,
|
||||
Repeat_N_,
|
||||
ThreadPerBlock_M_,
|
||||
ThreadPerBlock_N_,
|
||||
Vector_N_,
|
||||
kPadN_,
|
||||
kSaveMeanInvStd_,
|
||||
kTwoPass_>;
|
||||
|
||||
template <typename data_type>
|
||||
float layernorm2d_fwd_b16_(layernorm2d_fwd_traits /*t*/,
|
||||
layernorm2d_fwd_args a,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
#if 1
|
||||
float r = -1;
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd mv 2p
|
||||
if(a.n <= 64) {
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 1, 4, 64, 1, true, false, false>>(s, a);
|
||||
}
|
||||
else if(a.n <= 128) {
|
||||
if (a.n % 2 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 1, 4, 64, 2, true, false, false>>(s, a);
|
||||
else
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 4, 64, 1, true, false, false>>(s, a);
|
||||
}
|
||||
else if(a.n <= 256) {
|
||||
if (a.n % 4 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 1, 4, 64, 4, true, false, false>>(s, a);
|
||||
else if (a.n % 2 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 4, 64, 2, true, false, false>>(s, a);
|
||||
else
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 4, 4, 64, 1, true, false, false>>(s, a);
|
||||
}
|
||||
else if(a.n <= 512) {
|
||||
if (a.n % 8 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 1, 4, 64, 8, true, false, false>>(s, a);
|
||||
else if (a.n % 4 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 4, 64, 4, true, false, false>>(s, a);
|
||||
else if (a.n % 2 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 4, 4, 64, 2, true, false, false>>(s, a);
|
||||
else
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 8, 4, 64, 1, true, false, false>>(s, a);
|
||||
}
|
||||
else if(a.n <= 768) {
|
||||
if (a.n % 4 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 3, 4, 64, 4, true, false, false>>(s, a);
|
||||
else if (a.n % 2 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 6, 4, 64, 2, true, false, false>>(s, a);
|
||||
else
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1,12, 4, 64, 1, true, false, false>>(s, a);
|
||||
}
|
||||
else if(a.n <= 1024) {
|
||||
if (a.n % 8 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 1, 2, 128, 8, true, false, false>>(s, a);
|
||||
else if (a.n % 4 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 2, 128, 4, true, false, false>>(s, a);
|
||||
else if (a.n % 2 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 4, 2, 128, 2, true, false, false>>(s, a);
|
||||
else
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 4, 1, 256, 1, true, false, false>>(s, a);
|
||||
}
|
||||
else if(a.n <= 1536) {
|
||||
if (a.n % 8 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 3, 4, 64, 8, true, false, false>>(s, a);
|
||||
else if (a.n % 4 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 3, 2, 128, 4, true, false, false>>(s, a);
|
||||
else if (a.n % 2 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 3, 1, 256, 2, true, false, false>>(s, a);
|
||||
else
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 6, 1, 256, 1, true, false, false>>(s, a);
|
||||
}
|
||||
else if(a.n <= 2048) {
|
||||
if (a.n % 8 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 1, 1, 256, 8, true, false, false>>(s, a);
|
||||
else if (a.n % 4 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 1, 256, 4, true, false, false>>(s, a);
|
||||
else if (a.n % 2 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 4, 1, 256, 2, true, false, false>>(s, a);
|
||||
else
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 8, 1, 256, 1, true, false, false>>(s, a);
|
||||
}
|
||||
else if(a.n <= 3072) {
|
||||
if (a.n % 8 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 3, 1, 128, 8, true, false, false>>(s, a);
|
||||
else if (a.n % 4 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 3, 1, 256, 4, true, false, false>>(s, a);
|
||||
else if (a.n % 2 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 6, 1, 256, 2, true, false, false>>(s, a);
|
||||
else
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 3, 1, 1024, 1, true, false, false>>(s, a);
|
||||
}
|
||||
else if(a.n <= 4096) {
|
||||
if (a.n % 8 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 1, 256, 8, true, false, false>>(s, a);
|
||||
else if (a.n % 4 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 4, 1, 256, 4, true, false, false>>(s, a);
|
||||
else if (a.n % 2 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 1, 1024, 2, true, false, false>>(s, a);
|
||||
else
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 4, 1, 1024, 1, true, false, false>>(s, a);
|
||||
}
|
||||
else if(a.n > 4096) {
|
||||
if (a.n % 8 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 1, 256, 8, true, false, true>>(s, a);
|
||||
else if (a.n % 4 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 4, 1, 256, 4, true, false, true>>(s, a);
|
||||
else if (a.n % 2 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 1, 1024, 2, true, false, true>>(s, a);
|
||||
else
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 4, 1, 1024, 1, true, false, true>>(s, a);
|
||||
}
|
||||
return r;
|
||||
#else
|
||||
return layernorm2d_fwd_<trait_<data_type, 1, 1, 1, 256, 4, true, false, false>>(s, a);
|
||||
#endif
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
float layernorm2d_fwd(layernorm2d_fwd_traits t,
|
||||
layernorm2d_fwd_args a,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
|
||||
float r = -1;
|
||||
if(t.data_type.compare("fp16") == 0)
|
||||
{
|
||||
return layernorm2d_fwd_b16_<ck_tile::fp16_t>(t, a, s);
|
||||
}
|
||||
else if(t.data_type.compare("bf16") == 0)
|
||||
{
|
||||
return layernorm2d_fwd_b16_<ck_tile::bf16_t>(t, a, s);
|
||||
}
|
||||
if(r < 0)
|
||||
throw std::runtime_error("Without supported instances!");
|
||||
|
||||
return r;
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "layernorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd mv 2p
|
||||
#if 0
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 8, true , false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 4, true , false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 8, 4, 64, 2, true , false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 16, 4, 64, 1, true , false, false>>(const S&, A);
|
||||
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 4, true , false, false>>(const S&, A);
|
||||
#endif
|
||||
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 2, 128, 8, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 2, 128, 4, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 2, 128, 2, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 1, true, false, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,13 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "layernorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd mv 2p
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 3, 4, 64, 8, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 3, 2, 128, 4, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 3, 1, 256, 2, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 6, 1, 256, 1, true, false, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,14 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "layernorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd mv 2p
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 8, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 1, 256, 4, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 2, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 8, 1, 256, 1, true, false, false>>(const S&, A);
|
||||
|
||||
// clang-format on
|
||||
@@ -0,0 +1,12 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "layernorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd mv 2p
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 4, true , false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 2, true , false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 1, true , false, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,14 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "layernorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd mv 2p
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 3, 1, 128, 8, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 3, 1, 256, 4, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 6, 1, 256, 2, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 3, 1, 1024, 1, true, false, false>>(const S&, A);
|
||||
|
||||
// clang-format on
|
||||
@@ -0,0 +1,14 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "layernorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd mv 2p
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 1, 256, 8, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 4, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 1, 1024, 2, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 1024, 1, true, false, false>>(const S&, A);
|
||||
|
||||
// clang-format on
|
||||
@@ -0,0 +1,14 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "layernorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd mv 2p
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 1, 256, 8, true, false, true>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 4, true, false, true>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 1, 1024, 2, true, false, true>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 1024, 1, true, false, true>>(const S&, A);
|
||||
|
||||
// clang-format on
|
||||
@@ -0,0 +1,13 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "layernorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd mv 2p
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 8, true , false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 4, true , false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 2, true , false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 8, 4, 64, 1, true , false, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,12 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "layernorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd mv 2p
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 1, true , false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 2, true , false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 1, true , false, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,12 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "layernorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd mv 2p
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 3, 4, 64, 4, true , false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 6, 4, 64, 2, true , false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 12, 4, 64, 1, true , false, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "layernorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd mv 2p
|
||||
#if 0
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 8, true , false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 4, 64, 4, true , false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 8, 4, 64, 2, true , false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 16, 4, 64, 1, true , false, false>>(const S&, A);
|
||||
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 1, 256, 4, true , false, false>>(const S&, A);
|
||||
#endif
|
||||
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 2, 128, 8, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 2, 128, 4, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 2, 128, 2, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 1, true, false, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,13 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "layernorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd mv 2p
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 4, 64, 8, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 2, 128, 4, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 1, 256, 2, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 6, 1, 256, 1, true, false, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,14 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "layernorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd mv 2p
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 1, 256, 8, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 1, 256, 4, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 2, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 8, 1, 256, 1, true, false, false>>(const S&, A);
|
||||
|
||||
// clang-format on
|
||||
@@ -0,0 +1,12 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "layernorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd mv 2p
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 4, true , false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 2, true , false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 4, 64, 1, true , false, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,14 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "layernorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd mv 2p
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 1, 128, 8, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 1, 256, 4, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 6, 1, 256, 2, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 1, 1024, 1, true, false, false>>(const S&, A);
|
||||
|
||||
// clang-format on
|
||||
@@ -0,0 +1,14 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "layernorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd mv 2p
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 1, 256, 8, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 4, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 1, 1024, 2, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 1024, 1, true, false, false>>(const S&, A);
|
||||
|
||||
// clang-format on
|
||||
@@ -0,0 +1,14 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "layernorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd mv 2p
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 1, 256, 8, true, false, true>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 4, true, false, true>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 1, 1024, 2, true, false, true>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 1024, 1, true, false, true>>(const S&, A);
|
||||
|
||||
// clang-format on
|
||||
@@ -0,0 +1,13 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "layernorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd mv 2p
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 8, true , false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 4, true , false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 4, 64, 2, true , false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 8, 4, 64, 1, true , false, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,12 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "layernorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd mv 2p
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 1, true , false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 2, true , false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 1, true , false, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,12 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "layernorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd mv 2p
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 4, 64, 4, true , false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 6, 4, 64, 2, true , false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 12, 4, 64, 1, true , false, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,67 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include "layernorm2d_fwd.hpp"
|
||||
#include <iostream>
|
||||
|
||||
#pragma once
|
||||
|
||||
using S = ck_tile::stream_config;
|
||||
using A = layernorm2d_fwd_args;
|
||||
|
||||
template <typename DataType_,
|
||||
ck_tile::index_t Repeat_M_, // each thread repeat along M
|
||||
ck_tile::index_t Repeat_N_, // each thread repeat along N
|
||||
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
|
||||
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
|
||||
ck_tile::index_t Vector_N_, // vector size along N
|
||||
bool kPadN_,
|
||||
bool kSaveMeanInvStd_,
|
||||
bool kTwoPass_>
|
||||
using trait_ = layernorm2d_fwd_traits_<DataType_,
|
||||
Repeat_M_,
|
||||
Repeat_N_,
|
||||
ThreadPerBlock_M_,
|
||||
ThreadPerBlock_N_,
|
||||
Vector_N_,
|
||||
kPadN_,
|
||||
kSaveMeanInvStd_,
|
||||
kTwoPass_>;
|
||||
|
||||
template <typename Traits_>
|
||||
float layernorm2d_fwd_(const S& s, A a)
|
||||
{
|
||||
using DataType = typename Traits_::DataType;
|
||||
|
||||
using PipelineProblem = ck_tile::Layernorm2dFwdPipelineProblem<
|
||||
typename LayerNormTypeConfig<DataType>::XDataType,
|
||||
typename LayerNormTypeConfig<DataType>::GammaDataType,
|
||||
typename LayerNormTypeConfig<DataType>::BetaDataType,
|
||||
typename LayerNormTypeConfig<DataType>::ComputeDataType,
|
||||
typename LayerNormTypeConfig<DataType>::YDataType,
|
||||
typename LayerNormTypeConfig<DataType>::MeanDataType,
|
||||
typename LayerNormTypeConfig<DataType>::InvStdDataType,
|
||||
typename Traits_::Shape,
|
||||
Traits_::kPadN,
|
||||
Traits_::kSaveMeanInvStd,
|
||||
Traits_::kTwoPass>;
|
||||
|
||||
using OnePassPipeline = ck_tile::Layernorm2dFwdPipelineOnePass<PipelineProblem>;
|
||||
using TwoPassPipeline = ck_tile::Layernorm2dFwdPipelineTwoPass<PipelineProblem>;
|
||||
using Pipeline = std::conditional_t<Traits_::kTwoPass, TwoPassPipeline, OnePassPipeline>;
|
||||
|
||||
using Kernel = ck_tile::Layernorm2dFwd<Pipeline>;
|
||||
|
||||
const dim3 grids = Kernel::GridSize(a);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
|
||||
auto kargs = Kernel::MakeKargs(a);
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << Kernel::GetName() << std::flush;
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
@@ -2,161 +2,120 @@
|
||||
#include "layernorm2d_fwd.hpp"
|
||||
#include <cstring>
|
||||
|
||||
// Host API implementation
|
||||
float layernorm2d_fwd(layernorm2d_fwd_traits t,
|
||||
layernorm2d_fwd_args a,
|
||||
const ck_tile::stream_config& s)
|
||||
// different threshold for different dtype
|
||||
template <typename DataType>
|
||||
auto get_elimit()
|
||||
{
|
||||
if(t.data_type.compare("fp16") == 0)
|
||||
{
|
||||
using XDataType = ck_tile::half_t;
|
||||
using YDataType = ck_tile::half_t;
|
||||
using GammaDataType = ck_tile::half_t;
|
||||
using BetaDataType = ck_tile::half_t;
|
||||
#ifdef SAVE_MEAN_INV_STD
|
||||
using MeanDataType = ck_tile::half_t;
|
||||
using InvStdDataType = ck_tile::half_t;
|
||||
#else
|
||||
using MeanDataType = ck_tile::null_type;
|
||||
using InvStdDataType = ck_tile::null_type;
|
||||
#endif
|
||||
using ComputeDataType = float;
|
||||
double rtol = 1e-2;
|
||||
double atol = 1e-2;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
using thread_tile = ck_tile::sequence<4, 4>;
|
||||
using warp_tile = ck_tile::sequence<8, 128>;
|
||||
using block_tile = ck_tile::sequence<32, 128>;
|
||||
|
||||
using Shape = ck_tile::TileLayernorm2dShape<thread_tile, warp_tile, block_tile>;
|
||||
|
||||
using PipelineProblem = ck_tile::BlockLayernorm2dFwdProblem<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
MeanDataType,
|
||||
InvStdDataType,
|
||||
Shape,
|
||||
true,
|
||||
true>;
|
||||
|
||||
using Kernel = ck_tile::Layernorm2dFwd<PipelineProblem>;
|
||||
|
||||
auto kargs = Kernel::MakeKargs(
|
||||
a.p_x, a.p_gamma, a.p_beta, a.p_y, a.p_mean, a.p_invStd, a.epsilon, a.M, a.N);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(a.M);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
constexpr ck_tile::index_t kBlockPerCu = Shape::kMWarpPerBlock * Shape::kNWarpPerBlock;
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
return 0;
|
||||
template <>
|
||||
auto get_elimit<ck_tile::bf16_t>()
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1e-2;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "3328", "m dimension")
|
||||
.insert("n", "4096", "m dimension")
|
||||
.insert("n", "4096", "n dimension")
|
||||
.insert("stride", "-1", "stride per row, if -1 then equal to n")
|
||||
.insert("e", "1e-5", "epsilon")
|
||||
.insert("save_mv", "0", "save mean/variance(invstd) or not. set to 1 in training case")
|
||||
.insert("v", "1", "cpu validation or not")
|
||||
.insert("prec", "fp16", "precision");
|
||||
.insert("kname", "1", "print kernel name or not")
|
||||
.insert("prec", "fp16", "precision")
|
||||
.insert("warmup", "5", "cold iter")
|
||||
.insert("repeat", "20", "hot iter");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
template <typename DataType, bool SaveMeanVar>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
ck_tile::index_t m = arg_parser.get_int("m");
|
||||
ck_tile::index_t n = arg_parser.get_int("n");
|
||||
ck_tile::index_t stride = arg_parser.get_int("stride");
|
||||
if(stride < 0)
|
||||
stride = n;
|
||||
float epsilon = arg_parser.get_float("e");
|
||||
ck_tile::index_t M = arg_parser.get_int("m");
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
int kname = arg_parser.get_int("kname");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
|
||||
using XDataType = ck_tile::half_t;
|
||||
using YDataType = ck_tile::half_t;
|
||||
using GammaDataType = ck_tile::half_t;
|
||||
using BetaDataType = ck_tile::half_t;
|
||||
#ifdef SAVE_MEAN_INV_STD
|
||||
using MeanDataType = ck_tile::half_t;
|
||||
using InvStdDataType = ck_tile::half_t;
|
||||
#else
|
||||
using MeanDataType = ck_tile::null_type;
|
||||
using InvStdDataType = ck_tile::null_type;
|
||||
#endif
|
||||
using ComputeDataType = float;
|
||||
assert(stride >= n);
|
||||
|
||||
using TypeConfig = LayerNormTypeConfig<DataType>;
|
||||
|
||||
using XDataType = typename TypeConfig::XDataType;
|
||||
using YDataType = typename TypeConfig::YDataType;
|
||||
using GammaDataType = typename TypeConfig::GammaDataType;
|
||||
using BetaDataType = typename TypeConfig::BetaDataType;
|
||||
|
||||
using MeanDataType =
|
||||
std::conditional_t<SaveMeanVar, typename TypeConfig::MeanDataType, ck_tile::null_type>;
|
||||
using InvStdDataType =
|
||||
std::conditional_t<SaveMeanVar, typename TypeConfig::InvStdDataType, ck_tile::null_type>;
|
||||
|
||||
using ComputeDataType = typename TypeConfig::ComputeDataType;
|
||||
|
||||
// host verify
|
||||
ck_tile::HostTensor<XDataType> x_host({M, N});
|
||||
ck_tile::HostTensor<GammaDataType> gamma_host({N});
|
||||
ck_tile::HostTensor<BetaDataType> beta_host({N});
|
||||
ck_tile::HostTensor<XDataType> x_host({m, n}, {stride, 1});
|
||||
ck_tile::HostTensor<GammaDataType> gamma_host({n});
|
||||
ck_tile::HostTensor<BetaDataType> beta_host({n});
|
||||
|
||||
ck_tile::HostTensor<YDataType> y_host_ref({M, N});
|
||||
ck_tile::HostTensor<YDataType> y_host_dev({M, N});
|
||||
ck_tile::HostTensor<YDataType> y_host_ref({m, n}, {stride, 1});
|
||||
ck_tile::HostTensor<YDataType> y_host_dev({m, n}, {stride, 1});
|
||||
|
||||
ck_tile::HostTensor<MeanDataType> mean_host_ref({M});
|
||||
ck_tile::HostTensor<InvStdDataType> invStd_host_ref({M});
|
||||
ck_tile::HostTensor<MeanDataType> mean_host_ref({m});
|
||||
ck_tile::HostTensor<InvStdDataType> invStd_host_ref({m});
|
||||
|
||||
#ifdef SAVE_MEAN_INV_STD
|
||||
ck_tile::HostTensor<MeanDataType> mean_host_dev({M});
|
||||
ck_tile::HostTensor<InvStdDataType> invStd_host_dev({M});
|
||||
#endif
|
||||
|
||||
ck_tile::FillUniformDistribution<XDataType>{-5.f, 5.f}(x_host);
|
||||
ck_tile::FillUniformDistribution<GammaDataType>{-5.f, 5.f}(gamma_host);
|
||||
ck_tile::FillUniformDistribution<BetaDataType>{-5.f, 5.f}(beta_host);
|
||||
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
|
||||
ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host);
|
||||
ck_tile::FillUniformDistribution<BetaDataType>{-.5f, .5f}(beta_host);
|
||||
|
||||
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes());
|
||||
|
||||
#ifdef SAVE_MEAN_INV_STD
|
||||
ck_tile::DeviceMem mean_buf(mean_host_dev.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem invStd_buf(invStd_host_dev.get_element_space_size_in_bytes());
|
||||
#endif
|
||||
|
||||
x_buf.ToDevice(x_host.data());
|
||||
gamma_buf.ToDevice(gamma_host.data());
|
||||
beta_buf.ToDevice(beta_host.data());
|
||||
|
||||
layernorm2d_fwd_traits traits{data_type};
|
||||
std::cout << "[" << data_type << "]"
|
||||
<< " m:" << m << ", n:" << n << ", stride:" << stride << std::flush;
|
||||
|
||||
layernorm2d_fwd_traits traits{data_type, SaveMeanVar};
|
||||
|
||||
layernorm2d_fwd_args args{x_buf.GetDeviceBuffer(),
|
||||
gamma_buf.GetDeviceBuffer(),
|
||||
beta_buf.GetDeviceBuffer(),
|
||||
y_buf.GetDeviceBuffer(),
|
||||
#ifdef SAVE_MEAN_INV_STD
|
||||
mean_buf.GetDeviceBuffer(),
|
||||
invStd_buf.GetDeviceBuffer(),
|
||||
#else
|
||||
nullptr,
|
||||
nullptr,
|
||||
#endif
|
||||
epsilon,
|
||||
M,
|
||||
N};
|
||||
m,
|
||||
n,
|
||||
stride};
|
||||
|
||||
float ave_time = layernorm2d_fwd(traits, args, ck_tile::stream_config{nullptr, true});
|
||||
float ave_time = layernorm2d_fwd(
|
||||
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
|
||||
|
||||
std::size_t num_byte = sizeof(XDataType) * M * N + sizeof(GammaDataType) * N +
|
||||
sizeof(BetaDataType) * N + sizeof(YDataType) * M * N;
|
||||
std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(GammaDataType) * n +
|
||||
sizeof(BetaDataType) * n + sizeof(YDataType) * m * n;
|
||||
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
std::cout << "[" << data_type << "]"
|
||||
<< " m:" << M << ", n:" << N << ", " << ave_time << " ms, " << gb_per_sec << " GB/s"
|
||||
<< std::flush;
|
||||
std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
@@ -174,20 +133,59 @@ int main(int argc, char* argv[])
|
||||
|
||||
y_buf.FromDevice(y_host_dev.data());
|
||||
|
||||
pass = ck_tile::check_err(y_host_dev, y_host_ref);
|
||||
auto [rtol, atol] = get_elimit<DataType>();
|
||||
if(stride == n)
|
||||
{
|
||||
pass = ck_tile::check_err(
|
||||
y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);
|
||||
}
|
||||
else
|
||||
{
|
||||
for(int i_r = 0; i_r < m; i_r++)
|
||||
{
|
||||
std::vector<YDataType> y_host_dev_row(y_host_dev.begin() + i_r * stride,
|
||||
y_host_dev.begin() + i_r * stride + n);
|
||||
std::vector<YDataType> y_host_ref_row(y_host_ref.begin() + i_r * stride,
|
||||
y_host_ref.begin() + i_r * stride + n);
|
||||
pass &= ck_tile::check_err(y_host_dev_row,
|
||||
y_host_ref_row,
|
||||
std::string("OUT[") + std::to_string(i_r) +
|
||||
std::string("] Error: Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef SAVE_MEAN_INV_STD
|
||||
mean_buf.FromDevice(mean_host_dev.data());
|
||||
pass &= ck_tile::check_err(mean_host_dev, mean_host_ref);
|
||||
|
||||
invStd_buf.FromDevice(invStd_host_dev.data());
|
||||
pass &= ck_tile::check_err(invStd_host_dev, invStd_host_ref);
|
||||
#endif
|
||||
|
||||
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush;
|
||||
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
}
|
||||
|
||||
std::cout << std::endl << std::flush;
|
||||
|
||||
return !pass;
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
int save_mv = arg_parser.get_int("save_mv");
|
||||
if(data_type == "fp16" && save_mv)
|
||||
{
|
||||
return run<ck_tile::half_t, true>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "fp16" && !save_mv)
|
||||
{
|
||||
return run<ck_tile::half_t, false>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "bf16" && save_mv)
|
||||
{
|
||||
return run<ck_tile::bf16_t, true>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "bf16" && !save_mv)
|
||||
{
|
||||
return run<ck_tile::bf16_t, true>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
|
||||
return -3;
|
||||
}
|
||||
|
||||
@@ -8,23 +8,114 @@
|
||||
#include "ck_tile/ops/layernorm2d.hpp"
|
||||
#include <string>
|
||||
|
||||
template <typename DataType>
|
||||
struct LayerNormTypeConfig;
|
||||
|
||||
template <>
|
||||
struct LayerNormTypeConfig<ck_tile::half_t>
|
||||
{
|
||||
using XDataType = ck_tile::half_t;
|
||||
using YDataType = ck_tile::half_t;
|
||||
using GammaDataType = ck_tile::half_t;
|
||||
using BetaDataType = ck_tile::half_t;
|
||||
using MeanDataType = ck_tile::half_t;
|
||||
using InvStdDataType = ck_tile::half_t;
|
||||
using ComputeDataType = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct LayerNormTypeConfig<ck_tile::bf16_t>
|
||||
{
|
||||
using XDataType = ck_tile::bf16_t;
|
||||
using YDataType = ck_tile::bf16_t;
|
||||
using GammaDataType = ck_tile::bf16_t;
|
||||
using BetaDataType = ck_tile::bf16_t;
|
||||
using MeanDataType = ck_tile::bf16_t;
|
||||
using InvStdDataType = ck_tile::bf16_t;
|
||||
using ComputeDataType = float;
|
||||
};
|
||||
|
||||
// runtime args
|
||||
struct layernorm2d_fwd_args : public ck_tile::Layernorm2dFwdHostArgs
|
||||
{
|
||||
};
|
||||
|
||||
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
|
||||
template <typename DataType_,
|
||||
ck_tile::index_t Repeat_M_, // each thread repeat along M
|
||||
ck_tile::index_t Repeat_N_, // each thread repeat along N
|
||||
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
|
||||
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
|
||||
ck_tile::index_t Vector_N_, // vector size along N
|
||||
bool kPadN_,
|
||||
bool kSaveMeanInvStd_,
|
||||
bool kTwoPass_>
|
||||
struct layernorm2d_fwd_traits_
|
||||
{
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0);
|
||||
static constexpr ck_tile::index_t total_warps =
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize;
|
||||
|
||||
// num of warps along m
|
||||
static constexpr ck_tile::index_t BlockWarps_M = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(warpSize % ThreadPerBlock_N_ == 0);
|
||||
return total_warps * (warpSize / ThreadPerBlock_N_);
|
||||
}
|
||||
else
|
||||
{
|
||||
// static_assert(warpSize % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / warpSize);
|
||||
}
|
||||
}();
|
||||
|
||||
// num of warps along n
|
||||
static constexpr ck_tile::index_t BlockWarps_N = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(warpSize % ThreadPerBlock_N_ == 0);
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(ThreadPerBlock_N_ % warpSize == 0);
|
||||
return ThreadPerBlock_N_ / warpSize;
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
|
||||
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
|
||||
|
||||
static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
|
||||
static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;
|
||||
|
||||
static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M;
|
||||
static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_;
|
||||
|
||||
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
|
||||
using BlockWarps = ck_tile::sequence<BlockWarps_M, BlockWarps_N>;
|
||||
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
|
||||
using Vector = ck_tile::sequence<1, Vector_N_>;
|
||||
|
||||
using Shape = ck_tile::Layernorm2dShape<BlockTile, BlockWarps, WarpTile, Vector>;
|
||||
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
|
||||
static constexpr bool kTwoPass = kTwoPass_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
float layernorm2d_fwd_(const ck_tile::stream_config& s, layernorm2d_fwd_args a);
|
||||
|
||||
// This is the public API, will be generated by script
|
||||
struct layernorm2d_fwd_traits
|
||||
{
|
||||
std::string data_type;
|
||||
bool save_mean_var;
|
||||
};
|
||||
|
||||
struct layernorm2d_fwd_args
|
||||
{
|
||||
const void* p_x;
|
||||
const void* p_gamma;
|
||||
const void* p_beta;
|
||||
void* p_y;
|
||||
void* p_mean;
|
||||
void* p_invStd;
|
||||
float epsilon;
|
||||
ck_tile::index_t M;
|
||||
ck_tile::index_t N;
|
||||
};
|
||||
|
||||
// host API
|
||||
float layernorm2d_fwd(layernorm2d_fwd_traits, layernorm2d_fwd_args, const ck_tile::stream_config&);
|
||||
|
||||
38
example/ck_tile/02_layernorm2d/script/perf_test.sh
Executable file
38
example/ck_tile/02_layernorm2d/script/perf_test.sh
Executable file
@@ -0,0 +1,38 @@
|
||||
|
||||
# run from top of ck folder
|
||||
EXE=build/bin/tile_example_layernorm2d_fwd
|
||||
|
||||
$EXE -m=1 -n=1 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
|
||||
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
31
example/ck_tile/02_layernorm2d/script/smoke_test.sh
Executable file
31
example/ck_tile/02_layernorm2d/script/smoke_test.sh
Executable file
@@ -0,0 +1,31 @@
|
||||
#!/bin/sh
|
||||
# call from top of CK folder
|
||||
EXE=./build/bin/tile_example_layernorm2d_fwd
|
||||
|
||||
for pr_i in "fp16" "bf16" ; do
|
||||
$EXE -prec=$pr_i -m=99 -n=13
|
||||
$EXE -prec=$pr_i -m=17 -n=16
|
||||
$EXE -prec=$pr_i -m=1 -n=100
|
||||
$EXE -prec=$pr_i -m=4 -n=128
|
||||
$EXE -prec=$pr_i -m=80 -n=127
|
||||
$EXE -prec=$pr_i -m=22 -n=255 -stride=256
|
||||
$EXE -prec=$pr_i -m=7 -n=599
|
||||
$EXE -prec=$pr_i -m=19 -n=512
|
||||
$EXE -prec=$pr_i -m=33 -n=313 -stride=1000
|
||||
$EXE -prec=$pr_i -m=11 -n=510
|
||||
$EXE -prec=$pr_i -m=171 -n=676 -stride=818
|
||||
$EXE -prec=$pr_i -m=91 -n=636
|
||||
$EXE -prec=$pr_i -m=12 -n=768 -stride=800
|
||||
$EXE -prec=$pr_i -m=100 -n=766 -stride=812
|
||||
$EXE -prec=$pr_i -m=31 -n=1024
|
||||
$EXE -prec=$pr_i -m=64 -n=1000 -stride=1004
|
||||
$EXE -prec=$pr_i -m=8 -n=1501
|
||||
$EXE -prec=$pr_i -m=3 -n=1826
|
||||
$EXE -prec=$pr_i -m=5 -n=2040
|
||||
$EXE -prec=$pr_i -m=7 -n=2734
|
||||
$EXE -prec=$pr_i -m=1 -n=3182
|
||||
$EXE -prec=$pr_i -m=9 -n=4096
|
||||
$EXE -prec=$pr_i -m=3 -n=8192
|
||||
$EXE -prec=$pr_i -m=1 -n=10547
|
||||
$EXE -prec=$pr_i -m=3 -n=17134
|
||||
done
|
||||
19
example/ck_tile/05_reduce/CMakeLists.txt
Normal file
19
example/ck_tile/05_reduce/CMakeLists.txt
Normal file
@@ -0,0 +1,19 @@
|
||||
set(EXAMPLE_REDUCE "tile_example_reduce")
|
||||
# not using add_example_executable() to add this target, since we don't want this to have
|
||||
# to be included in "make all/install/check"
|
||||
message("adding example ${EXAMPLE_REDUCE}")
|
||||
|
||||
add_executable(${EXAMPLE_REDUCE} EXCLUDE_FROM_ALL reduce.cpp)
|
||||
target_include_directories(${EXAMPLE_REDUCE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
set(EXAMPLE_REDUCE_COMPILE_OPTIONS)
|
||||
|
||||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
|
||||
list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
|
||||
|
||||
target_compile_options(${EXAMPLE_REDUCE} PRIVATE ${EXAMPLE_REDUCE_COMPILE_OPTIONS})
|
||||
|
||||
# TODO: we have to turn off this global prop, otherwise the progress bar generated
|
||||
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
|
||||
# however, this property may affect global
|
||||
# TODO: consider codegen a makefile by us
|
||||
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)
|
||||
110
example/ck_tile/05_reduce/reduce.cpp
Normal file
110
example/ck_tile/05_reduce/reduce.cpp
Normal file
@@ -0,0 +1,110 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "reduce.hpp"
|
||||
#include <cstring>
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "3328", "m dimension")
|
||||
.insert("n", "4096", "n dimension")
|
||||
.insert("v", "1", "cpu validation or not")
|
||||
.insert("prec", "fp16", "precision")
|
||||
.insert("warmup", "5", "cold iter")
|
||||
.insert("repeat", "20", "hot iter");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
using ADataType = DataType;
|
||||
using AccDataType = float;
|
||||
using BDataType = DataType;
|
||||
|
||||
ck_tile::index_t m = arg_parser.get_int("m");
|
||||
ck_tile::index_t n = arg_parser.get_int("n");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_host({m, n});
|
||||
ck_tile::HostTensor<BDataType> b_host_ref({m});
|
||||
ck_tile::HostTensor<BDataType> b_host_dev({m});
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_host);
|
||||
|
||||
ck_tile::DeviceMem a_buf(a_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_buf(b_host_dev.get_element_space_size_in_bytes());
|
||||
|
||||
a_buf.ToDevice(a_host.data());
|
||||
|
||||
using BlockWarps = ck_tile::sequence<4, 1>;
|
||||
using BlockTile = ck_tile::sequence<128, 128>;
|
||||
using WarpTile = ck_tile::sequence<32, 128>;
|
||||
using ThreadTile = ck_tile::sequence<8, 8>;
|
||||
|
||||
constexpr ck_tile::index_t kBlockSize = 256;
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
ck_tile::index_t kGridSize = (m / BlockTile::at(ck_tile::number<0>{}));
|
||||
std::cout << "grid size " << kGridSize << std::endl;
|
||||
|
||||
using Kernel = ck_tile::Reduce<ADataType,
|
||||
AccDataType,
|
||||
BDataType,
|
||||
kBlockSize,
|
||||
BlockWarps,
|
||||
BlockTile,
|
||||
WarpTile,
|
||||
ThreadTile>;
|
||||
|
||||
float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
|
||||
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
|
||||
Kernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
static_cast<ADataType*>(a_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_buf.GetDeviceBuffer()),
|
||||
m,
|
||||
n));
|
||||
|
||||
std::size_t num_btype = sizeof(ADataType) * m * n + sizeof(BDataType) * m;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(do_validation)
|
||||
{
|
||||
// reference
|
||||
ck_tile::reference_reduce<ADataType, AccDataType, BDataType>(a_host, b_host_ref);
|
||||
b_buf.FromDevice(b_host_dev.mData.data());
|
||||
pass = ck_tile::check_err(b_host_dev, b_host_ref);
|
||||
|
||||
std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
if(data_type == "bf16")
|
||||
{
|
||||
return run<ck_tile::bf16_t>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
}
|
||||
118
example/ck_tile/05_reduce/reduce.hpp
Normal file
118
example/ck_tile/05_reduce/reduce.hpp
Normal file
@@ -0,0 +1,118 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ADataType,
|
||||
typename AccDataType,
|
||||
typename BDataType,
|
||||
index_t kBlockSize,
|
||||
typename BlockWarps, // num warps along seq<M, N>
|
||||
typename BlockTile, // block size, seq<M, N>
|
||||
typename WarpTile, // warp size, seq<M, N>
|
||||
typename ThreadTile> // contiguous pixels(vector size) along seq<M, N>
|
||||
struct Reduce
|
||||
{
|
||||
static constexpr index_t Block_M = BlockTile::at(number<0>{});
|
||||
static constexpr index_t Block_N = BlockTile::at(number<1>{});
|
||||
|
||||
static constexpr index_t Warp_M = WarpTile::at(number<0>{});
|
||||
static constexpr index_t Warp_N = WarpTile::at(number<1>{});
|
||||
|
||||
static constexpr index_t Thread_M = ThreadTile::at(number<0>{});
|
||||
static constexpr index_t Thread_N = ThreadTile::at(number<1>{});
|
||||
|
||||
static constexpr index_t WarpPerBlock_M = BlockWarps::at(number<0>{});
|
||||
static constexpr index_t WarpPerBlock_N = BlockWarps::at(number<1>{});
|
||||
|
||||
static constexpr index_t ThreadPerWarp_M = Warp_M / Thread_M;
|
||||
static constexpr index_t ThreadPerWarp_N = Warp_N / Thread_N;
|
||||
|
||||
static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M);
|
||||
static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N);
|
||||
|
||||
__device__ static constexpr auto MakeABlockTileDistribution()
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Repeat_M, WarpPerBlock_M, ThreadPerWarp_M, Thread_M>,
|
||||
sequence<Repeat_N, WarpPerBlock_N, ThreadPerWarp_N, Thread_N>>,
|
||||
tuple<sequence<1, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>, sequence<2, 2>>,
|
||||
sequence<1, 1, 2, 2>,
|
||||
sequence<0, 3, 0, 3>>{});
|
||||
}
|
||||
|
||||
__device__ void operator()(const ADataType* p_a, BDataType* p_b, index_t M, index_t N) const
|
||||
{
|
||||
const auto a_m_n = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_a, make_tuple(M, N), make_tuple(N, 1), number<Thread_N>{}, number<1>{});
|
||||
|
||||
const auto iM = get_block_id() * Block_M;
|
||||
|
||||
// A window
|
||||
auto a_block_window = make_tile_window(a_m_n,
|
||||
make_tuple(number<Block_M>{}, number<Block_N>{}),
|
||||
{iM, 0},
|
||||
MakeABlockTileDistribution());
|
||||
|
||||
const auto f_reduce = [](const auto& v0, const auto& v1) { return v0 + v1; };
|
||||
|
||||
const ADataType reduce_init_value = 0;
|
||||
|
||||
constexpr auto reduce_dims = sequence<1>{};
|
||||
|
||||
// Acc tile
|
||||
// TODO: support cross warp reduction
|
||||
auto acc_block_tensor = decltype(block_tile_reduce<AccDataType>(
|
||||
load_tile(a_block_window), reduce_dims, f_reduce, reduce_init_value)){};
|
||||
|
||||
// init Acc tile
|
||||
tile_elementwise_inout(
|
||||
[&](auto& acc) { acc = type_convert<AccDataType>(reduce_init_value); },
|
||||
acc_block_tensor);
|
||||
|
||||
// loop
|
||||
index_t iN = 0;
|
||||
|
||||
do
|
||||
{
|
||||
const auto a_block_tensor = load_tile(a_block_window);
|
||||
|
||||
// FIXME: support cross warp reduction
|
||||
block_tile_reduce(acc_block_tensor, a_block_tensor, reduce_dims, f_reduce);
|
||||
|
||||
move_tile_window(a_block_window, {0, Block_N});
|
||||
|
||||
iN += Block_N;
|
||||
|
||||
} while(iN < N);
|
||||
|
||||
// FIXME: support cross warp reduction
|
||||
block_tile_reduce_sync(acc_block_tensor, f_reduce);
|
||||
|
||||
// convert acc_block_tensor to b_block_tensor
|
||||
const auto b_block_tensor = tile_elementwise_in(
|
||||
[](const auto& acc) { return type_convert<BDataType>(acc); }, acc_block_tensor);
|
||||
|
||||
// B
|
||||
const auto b_m = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
p_b, make_tuple(M), number<32>{});
|
||||
|
||||
// B window
|
||||
auto b_block_window = make_tile_window(b_m, make_tuple(number<Block_M>{}), {iM});
|
||||
|
||||
// store B tile
|
||||
store_tile(b_block_window, b_block_tensor);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -6,3 +6,4 @@ add_subdirectory(01_fmha)
|
||||
add_subdirectory(02_layernorm2d)
|
||||
add_subdirectory(03_gemm)
|
||||
add_subdirectory(04_img2col)
|
||||
add_subdirectory(05_reduce)
|
||||
|
||||
@@ -52,6 +52,7 @@
|
||||
#include "ck_tile/core/tensor/update_tile.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/functional_with_tuple.hpp"
|
||||
#include "ck_tile/core/utility/ignore.hpp"
|
||||
#include "ck_tile/core/utility/magic_div.hpp"
|
||||
#include "ck_tile/core/utility/philox_rand.hpp"
|
||||
|
||||
@@ -59,4 +59,47 @@ CK_TILE_DEVICE T warp_shuffle_down(const T& v_local, uint32_t lane_delta)
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T warp_shuffle(const T& v_local, uint32_t src_lane)
|
||||
{
|
||||
#if 0
|
||||
return __shfl(v_local, src_lane);
|
||||
#elif 1
|
||||
if constexpr(sizeof(int32_t) > sizeof(T))
|
||||
{
|
||||
union packet
|
||||
{
|
||||
int32_t x;
|
||||
T v;
|
||||
};
|
||||
packet p;
|
||||
p.v = v_local;
|
||||
packet p_remote;
|
||||
p_remote.x = __builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast<int32_t>(p));
|
||||
|
||||
return p_remote.v;
|
||||
}
|
||||
else if constexpr(sizeof(int32_t) == sizeof(T))
|
||||
{
|
||||
const int32_t v_remote_tmp =
|
||||
__builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast<int32_t>(v_local));
|
||||
|
||||
return bit_cast<T>(v_remote_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(sizeof(T) % sizeof(int32_t) == 0, "wrong!");
|
||||
constexpr index_t elm = sizeof(T) / sizeof(int32_t);
|
||||
using vector_type = thread_buffer<int32_t, elm>;
|
||||
auto vs = bit_cast<vector_type>(v_local);
|
||||
auto vs_remote = vector_type{};
|
||||
static_for<0, elm, 1>{}([&](auto i_e) {
|
||||
int32_t tmp = __builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast<int32_t>(vs[i_e]));
|
||||
vs_remote(i_e) = tmp;
|
||||
});
|
||||
return bit_cast<T>(vs_remote);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -32,11 +32,13 @@
|
||||
#define CK_TILE_DEVICE inline __device__
|
||||
#define CK_TILE_HOST_DEVICE inline __host__ __device__
|
||||
#define CK_TILE_DEVICE_EXTERN __device__
|
||||
#define CK_TILE_HOST_DEVICE_EXTERN __host__ __device__
|
||||
#else
|
||||
#define CK_TILE_HOST inline
|
||||
#define CK_TILE_DEVICE inline
|
||||
#define CK_TILE_HOST_DEVICE inline
|
||||
#define CK_TILE_DEVICE_EXTERN
|
||||
#define CK_TILE_HOST_DEVICE_EXTERN
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
|
||||
@@ -1111,4 +1111,126 @@ CK_TILE_HOST_DEVICE constexpr auto generate_array(F&& f, number<N>)
|
||||
typename arithmetic_sequence_gen<0, N, 1>::type{});
|
||||
}
|
||||
|
||||
namespace impl {
|
||||
template <typename, typename, typename, index_t>
|
||||
struct reverse_slice_sequence_impl;
|
||||
|
||||
template <index_t x,
|
||||
index_t... xs,
|
||||
index_t m,
|
||||
index_t... ms,
|
||||
index_t id,
|
||||
index_t... ids,
|
||||
index_t SliceSize>
|
||||
struct reverse_slice_sequence_impl<sequence<x, xs...>,
|
||||
sequence<m, ms...>,
|
||||
sequence<id, ids...>,
|
||||
SliceSize>
|
||||
{
|
||||
using old_scan =
|
||||
reverse_slice_sequence_impl<sequence<xs...>, sequence<ms...>, sequence<ids...>, SliceSize>;
|
||||
|
||||
static constexpr auto slice_size = old_scan::remaining_slice_sizes::front().value;
|
||||
static constexpr auto slice_length =
|
||||
std::conditional_t<m, number<gcd(x, slice_size)>, number<x>>::value;
|
||||
|
||||
using dim_lengths =
|
||||
typename sequence_merge<sequence<slice_length>, typename old_scan::dim_lengths>::type;
|
||||
using dim_slices =
|
||||
typename sequence_merge<sequence<x / slice_length>, typename old_scan::dim_slices>::type;
|
||||
using remaining_slice_sizes = typename sequence_merge<
|
||||
std::conditional_t<m, sequence<slice_size / slice_length>, sequence<slice_size>>,
|
||||
typename old_scan::remaining_slice_sizes>::type;
|
||||
|
||||
// the first idx that sliced length not equal to original length
|
||||
static constexpr index_t _flag =
|
||||
slice_length != x && remaining_slice_sizes{}.front().value == 1;
|
||||
static constexpr index_t _split_flag = std::conditional_t<m, number<_flag>, number<0>>::value;
|
||||
static constexpr index_t _split_idx =
|
||||
std::conditional_t<_split_flag, number<id>, number<0>>::value;
|
||||
|
||||
static constexpr index_t split_flag = _split_flag || old_scan::split_flag;
|
||||
static constexpr index_t split_idx = std::
|
||||
conditional_t<old_scan::split_flag, number<old_scan::split_idx>, number<_split_idx>>::value;
|
||||
};
|
||||
|
||||
template <index_t x, index_t m, index_t id, index_t SliceSize>
|
||||
struct reverse_slice_sequence_impl<sequence<x>, sequence<m>, sequence<id>, SliceSize>
|
||||
{
|
||||
static constexpr auto slice_size = SliceSize;
|
||||
static constexpr auto slice_length =
|
||||
std::conditional_t<m, number<gcd(x, slice_size)>, number<x>>::value;
|
||||
|
||||
using dim_lengths = sequence<slice_length>;
|
||||
using dim_slices = sequence<x / slice_length>;
|
||||
using remaining_slice_sizes =
|
||||
std::conditional_t<m, sequence<slice_size / slice_length>, sequence<slice_size>>;
|
||||
|
||||
// the first idx that sliced length not equal to original length
|
||||
static constexpr index_t _flag =
|
||||
slice_length != x && remaining_slice_sizes{}.front().value == 1;
|
||||
static constexpr index_t split_flag = std::conditional_t<m, number<_flag>, number<0>>::value;
|
||||
static constexpr index_t split_idx =
|
||||
std::conditional_t<split_flag, number<id>, number<0>>::value;
|
||||
};
|
||||
} // namespace impl
|
||||
|
||||
// clang-format off
|
||||
// input a sequence(with optional mask), and the SliceSize : size per slice
|
||||
// output the sequence each slice, and number of slices
|
||||
//
|
||||
// e.g. <2, 1, 4, 2>, 8 -> lengths:<1, 1, 4, 2> , nums: <2, 1, 1, 1> : 2 slices , slice_idx: 0
|
||||
// <4, 2, 4, 1, 2>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 1> : 16 slices , slice_idx: 2
|
||||
// <4, 2, 4, 1, 6>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 3> : 48 slices , slice_idx: 2
|
||||
// <4, 2, 5, 1, 2>, 10 -> lengths:<1, 1, 5, 1, 2> , nums: <4, 2, 1, 1, 1> : 8 slices , slice_idx: 1
|
||||
//
|
||||
// <4, 2, 8>, 64 -> lengths:<4, 2, 8> , nums: <1, 1, 1> : 1 slices , slice_idx: 0
|
||||
// <4, 2, 8>, 32 -> lengths:<2, 2, 8> , nums: <2, 1, 1> : 2 slices , slice_idx: 0
|
||||
// <4, 2, 8>, 16 -> lengths:<1, 2, 8> , nums: <4, 1, 1> : 4 slices , slice_idx: 0
|
||||
// <4, 2, 8>, 8 -> lengths:<1, 1, 8> , nums: <4, 2, 1> : 8 slices , slice_idx: 1
|
||||
// <4, 2, 8>, 4 -> lengths:<1, 1, 4> , nums: <4, 2, 2> : 16 slices , slice_idx: 2
|
||||
// <4, 2, 8>, 2 -> lengths:<1, 1, 2> , nums: <4, 2, 4> : 32 slices , slice_idx: 2
|
||||
// <4, 2, 8>, 1 -> lengths:<1, 1, 1> , nums: <4, 2, 8> : 64 slices , slice_idx: 2
|
||||
//
|
||||
// <4, 2, 1, 4, 2> / 4 ->
|
||||
// mask:<1, 1, 1, 0, 1>, -> lengths:<1, 2, 1, 4, 2> , nums: <4, 1, 1, 1, 1> : 8 slices , slice_idx: 0
|
||||
//
|
||||
// return tuple<slice_lengths, slice_nums, slice_index>, slice_index is at which index will start
|
||||
// have split slices (right -> left)
|
||||
// or the first index that sliced length is different from the original length
|
||||
// clang-format on
|
||||
template <typename Seq,
|
||||
index_t SliceSize,
|
||||
typename Mask = typename uniform_sequence_gen<Seq::size(), 1>::type>
|
||||
constexpr auto reverse_slice_sequence(Seq,
|
||||
number<SliceSize>,
|
||||
Mask = typename uniform_sequence_gen<Seq::size(), 1>::type{})
|
||||
{
|
||||
static_assert(Seq::size() == Mask::size());
|
||||
using sliced_type =
|
||||
impl::reverse_slice_sequence_impl<Seq,
|
||||
Mask,
|
||||
typename arithmetic_sequence_gen<0, Seq::size(), 1>::type,
|
||||
SliceSize>;
|
||||
static_assert(sliced_type::remaining_slice_sizes::front().value == 1,
|
||||
"can not evenly divide this sequence, please check");
|
||||
return make_tuple(typename sliced_type::dim_lengths{},
|
||||
typename sliced_type::dim_slices{},
|
||||
number<sliced_type::split_idx>{});
|
||||
}
|
||||
|
||||
template <typename Seq,
|
||||
index_t SliceSize,
|
||||
typename Mask = typename uniform_sequence_gen<Seq::size(), 1>::type>
|
||||
constexpr auto slice_sequence(Seq,
|
||||
number<SliceSize>,
|
||||
Mask = typename uniform_sequence_gen<Seq::size(), 1>::type{})
|
||||
{
|
||||
constexpr auto r =
|
||||
reverse_slice_sequence(Seq{}.reverse(), number<SliceSize>{}, Mask{}.reverse());
|
||||
return make_tuple(r[number<0>{}].reverse(),
|
||||
r[number<1>{}].reverse(),
|
||||
number<Seq::size() - r[number<2>{}] - 1>{});
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -488,6 +488,26 @@ CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y,
|
||||
f, x, y, z, typename arithmetic_sequence_gen<0, X::size(), 1>::type{});
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename F, typename X, index_t... Is>
|
||||
CK_TILE_HOST_DEVICE constexpr auto embed_tuples_impl(F f, const X& x, sequence<Is...>)
|
||||
{
|
||||
return concat_tuple(f(x.at(number<Is>{}))...);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// make sure F return at least a tuple
|
||||
// e.g. x : tuple<X, Y>, f will return tuple<Z, W>
|
||||
// this function will return
|
||||
template <typename F, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto embed_tuples(F f, const X& x)
|
||||
{
|
||||
return detail::embed_tuples_impl(
|
||||
f, x, typename arithmetic_sequence_gen<0, X::size(), 1>::type{});
|
||||
}
|
||||
|
||||
// By default unroll to the flatten
|
||||
template <index_t Depth = 0, index_t MaxDepth = -1>
|
||||
CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const tuple<>& t)
|
||||
|
||||
@@ -187,4 +187,18 @@ set_tile_if(static_distributed_tensor<DataType, StaticTileDistribution>& out_ten
|
||||
});
|
||||
}
|
||||
|
||||
// this function used inside span loop over
|
||||
template <typename YLengths, index_t XUnpacks>
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks_from_x_unpacks(YLengths, number<XUnpacks>)
|
||||
{
|
||||
constexpr auto y_size = reduce_on_sequence(YLengths{}, multiplies{}, number<1>{});
|
||||
constexpr auto y_packs = number<XUnpacks>{};
|
||||
static_assert(y_size % y_packs == 0);
|
||||
constexpr auto y_slice_size = y_size / y_packs;
|
||||
|
||||
constexpr auto slice_info = slice_sequence(YLengths{}, number<y_slice_size>{});
|
||||
constexpr auto unpacks = slice_info[number<1>{}];
|
||||
return unpacks;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/functional_with_tuple.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
@@ -27,4 +28,281 @@ CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F& f)
|
||||
});
|
||||
}
|
||||
|
||||
// unpacked span, this version support span with unpack(multi-arg) functor
|
||||
//
|
||||
template <
|
||||
typename TileDistributedSpan_, // tile_distributed_span<...>
|
||||
typename F, // signature: F(tile_distributed_index<...>)
|
||||
typename Unpacks = typename uniform_sequence_gen<TileDistributedSpan_::Impl::size(), 1>::type>
|
||||
CK_TILE_DEVICE void sweep_tile_uspan(TileDistributedSpan_, const F& f, Unpacks = {})
|
||||
{
|
||||
using DstrSpan = remove_cvref_t<TileDistributedSpan_>;
|
||||
|
||||
static_uford<typename DstrSpan::Impl, Unpacks>{}(
|
||||
[&](auto... dstr_idx_impl) { f(detail::make_tile_distributed_index(dstr_idx_impl)...); });
|
||||
}
|
||||
|
||||
namespace impl {
|
||||
|
||||
template <typename, typename, typename>
|
||||
struct sweep_tile_impl;
|
||||
|
||||
template <typename DistributedTensor, typename UnpacksPerXDim, index_t I, index_t... Is>
|
||||
struct sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<I, Is...>>
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks() const
|
||||
{
|
||||
constexpr auto spans = DistributedTensor::get_distributed_spans();
|
||||
constexpr auto y_lengths = typename decltype(spans[number<I>{}])::Impl{};
|
||||
constexpr auto x_unpacks = number<UnpacksPerXDim{}.at(number<I>{})>{};
|
||||
constexpr auto y_unpacks = get_y_unpacks_from_x_unpacks(y_lengths, x_unpacks);
|
||||
return y_unpacks;
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr index_t get_num_of_access() const
|
||||
{
|
||||
constexpr auto spans = DistributedTensor::get_distributed_spans();
|
||||
constexpr auto u =
|
||||
static_uford<typename decltype(spans[number<I>{}])::Impl, decltype(get_y_unpacks())>{};
|
||||
return u.get_num_of_access() *
|
||||
sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}
|
||||
.get_num_of_access();
|
||||
}
|
||||
template <typename F, typename SpanIdx>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(const F& f, const SpanIdx& span_idx) const
|
||||
{
|
||||
constexpr auto spans = DistributedTensor::get_distributed_spans();
|
||||
|
||||
sweep_tile_uspan(
|
||||
spans[number<I>{}],
|
||||
[&](auto... i_idx) {
|
||||
const auto next_span_idx = embed_tuples(
|
||||
[&](auto si) { return make_tuple(concat_tuple(si, make_tuple(i_idx))...); },
|
||||
span_idx);
|
||||
sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}(
|
||||
f, next_span_idx);
|
||||
},
|
||||
get_y_unpacks());
|
||||
}
|
||||
template <typename F, typename SpanIdx, index_t i_access>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
operator()(const F& f, const SpanIdx& span_idx, number<i_access>) const
|
||||
{
|
||||
constexpr auto spans = DistributedTensor::get_distributed_spans();
|
||||
constexpr auto u =
|
||||
static_uford<typename decltype(spans[number<I>{}])::Impl, decltype(get_y_unpacks())>{};
|
||||
constexpr auto access_stride =
|
||||
sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}
|
||||
.get_num_of_access();
|
||||
constexpr auto curr_i_access = number<i_access / access_stride>{};
|
||||
constexpr auto next_i_access = number<i_access % access_stride>{};
|
||||
u(
|
||||
[&](auto... i_idx) {
|
||||
const auto next_span_idx = embed_tuples(
|
||||
[&](auto si) {
|
||||
return make_tuple(concat_tuple(
|
||||
si, make_tuple(detail::make_tile_distributed_index(i_idx)))...);
|
||||
},
|
||||
span_idx);
|
||||
sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}(
|
||||
f, next_span_idx, next_i_access);
|
||||
},
|
||||
curr_i_access);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DistributedTensor, typename UnpacksPerXDim>
|
||||
struct sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<>>
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr index_t get_num_of_access() const { return 1; }
|
||||
template <typename F, typename SpanIdx>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(const F& f, const SpanIdx& span_idx) const
|
||||
{
|
||||
unpack(f, span_idx);
|
||||
}
|
||||
template <typename F, typename SpanIdx, index_t i_access>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
operator()(const F& f, const SpanIdx& span_idx, number<i_access>) const
|
||||
{
|
||||
unpack(f, span_idx);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename, typename, typename>
|
||||
struct sweep_tile_impl_0;
|
||||
|
||||
// TODO: support empty tuple to remove this "entry-point" like function
|
||||
template <typename DistributedTensor, typename UnpacksPerXDim, index_t I, index_t... Is>
|
||||
struct sweep_tile_impl_0<DistributedTensor, UnpacksPerXDim, sequence<I, Is...>>
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks() const
|
||||
{
|
||||
constexpr auto spans = DistributedTensor::get_distributed_spans();
|
||||
constexpr auto y_lengths = typename decltype(spans[number<I>{}])::Impl{};
|
||||
constexpr auto x_unpacks = number<UnpacksPerXDim{}.at(number<I>{})>{};
|
||||
constexpr auto y_unpacks = get_y_unpacks_from_x_unpacks(y_lengths, x_unpacks);
|
||||
return y_unpacks;
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr index_t get_num_of_access() const
|
||||
{
|
||||
constexpr auto spans = DistributedTensor::get_distributed_spans();
|
||||
constexpr auto u =
|
||||
static_uford<typename decltype(spans[number<I>{}])::Impl, decltype(get_y_unpacks())>{};
|
||||
return u.get_num_of_access() *
|
||||
sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}
|
||||
.get_num_of_access();
|
||||
}
|
||||
template <typename F>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(const F& f) const
|
||||
{
|
||||
constexpr auto spans = DistributedTensor::get_distributed_spans();
|
||||
sweep_tile_uspan(
|
||||
spans[number<I>{}],
|
||||
[&](auto... i_idx) {
|
||||
constexpr auto next_span_idx = make_tuple(make_tuple(i_idx)...);
|
||||
sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}(
|
||||
f, next_span_idx);
|
||||
},
|
||||
get_y_unpacks());
|
||||
}
|
||||
template <typename F, index_t i_access>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(const F& f, number<i_access>) const
|
||||
{
|
||||
constexpr auto spans = DistributedTensor::get_distributed_spans();
|
||||
constexpr auto u =
|
||||
static_uford<typename decltype(spans[number<I>{}])::Impl, decltype(get_y_unpacks())>{};
|
||||
constexpr auto access_stride =
|
||||
sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}
|
||||
.get_num_of_access();
|
||||
constexpr auto curr_i_access = number<i_access / access_stride>{};
|
||||
constexpr auto next_i_access = number<i_access % access_stride>{};
|
||||
u(
|
||||
[&](auto... i_idx) {
|
||||
constexpr auto next_span_idx =
|
||||
make_tuple(make_tuple(detail::make_tile_distributed_index(i_idx))...);
|
||||
sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}(
|
||||
f, next_span_idx, next_i_access);
|
||||
},
|
||||
curr_i_access);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace impl
|
||||
|
||||
/*
|
||||
* Enhanced sweep-tile utility, can control unpacks along each X-dim
|
||||
* the lambda function argument is the distributed-idx, which can directly
|
||||
* plugged into the distributed tensor as setter/getter
|
||||
*
|
||||
* e.g. below function, y with the type DistributedTensor, r is row scale
|
||||
*
|
||||
* // sweep tile 1 by 1
|
||||
* sweep_tile<DistributedTensor>([&](auto idx) {
|
||||
* constexpr auto row_id = make_tuple(idx[number<0>{}]);
|
||||
* y(idx) = y(idx) * r(row_id);
|
||||
* });
|
||||
*
|
||||
* // sweep tile with 2 pixel from last dim each function call
|
||||
* sweep_tile<DistributedTensor>(
|
||||
* [&](auto idx_0, auto idx_1) {
|
||||
* constexpr auto row_id = make_tuple(idx_0[number<0>{}]);
|
||||
* y(idx_0) = y(idx_0) * r(row_id);
|
||||
* y(idx_1) = y(idx_1) * r(row_id);
|
||||
* },
|
||||
* sequence<1, 2>{});
|
||||
*
|
||||
* // sweep tile with 2x2 pixel each function call
|
||||
* sweep_tile<DistributedTensor>(
|
||||
* [&](auto idx_00, auto idx_01, auto idx_10, auto idx_11) {
|
||||
* constexpr auto row_id0 = make_tuple(idx_00[number<0>{}]);
|
||||
* constexpr auto row_id1 = make_tuple(idx_10[number<0>{}]);
|
||||
* y(idx_00) = y(idx_00) * r(row_id0);
|
||||
* y(idx_01) = y(idx_01) * r(row_id0);
|
||||
* y(idx_10) = y(idx_10) * r(row_id1);
|
||||
* y(idx_11) = y(idx_11) * r(row_id1);
|
||||
* },
|
||||
* sequence<2, 2>{});
|
||||
*
|
||||
* TODO: do we need constexpr? lambda function could be non-constexpr
|
||||
*/
|
||||
template <typename DistributedTensor,
|
||||
typename F,
|
||||
typename UnpacksPerXDim =
|
||||
typename uniform_sequence_gen<DistributedTensor::get_num_of_dimension(), 1>::type>
|
||||
CK_TILE_HOST_DEVICE constexpr void sweep_tile(const F& f, UnpacksPerXDim = {})
|
||||
{
|
||||
constexpr auto spans = DistributedTensor::get_distributed_spans();
|
||||
|
||||
impl::sweep_tile_impl_0<DistributedTensor,
|
||||
UnpacksPerXDim,
|
||||
typename arithmetic_sequence_gen<0, spans.size(), 1>::type>{}(f);
|
||||
}
|
||||
|
||||
template <typename DistributedTensor,
|
||||
typename F,
|
||||
typename UnpacksPerXDim =
|
||||
typename uniform_sequence_gen<DistributedTensor::get_num_of_dimension(), 1>::type>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
sweep_tile(const DistributedTensor&, const F& f, UnpacksPerXDim = {})
|
||||
{
|
||||
sweep_tile<DistributedTensor, F, UnpacksPerXDim>(f, UnpacksPerXDim{});
|
||||
}
|
||||
|
||||
/*
|
||||
* construct a sweep tile instance, which support issue the lambda one by one
|
||||
* Note that this struct will hold the lambda functor, but will not hold the distributed tensor
|
||||
* the functionality is the same as sweep_tile()
|
||||
*/
|
||||
template <typename DistributedTensor_,
|
||||
typename F_,
|
||||
typename UnpacksPerXDim_ =
|
||||
typename uniform_sequence_gen<DistributedTensor_::get_num_of_dimension(), 1>::type>
|
||||
struct tile_sweeper
|
||||
{
|
||||
using DistributedTensor = remove_cvref_t<DistributedTensor_>;
|
||||
using F = remove_cvref_t<F_>;
|
||||
using UnpacksPerXDim = remove_cvref_t<UnpacksPerXDim_>;
|
||||
|
||||
CK_TILE_HOST_DEVICE tile_sweeper(const F& f_, UnpacksPerXDim = {}) : f(f_) {}
|
||||
CK_TILE_HOST_DEVICE tile_sweeper(const DistributedTensor&, const F& f_, UnpacksPerXDim = {})
|
||||
: f(f_)
|
||||
{
|
||||
}
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_access()
|
||||
{
|
||||
constexpr auto spans = DistributedTensor::get_distributed_spans();
|
||||
constexpr auto tmp =
|
||||
impl::sweep_tile_impl_0<DistributedTensor,
|
||||
UnpacksPerXDim,
|
||||
typename arithmetic_sequence_gen<0, spans.size(), 1>::type>{};
|
||||
return tmp.get_num_of_access();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void operator()() const
|
||||
{
|
||||
sweep_tile<DistributedTensor>(f, UnpacksPerXDim{});
|
||||
}
|
||||
|
||||
template <index_t i_access>
|
||||
CK_TILE_HOST_DEVICE void operator()(number<i_access>) const
|
||||
{
|
||||
constexpr auto spans = DistributedTensor::get_distributed_spans();
|
||||
|
||||
impl::sweep_tile_impl_0<DistributedTensor,
|
||||
UnpacksPerXDim,
|
||||
typename arithmetic_sequence_gen<0, spans.size(), 1>::type>{}(
|
||||
f, number<i_access>{});
|
||||
}
|
||||
F f;
|
||||
};
|
||||
|
||||
// partial deduction is not allowed
|
||||
// template <typename T, typename F, typename U>
|
||||
// CK_TILE_HOST_DEVICE_EXTERN tile_sweeper(const F&, U = {})->tile_sweeper<T, F, U>;
|
||||
|
||||
// deduction guide
|
||||
template <typename T,
|
||||
typename F,
|
||||
typename U = typename uniform_sequence_gen<T::get_num_of_dimension(), 1>::type>
|
||||
CK_TILE_HOST_DEVICE_EXTERN tile_sweeper(const T&, const F&, U = {})->tile_sweeper<T, F, U>;
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -17,6 +17,14 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
namespace detail {
|
||||
template <typename Distribution>
|
||||
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
|
||||
{
|
||||
return Distribution::_get_partition_index();
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
// distributed span
|
||||
template <index_t... PartialHsLengths>
|
||||
struct tile_distributed_span
|
||||
@@ -83,6 +91,21 @@ struct tile_distribution
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_p() { return NDimP; }
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_r() { return NDimR; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static auto _get_partition_index()
|
||||
{
|
||||
// only support warp-tile and block-tile
|
||||
static_assert(NDimP == 1 or NDimP == 2, "wrong!");
|
||||
|
||||
if constexpr(NDimP == 1)
|
||||
{
|
||||
return array<index_t, 1>{get_lane_id()};
|
||||
}
|
||||
else if constexpr(NDimP == 2)
|
||||
{
|
||||
return array<index_t, 2>{get_warp_id(), get_lane_id()};
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_lengths()
|
||||
{
|
||||
#if 0
|
||||
@@ -149,6 +172,16 @@ struct tile_distribution
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename PartitionIndex = decltype(_get_partition_index())>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
calculate_index(const PartitionIndex& ps_idx = _get_partition_index()) const
|
||||
{
|
||||
const auto ps_ys_idx = container_concat(ps_idx, array<index_t, NDimY>{0});
|
||||
const auto window_adaptor_thread_coord_tmp =
|
||||
make_tensor_adaptor_coordinate(ps_ys_to_xs_, ps_ys_idx);
|
||||
return window_adaptor_thread_coord_tmp.get_bottom_index();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_distributed_spans()
|
||||
{
|
||||
constexpr auto distributed_spans_impl = DstrEncode::detail::distributed_spans_lengthss_;
|
||||
@@ -421,6 +454,7 @@ struct tile_distribution_detail
|
||||
|
||||
} // namespace detail
|
||||
|
||||
#if 0
|
||||
// this returns a constexpr tile_distribution
|
||||
template <typename StaticTileDistributionEncoding_>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_tile_distribution(StaticTileDistributionEncoding_)
|
||||
@@ -457,6 +491,7 @@ CK_TILE_HOST_DEVICE constexpr auto make_tile_distribution(StaticTileDistribution
|
||||
detail::tile_distribution_detail<remove_cvref_t<decltype(rh_major_minor_to_hidden_ids)>>>{
|
||||
ps_ys_to_xs_adaptor, ys_to_d_descriptor};
|
||||
}
|
||||
#endif
|
||||
|
||||
// this returns a static tile_distribution
|
||||
template <typename StaticTileDistributionEncoding_>
|
||||
@@ -499,129 +534,6 @@ CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistr
|
||||
//***********************************************************************************
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename Distribution>
|
||||
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
|
||||
{
|
||||
// only support warp-tile and block-tile
|
||||
static_assert(Distribution::NDimP == 1 or Distribution::NDimP == 2, "wrong!");
|
||||
|
||||
if constexpr(Distribution::NDimP == 1)
|
||||
{
|
||||
return array<index_t, 1>{get_lane_id()};
|
||||
}
|
||||
else if constexpr(Distribution::NDimP == 2)
|
||||
{
|
||||
return array<index_t, 2>{get_warp_id(), get_lane_id()};
|
||||
}
|
||||
}
|
||||
|
||||
template <typename, typename, typename, index_t>
|
||||
struct reverse_slice_sequence_impl;
|
||||
|
||||
template <index_t x,
|
||||
index_t... xs,
|
||||
index_t m,
|
||||
index_t... ms,
|
||||
index_t id,
|
||||
index_t... ids,
|
||||
index_t SliceSize>
|
||||
struct reverse_slice_sequence_impl<sequence<x, xs...>,
|
||||
sequence<m, ms...>,
|
||||
sequence<id, ids...>,
|
||||
SliceSize>
|
||||
{
|
||||
using old_scan =
|
||||
reverse_slice_sequence_impl<sequence<xs...>, sequence<ms...>, sequence<ids...>, SliceSize>;
|
||||
|
||||
static constexpr auto slice_size = old_scan::remaining_slice_sizes::front().value;
|
||||
static constexpr auto slice_length =
|
||||
std::conditional_t<m, number<gcd(x, slice_size)>, number<x>>::value;
|
||||
|
||||
using dim_lengths =
|
||||
typename sequence_merge<sequence<slice_length>, typename old_scan::dim_lengths>::type;
|
||||
using dim_slices =
|
||||
typename sequence_merge<sequence<x / slice_length>, typename old_scan::dim_slices>::type;
|
||||
using remaining_slice_sizes = typename sequence_merge<
|
||||
std::conditional_t<m, sequence<slice_size / slice_length>, sequence<slice_size>>,
|
||||
typename old_scan::remaining_slice_sizes>::type;
|
||||
|
||||
// the first idx that sliced length not equal to original length
|
||||
static constexpr index_t _flag =
|
||||
slice_length != x && remaining_slice_sizes{}.front().value == 1;
|
||||
static constexpr index_t _split_flag = std::conditional_t<m, number<_flag>, number<0>>::value;
|
||||
static constexpr index_t _split_idx =
|
||||
std::conditional_t<_split_flag, number<id>, number<0>>::value;
|
||||
|
||||
static constexpr index_t split_flag = _split_flag || old_scan::split_flag;
|
||||
static constexpr index_t split_idx = std::
|
||||
conditional_t<old_scan::split_flag, number<old_scan::split_idx>, number<_split_idx>>::value;
|
||||
};
|
||||
|
||||
template <index_t x, index_t m, index_t id, index_t SliceSize>
|
||||
struct reverse_slice_sequence_impl<sequence<x>, sequence<m>, sequence<id>, SliceSize>
|
||||
{
|
||||
static constexpr auto slice_size = SliceSize;
|
||||
static constexpr auto slice_length =
|
||||
std::conditional_t<m, number<gcd(x, slice_size)>, number<x>>::value;
|
||||
|
||||
using dim_lengths = sequence<slice_length>;
|
||||
using dim_slices = sequence<x / slice_length>;
|
||||
using remaining_slice_sizes =
|
||||
std::conditional_t<m, sequence<slice_size / slice_length>, sequence<slice_size>>;
|
||||
|
||||
// the first idx that sliced length not equal to original length
|
||||
static constexpr index_t _flag =
|
||||
slice_length != x && remaining_slice_sizes{}.front().value == 1;
|
||||
static constexpr index_t split_flag = std::conditional_t<m, number<_flag>, number<0>>::value;
|
||||
static constexpr index_t split_idx =
|
||||
std::conditional_t<split_flag, number<id>, number<0>>::value;
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
// input a sequence(with optional mask), and the SliceSize : size per slice
|
||||
// output the sequence each slice, and number of slices
|
||||
//
|
||||
// e.g. <2, 1, 4, 2>, 8 -> lengths:<1, 1, 4, 2> , nums: <2, 1, 1, 1> : 2 slices , slice_idx: 0
|
||||
// <4, 2, 4, 1, 2>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 1> : 16 slices , slice_idx: 2
|
||||
// <4, 2, 4, 1, 6>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 3> : 48 slices , slice_idx: 2
|
||||
// <4, 2, 5, 1, 2>, 10 -> lengths:<1, 1, 5, 1, 2> , nums: <4, 2, 1, 1, 1> : 8 slices , slice_idx: 1
|
||||
//
|
||||
// <4, 2, 8>, 64 -> lengths:<4, 2, 8> , nums: <1, 1, 1> : 1 slices , slice_idx: 0
|
||||
// <4, 2, 8>, 32 -> lengths:<2, 2, 8> , nums: <2, 1, 1> : 2 slices , slice_idx: 0
|
||||
// <4, 2, 8>, 16 -> lengths:<1, 2, 8> , nums: <4, 1, 1> : 4 slices , slice_idx: 0
|
||||
// <4, 2, 8>, 8 -> lengths:<1, 1, 8> , nums: <4, 2, 1> : 8 slices , slice_idx: 1
|
||||
// <4, 2, 8>, 4 -> lengths:<1, 1, 4> , nums: <4, 2, 2> : 16 slices , slice_idx: 2
|
||||
// <4, 2, 8>, 2 -> lengths:<1, 1, 2> , nums: <4, 2, 4> : 32 slices , slice_idx: 2
|
||||
// <4, 2, 8>, 1 -> lengths:<1, 1, 1> , nums: <4, 2, 8> : 64 slices , slice_idx: 2
|
||||
//
|
||||
// <4, 2, 1, 4, 2> / 4 ->
|
||||
// mask:<1, 1, 1, 0, 1>, -> lengths:<1, 2, 1, 4, 2> , nums: <4, 1, 1, 1, 1> : 8 slices , slice_idx: 0
|
||||
//
|
||||
// return tuple<slice_lengths, slice_nums, slice_index>, slice_index is at which index will start
|
||||
// have split slices (right -> left)
|
||||
// or the first index that sliced length is different from the original length
|
||||
// clang-format on
|
||||
template <typename Seq,
|
||||
index_t SliceSize,
|
||||
typename Mask = typename uniform_sequence_gen<Seq::size(), 1>::type>
|
||||
constexpr auto reverse_slice_sequence(Seq,
|
||||
number<SliceSize>,
|
||||
Mask = typename uniform_sequence_gen<Seq::size(), 1>::type{})
|
||||
{
|
||||
static_assert(Seq::size() == Mask::size());
|
||||
using sliced_type =
|
||||
reverse_slice_sequence_impl<Seq,
|
||||
Mask,
|
||||
typename arithmetic_sequence_gen<0, Seq::size(), 1>::type,
|
||||
SliceSize>;
|
||||
static_assert(sliced_type::remaining_slice_sizes::front().value == 1,
|
||||
"can not evenly divide this sequence, please check");
|
||||
return make_tuple(typename sliced_type::dim_lengths{},
|
||||
typename sliced_type::dim_slices{},
|
||||
number<sliced_type::split_idx>{});
|
||||
}
|
||||
|
||||
//
|
||||
// slice tensor from x_dim, result in split in y_dim, not p_dim.
|
||||
// We don't support slice cross p_dim (aka, slice different threads)
|
||||
|
||||
173
include/ck_tile/core/utility/functional_with_tuple.hpp
Normal file
173
include/ck_tile/core/utility/functional_with_tuple.hpp
Normal file
@@ -0,0 +1,173 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
// This file should not be included inside tuple.hpp!
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include <stdint.h>
|
||||
#include <utility>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
namespace detail {
|
||||
|
||||
// RemainLengths: sequence<...>
|
||||
// Orders: sequence<...>
|
||||
template <class RemainLengths, class RamainUnpacks, class Orders>
|
||||
struct static_uford_impl
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr static_uford_impl()
|
||||
{
|
||||
static_assert(RemainLengths::size() > 0, "wrong! should not get here");
|
||||
static_assert(RamainUnpacks::size() > 0, "wrong! should not get here");
|
||||
}
|
||||
|
||||
template <class F, class CurrentUnpackIds>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f, CurrentUnpackIds) const
|
||||
{
|
||||
constexpr index_t pack_len = RamainUnpacks::front();
|
||||
static_for<0, RemainLengths::front(), pack_len>{}([=](auto I) {
|
||||
constexpr auto new_pack = generate_tuple(
|
||||
[&](auto idx_) {
|
||||
constexpr auto i_new_pack = number<I + idx_ % pack_len>{};
|
||||
constexpr auto i_pre_pack = number<idx_ / pack_len>{};
|
||||
return CurrentUnpackIds{}.at(i_pre_pack).push_back(i_new_pack);
|
||||
},
|
||||
number<CurrentUnpackIds::size() * pack_len>{});
|
||||
|
||||
static_uford_impl<decltype(RemainLengths::pop_front()),
|
||||
decltype(RamainUnpacks::pop_front()),
|
||||
Orders>{}(f, new_pack);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
template <class Orders>
|
||||
struct static_uford_impl<sequence<>, sequence<>, Orders>
|
||||
{
|
||||
template <class F, class PackedId>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f, PackedId) const
|
||||
{
|
||||
constexpr auto origin_packs = transform_tuples(
|
||||
[](auto pack_) { return decltype(pack_)::reorder_old_to_new(Orders{}); }, PackedId{});
|
||||
unpack(f, origin_packs);
|
||||
}
|
||||
};
|
||||
|
||||
template <class RemainLengths, class RamainUnpacks, class Orders>
|
||||
struct static_uford_one_shot_impl
|
||||
{
|
||||
template <class F, class CurrentUnpackIds, index_t current_acc>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f, CurrentUnpackIds, number<current_acc>) const
|
||||
{
|
||||
constexpr auto r_lens_stride =
|
||||
reverse_exclusive_scan_sequence(RemainLengths{}, multiplies{}, number<1>{});
|
||||
constexpr auto r_upks_stride =
|
||||
reverse_exclusive_scan_sequence(RamainUnpacks{}, multiplies{}, number<1>{});
|
||||
|
||||
constexpr index_t current_stride = r_lens_stride.front() / r_upks_stride.front();
|
||||
constexpr index_t pack_len = RamainUnpacks::front();
|
||||
constexpr index_t current_idx = (current_acc / current_stride) * pack_len;
|
||||
|
||||
constexpr auto new_pack = generate_tuple(
|
||||
[&](auto idx_) {
|
||||
constexpr auto i_new_pack = number<current_idx + idx_ % pack_len>{};
|
||||
constexpr auto i_pre_pack = number<idx_ / pack_len>{};
|
||||
return CurrentUnpackIds{}.at(i_pre_pack).push_back(i_new_pack);
|
||||
},
|
||||
number<CurrentUnpackIds::size() * pack_len>{});
|
||||
|
||||
static_uford_one_shot_impl<decltype(RemainLengths::pop_front()),
|
||||
decltype(RamainUnpacks::pop_front()),
|
||||
Orders>{}(f, new_pack, number<current_acc % current_stride>{});
|
||||
}
|
||||
};
|
||||
|
||||
template <class Orders>
|
||||
struct static_uford_one_shot_impl<sequence<>, sequence<>, Orders>
|
||||
{
|
||||
template <class F, class PackedId, index_t current_acc>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f, PackedId, number<current_acc>) const
|
||||
{
|
||||
constexpr auto origin_packs = transform_tuples(
|
||||
[](auto pack_) { return decltype(pack_)::reorder_old_to_new(Orders{}); }, PackedId{});
|
||||
unpack(f, origin_packs);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// TODO: we may unify static_ford/static_uford in the future
|
||||
//
|
||||
// loop over nd space(sequence) with packs
|
||||
// you must make sure the function passed in has same number of argument
|
||||
//
|
||||
// e.g.
|
||||
// Lengths=seq<2, 3, 4>, Unpacks=<1, 1, 2>
|
||||
// static_uford<Lengths, Unpacks>{}([&](auto i_0, auto i_1){}); // require 2 args(packs)
|
||||
//
|
||||
// loop #0, i_0=seq<0, 0, 0>, i_1=<0, 0, 1>
|
||||
// loop #1, i_0=seq<0, 0, 2>, i_1=<0, 0, 3>
|
||||
// loop #2, i_0=seq<0, 1, 0>, i_1=<0, 1, 1>
|
||||
// loop #3, i_0=seq<0, 1, 2>, i_1=<0, 1, 3>
|
||||
// loop #4, i_0=seq<0, 2, 0>, i_1=<0, 2, 1>
|
||||
// loop #5, i_0=seq<0, 2, 2>, i_1=<0, 2, 3>
|
||||
// loop #6, i_0=seq<1, 0, 0>, i_1=<1, 0, 1>
|
||||
// ...
|
||||
template <class Lengths,
|
||||
class Unpacks = typename uniform_sequence_gen<Lengths::size(), 1>::type,
|
||||
class Orders = typename arithmetic_sequence_gen<0, Lengths::size(), 1>::type>
|
||||
struct static_uford
|
||||
{
|
||||
static constexpr index_t num_packs = reduce_on_sequence(Unpacks{}, multiplies{}, number<1>{});
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr static_uford()
|
||||
{
|
||||
static_assert(Lengths::size() > 0, "wrong! Lengths is empty");
|
||||
static_assert(Lengths::size() == Unpacks::size(), "wrong! inconsistent size");
|
||||
static_assert(Lengths::size() == Orders::size(), "wrong! inconsistent size");
|
||||
static_for<0, Lengths::size(), 1>{}(
|
||||
[&](auto i) { static_assert(Lengths{}.at(i) % Unpacks{}.at(i) == 0); });
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_access()
|
||||
{
|
||||
using L_ = decltype(Lengths{} / Unpacks{});
|
||||
|
||||
return reduce_on_sequence(L_{}, multiplies{}, number<1>{});
|
||||
}
|
||||
|
||||
// F signature: F(sequence<...> multi_id...)
|
||||
// multi_id is the unordered multi-index
|
||||
template <class F>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
|
||||
{
|
||||
constexpr auto ordered_lengths = Lengths::reorder_new_to_old(Orders{});
|
||||
constexpr auto ordered_unpacks = Unpacks::reorder_new_to_old(Orders{});
|
||||
detail::static_uford_impl<decltype(ordered_lengths), decltype(ordered_unpacks), Orders>{}(
|
||||
f, make_tuple(sequence<>{}));
|
||||
}
|
||||
|
||||
// this version is friendly for issue function one by one
|
||||
template <class F, index_t i_access>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f, number<i_access>) const
|
||||
{
|
||||
static_assert(i_access < get_num_of_access());
|
||||
constexpr auto ordered_lengths = Lengths::reorder_new_to_old(Orders{});
|
||||
constexpr auto ordered_unpacks = Unpacks::reorder_new_to_old(Orders{});
|
||||
detail::static_uford_one_shot_impl<decltype(ordered_lengths),
|
||||
decltype(ordered_unpacks),
|
||||
Orders>{}(
|
||||
f, make_tuple(sequence<>{}), number<i_access>{});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -21,7 +21,7 @@
|
||||
#include "ck_tile/host/reference/reference_batched_softmax.hpp"
|
||||
#include "ck_tile/host/reference/reference_gemm.hpp"
|
||||
#include "ck_tile/host/reference/reference_im2col.hpp"
|
||||
#include "ck_tile/host/reference/reference_layernorm2d.hpp"
|
||||
#include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp"
|
||||
#include "ck_tile/host/reference/reference_reduce.hpp"
|
||||
#include "ck_tile/host/reference/reference_softmax.hpp"
|
||||
#include "ck_tile/host/stream_config.hpp"
|
||||
|
||||
@@ -4,6 +4,9 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp"
|
||||
#include "ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp"
|
||||
#include "ck_tile/ops/layernorm2d/pipeline/tile_layernorm2d_fwd_shape.hpp"
|
||||
#include "ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_shape.hpp"
|
||||
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp"
|
||||
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp"
|
||||
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
@@ -5,37 +5,57 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/welford/thread/thread_welford.hpp"
|
||||
#include "ck_tile/ops/welford/warp/warp_welford.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// host side args
|
||||
struct Layernorm2dFwdHostArgs
|
||||
{
|
||||
const void* p_x;
|
||||
const void* p_gamma;
|
||||
const void* p_beta;
|
||||
|
||||
void* p_y;
|
||||
void* p_mean;
|
||||
void* p_invStd;
|
||||
|
||||
float epsilon;
|
||||
|
||||
index_t m;
|
||||
index_t n;
|
||||
index_t stride; // row_stride
|
||||
};
|
||||
|
||||
// TODO: Extract some type to wrapper class
|
||||
template <typename Problem_>
|
||||
template <typename Pipeline_>
|
||||
struct Layernorm2dFwd
|
||||
{
|
||||
using Problem = ck_tile::remove_cvref_t<Problem_>;
|
||||
using Pipeline = remove_cvref_t<Pipeline_>;
|
||||
using Problem = typename Pipeline::Problem;
|
||||
|
||||
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
|
||||
using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>;
|
||||
using BetaDataType = ck_tile::remove_cvref_t<typename Problem::BetaDataType>;
|
||||
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
|
||||
using MeanDataType = ck_tile::remove_cvref_t<typename Problem::MeanDataType>;
|
||||
using InvStdDataType = ck_tile::remove_cvref_t<typename Problem::InvStdDataType>;
|
||||
using XDataType = remove_cvref_t<typename Problem::XDataType>;
|
||||
using GammaDataType = remove_cvref_t<typename Problem::GammaDataType>;
|
||||
using BetaDataType = remove_cvref_t<typename Problem::BetaDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using YDataType = remove_cvref_t<typename Problem::YDataType>;
|
||||
using MeanDataType = remove_cvref_t<typename Problem::MeanDataType>;
|
||||
using InvStdDataType = remove_cvref_t<typename Problem::InvStdDataType>;
|
||||
|
||||
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
|
||||
static constexpr bool kHasBeta = !std::is_same_v<BetaDataType, ck_tile::null_type>;
|
||||
static constexpr bool kSaveMean = !std::is_same_v<MeanDataType, ck_tile::null_type>;
|
||||
static constexpr bool kSaveInvStd = !std::is_same_v<InvStdDataType, ck_tile::null_type>;
|
||||
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, null_type>;
|
||||
static constexpr bool kHasBeta = !std::is_same_v<BetaDataType, null_type>;
|
||||
static constexpr bool kSaveMeanInvStd = Problem::kSaveMeanInvStd;
|
||||
static constexpr bool kSaveMean = Problem::kSaveMeanInvStd;
|
||||
static constexpr bool kSaveInvStd = Problem::kSaveMeanInvStd;
|
||||
|
||||
static constexpr ck_tile::index_t kMPerBlock = Problem::BlockShape::kMPerBlock;
|
||||
static constexpr ck_tile::index_t kNPerBlock = Problem::BlockShape::kNPerBlock;
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr index_t Block_M = Problem::BlockShape::Block_M;
|
||||
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
|
||||
static constexpr bool kPadM = false; // always no need to pad along M
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kTwoPass = Problem::kTwoPass;
|
||||
|
||||
static constexpr ck_tile::index_t kNThreadPerWarp = Problem::BlockShape::kNThreadPerWarp;
|
||||
static constexpr ck_tile::index_t kNPerThread = Problem::BlockShape::kNPerThread;
|
||||
static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
|
||||
static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
|
||||
static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N;
|
||||
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
@@ -52,400 +72,177 @@ struct Layernorm2dFwd
|
||||
|
||||
float epsilon;
|
||||
|
||||
ck_tile::index_t M;
|
||||
ck_tile::index_t N;
|
||||
index_t m;
|
||||
index_t n;
|
||||
index_t stride; // row_stride
|
||||
};
|
||||
using Hargs = Layernorm2dFwdHostArgs;
|
||||
|
||||
CK_TILE_HOST static constexpr Kargs MakeKargs(const void* p_x,
|
||||
const void* p_gamma,
|
||||
const void* p_beta,
|
||||
void* p_y,
|
||||
void* p_mean,
|
||||
void* p_invStd,
|
||||
float epsilon,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N)
|
||||
CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
|
||||
{
|
||||
return Kargs{p_x, p_gamma, p_beta, p_y, p_mean, p_invStd, epsilon, M, N};
|
||||
return Kargs{hargs.p_x,
|
||||
hargs.p_gamma,
|
||||
hargs.p_beta,
|
||||
hargs.p_y,
|
||||
hargs.p_mean,
|
||||
hargs.p_invStd,
|
||||
hargs.epsilon,
|
||||
hargs.m,
|
||||
hargs.n,
|
||||
hargs.stride};
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t M) { return M / kMPerBlock; }
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::kBlockSize; }
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution()
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<S::kMWarpPerBlock, S::kMThreadPerWarp, S::kMPerThread>,
|
||||
sequence<S::kNWarpPerBlock, S::kNThreadPerWarp, S::kNPerThread>>,
|
||||
tuple<sequence<1, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<0, 0>, sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 2>>{});
|
||||
return (hargs.m + Block_M - 1) / Block_M;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeGammaBetaBlockTileDistribution()
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; }
|
||||
|
||||
// clang-format off
|
||||
template <typename T> struct t2s;
|
||||
template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
|
||||
template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
|
||||
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
|
||||
template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
|
||||
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
|
||||
// clang-format on
|
||||
|
||||
// in byte
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Pipeline::GetSmemSize(); }
|
||||
|
||||
CK_TILE_HOST static std::string GetName()
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
// clang-format off
|
||||
using S_ = typename Problem::BlockShape;
|
||||
auto surfix = [&] () {
|
||||
std::string n;
|
||||
if (kPadN) n += "_pn";
|
||||
if (kSaveMeanInvStd) n += "_mv";
|
||||
if (kTwoPass) n += "_2p";
|
||||
return n; }();
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<S::kMWarpPerBlock, S::kMThreadPerWarp>,
|
||||
tuple<sequence<S::kNWarpPerBlock, S::kNThreadPerWarp, S::kNPerThread>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 1>>,
|
||||
tuple<sequence<0, 0>, sequence<1, 1>>,
|
||||
sequence<1>,
|
||||
sequence<2>>{});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static int GetWelfordMaxCount(int N)
|
||||
{
|
||||
constexpr ck_tile::index_t kNThreadPerBlock = kNPerBlock / kNPerThread;
|
||||
|
||||
int thread_id_n = get_thread_id() % kNThreadPerBlock;
|
||||
int max_count =
|
||||
__builtin_amdgcn_readfirstlane(N < kNPerBlock ? 0 : kNPerThread * (N / kNPerBlock));
|
||||
int n_per_block_tail_loop =
|
||||
__builtin_amdgcn_readfirstlane(N - max_count * kNThreadPerBlock);
|
||||
|
||||
if(n_per_block_tail_loop > 0)
|
||||
{
|
||||
int thread_max_n = (thread_id_n + 1) * kNPerThread;
|
||||
int delta = thread_max_n - n_per_block_tail_loop;
|
||||
delta = clamp(thread_max_n - n_per_block_tail_loop, 0, kNPerThread);
|
||||
max_count += kNPerThread - delta;
|
||||
}
|
||||
|
||||
return max_count;
|
||||
}
|
||||
|
||||
template <typename DistributedTensor>
|
||||
CK_TILE_DEVICE static auto InvSqrt(const DistributedTensor& in_dstr_tensor,
|
||||
const ComputeDataType epsilon)
|
||||
{
|
||||
// TODO: Investigate fast inverse square root algorithm with epsilon
|
||||
constexpr auto spans = DistributedTensor::get_distributed_spans();
|
||||
|
||||
DistributedTensor out_dstr_tensor;
|
||||
|
||||
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
out_dstr_tensor(i_idx) = type_convert<ComputeDataType>(1.0f) /
|
||||
ck_tile::sqrt(in_dstr_tensor[i_idx] + epsilon);
|
||||
});
|
||||
|
||||
return out_dstr_tensor;
|
||||
}
|
||||
|
||||
template <typename XBlockWindow,
|
||||
typename GammaBlockWindow,
|
||||
typename BetaBlockWindow,
|
||||
typename YBlockWindow,
|
||||
typename MeanBlockWindow,
|
||||
typename InvStdBlockWindow,
|
||||
bool Cond = (kHasGamma && kHasBeta)>
|
||||
CK_TILE_DEVICE std::enable_if_t<Cond>
|
||||
TwoPassLayernorm2dFwd(XBlockWindow& x_block_window,
|
||||
GammaBlockWindow& gamma_block_window,
|
||||
BetaBlockWindow& beta_block_window,
|
||||
YBlockWindow& y_block_window,
|
||||
MeanBlockWindow& mean_block_window,
|
||||
InvStdBlockWindow& inv_std_block_window,
|
||||
ComputeDataType epsilon,
|
||||
ck_tile::index_t N) const
|
||||
{
|
||||
// TODO - Optimize tail loop to reduce move_tile_window()
|
||||
index_t num_n_tile_iteration =
|
||||
__builtin_amdgcn_readfirstlane(integer_divide_ceil(N, kNPerBlock));
|
||||
|
||||
int welford_max_count = GetWelfordMaxCount(N);
|
||||
ThreadWelford<ComputeDataType, XDataType> thread_welford{welford_max_count};
|
||||
|
||||
using XTensorType = decltype(load_tile(x_block_window));
|
||||
auto mean_compute_block_tensor =
|
||||
thread_welford.template MakeInitialMeanVarDistributedTensor<XTensorType>();
|
||||
auto var_compute_block_tensor =
|
||||
thread_welford.template MakeInitialMeanVarDistributedTensor<XTensorType>();
|
||||
|
||||
clear_tile(mean_compute_block_tensor);
|
||||
clear_tile(var_compute_block_tensor);
|
||||
|
||||
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
|
||||
{
|
||||
const auto x_block_tensor = load_tile(x_block_window);
|
||||
|
||||
thread_welford(x_block_tensor, mean_compute_block_tensor, var_compute_block_tensor);
|
||||
move_tile_window(x_block_window, {0, kNPerBlock});
|
||||
}
|
||||
|
||||
// TODO: support cross warp Welford
|
||||
WarpMergeWelford<ComputeDataType, true>{}(
|
||||
mean_compute_block_tensor, var_compute_block_tensor, thread_welford.cur_count_);
|
||||
|
||||
auto inv_std_compute_block_tensor = InvSqrt(var_compute_block_tensor, epsilon);
|
||||
|
||||
if constexpr(kSaveMean)
|
||||
store_tile(mean_block_window, cast_tile<MeanDataType>(mean_compute_block_tensor));
|
||||
if constexpr(kSaveInvStd)
|
||||
store_tile(inv_std_block_window,
|
||||
cast_tile<InvStdDataType>(inv_std_compute_block_tensor));
|
||||
|
||||
// reverse read x to reuse cache
|
||||
ck_tile::index_t stride_to_right_most_window =
|
||||
N % kNPerBlock == 0 ? N - kNPerBlock : N - N % kNPerBlock;
|
||||
|
||||
move_tile_window(x_block_window, {0, -kNPerBlock});
|
||||
move_tile_window(gamma_block_window, {stride_to_right_most_window});
|
||||
move_tile_window(beta_block_window, {stride_to_right_most_window});
|
||||
move_tile_window(y_block_window, {0, stride_to_right_most_window});
|
||||
|
||||
// Normalization
|
||||
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
|
||||
{
|
||||
const auto x_block_tensor = load_tile(x_block_window);
|
||||
const auto gamma_block_tensor = load_tile(gamma_block_window);
|
||||
const auto beta_block_tensor = load_tile(beta_block_window);
|
||||
|
||||
constexpr auto x_spans = decltype(x_block_tensor)::get_distributed_spans();
|
||||
|
||||
auto y_block_tensor =
|
||||
make_static_distributed_tensor<YDataType>(x_block_tensor.get_tile_distribution());
|
||||
|
||||
sweep_tile_span(x_spans[I1], [&](auto idx1) {
|
||||
constexpr auto j_idx = make_tuple(idx1);
|
||||
const auto gamma = type_convert<ComputeDataType>(gamma_block_tensor[j_idx]);
|
||||
const auto beta = type_convert<ComputeDataType>(beta_block_tensor[j_idx]);
|
||||
|
||||
sweep_tile_span(x_spans[I0], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
const auto mean = mean_compute_block_tensor[i_idx];
|
||||
const auto inv_std = inv_std_compute_block_tensor[i_idx];
|
||||
|
||||
const auto x = type_convert<ComputeDataType>(x_block_tensor[i_j_idx]);
|
||||
auto y = (x - mean) * inv_std * gamma + beta;
|
||||
|
||||
y_block_tensor(i_j_idx) = type_convert<YDataType>(y);
|
||||
});
|
||||
});
|
||||
|
||||
store_tile(y_block_window, y_block_tensor);
|
||||
|
||||
move_tile_window(x_block_window, {0, -kNPerBlock});
|
||||
move_tile_window(gamma_block_window, {-kNPerBlock});
|
||||
move_tile_window(beta_block_window, {-kNPerBlock});
|
||||
move_tile_window(y_block_window, {0, -kNPerBlock});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename XBlockWindow,
|
||||
typename GammaBlockWindow,
|
||||
typename BetaBlockWindow,
|
||||
typename YBlockWindow,
|
||||
typename MeanBlockWindow,
|
||||
typename InvStdBlockWindow,
|
||||
bool Cond = (kHasGamma && kHasBeta)>
|
||||
CK_TILE_DEVICE std::enable_if_t<Cond>
|
||||
OnePassLayernorm2dFwd(XBlockWindow& x_block_window,
|
||||
GammaBlockWindow& gamma_block_window,
|
||||
BetaBlockWindow& beta_block_window,
|
||||
YBlockWindow& y_block_window,
|
||||
MeanBlockWindow& mean_block_window,
|
||||
InvStdBlockWindow& inv_std_block_window,
|
||||
ComputeDataType epsilon,
|
||||
ck_tile::index_t N) const
|
||||
{
|
||||
int welford_max_count = GetWelfordMaxCount(N);
|
||||
ThreadWelford<ComputeDataType, XDataType> thread_welford{welford_max_count};
|
||||
|
||||
using XTensorType = decltype(load_tile(x_block_window));
|
||||
auto mean_compute_block_tensor =
|
||||
thread_welford.template MakeInitialMeanVarDistributedTensor<XTensorType>();
|
||||
auto var_compute_block_tensor =
|
||||
thread_welford.template MakeInitialMeanVarDistributedTensor<XTensorType>();
|
||||
|
||||
clear_tile(mean_compute_block_tensor);
|
||||
clear_tile(var_compute_block_tensor);
|
||||
|
||||
const auto x_block_tensor = load_tile(x_block_window);
|
||||
thread_welford(x_block_tensor, mean_compute_block_tensor, var_compute_block_tensor);
|
||||
// TODO: support cross warp Welford
|
||||
WarpMergeWelford<ComputeDataType, true>{}(
|
||||
mean_compute_block_tensor, var_compute_block_tensor, thread_welford.cur_count_);
|
||||
|
||||
auto inv_std_compute_block_tensor = InvSqrt(var_compute_block_tensor, epsilon);
|
||||
|
||||
if constexpr(kSaveMean)
|
||||
store_tile(mean_block_window, cast_tile<MeanDataType>(mean_compute_block_tensor));
|
||||
if constexpr(kSaveInvStd)
|
||||
store_tile(inv_std_block_window,
|
||||
cast_tile<InvStdDataType>(inv_std_compute_block_tensor));
|
||||
|
||||
// normalize
|
||||
const auto gamma_block_tensor = load_tile(gamma_block_window);
|
||||
const auto beta_block_tensor = load_tile(beta_block_window);
|
||||
|
||||
constexpr auto x_spans = decltype(x_block_tensor)::get_distributed_spans();
|
||||
|
||||
auto y_block_tensor =
|
||||
make_static_distributed_tensor<YDataType>(x_block_tensor.get_tile_distribution());
|
||||
|
||||
sweep_tile_span(x_spans[I1], [&](auto idx1) {
|
||||
constexpr auto j_idx = make_tuple(idx1);
|
||||
const auto gamma = type_convert<ComputeDataType>(gamma_block_tensor[j_idx]);
|
||||
const auto beta = type_convert<ComputeDataType>(beta_block_tensor[j_idx]);
|
||||
|
||||
sweep_tile_span(x_spans[I0], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
const auto mean = mean_compute_block_tensor[i_idx];
|
||||
const auto inv_std = inv_std_compute_block_tensor[i_idx];
|
||||
|
||||
const auto x = type_convert<ComputeDataType>(x_block_tensor[i_j_idx]);
|
||||
auto y = (x - mean) * inv_std * gamma + beta;
|
||||
|
||||
y_block_tensor(i_j_idx) = type_convert<YDataType>(y);
|
||||
});
|
||||
});
|
||||
|
||||
store_tile(y_block_window, y_block_tensor);
|
||||
#define _SS_ std::string
|
||||
#define _TS_ std::to_string
|
||||
return _SS_("layernorm2d_fwd_") + _SS_(t2s<XDataType>::name) + "_" +
|
||||
_TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" +
|
||||
_TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" +
|
||||
_SS_(Pipeline::name) + surfix;
|
||||
#undef _SS_
|
||||
#undef _TS_
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
const auto x_m_n = [&]() {
|
||||
const auto x_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
const auto iM = get_block_id() * Block_M;
|
||||
|
||||
const auto x_window = [&]() {
|
||||
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const XDataType*>(kargs.p_x),
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(kargs.N, 1),
|
||||
number<kNPerThread>{},
|
||||
make_tuple(kargs.m, kargs.n),
|
||||
make_tuple(kargs.stride, 1),
|
||||
number<Vector_N>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(x_dram_naive,
|
||||
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
|
||||
sequence<kPadM, kPadN>{});
|
||||
// NOTE: we don't do any pad in this kernel for loading, assume that inside kernel will
|
||||
// check the max count dynamically
|
||||
const auto tmp2_ = pad_tensor_view(
|
||||
tmp_, make_tuple(number<Block_M>{}, number<Block_N>{}), sequence<false, false>{});
|
||||
return make_tile_window(
|
||||
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
|
||||
}();
|
||||
|
||||
const auto gamma_n = [&]() {
|
||||
const auto gamma_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
const auto gamma_window = [&]() {
|
||||
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const GammaDataType*>(kargs.p_gamma),
|
||||
make_tuple(kargs.N),
|
||||
make_tuple(kargs.n),
|
||||
make_tuple(1),
|
||||
number<kNPerThread>{},
|
||||
number<Vector_N>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(
|
||||
gamma_dram_naive, make_tuple(number<kNPerBlock>{}), sequence<kPadN>{});
|
||||
const auto tmp2_ =
|
||||
pad_tensor_view(tmp_, make_tuple(number<Block_N>{}), sequence<false>{});
|
||||
|
||||
return make_tile_window(tmp2_, make_tuple(number<Block_N>{}), {0});
|
||||
}();
|
||||
|
||||
const auto beta_n = [&]() {
|
||||
const auto gamma_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
const auto beta_window = [&]() {
|
||||
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const BetaDataType*>(kargs.p_beta),
|
||||
make_tuple(kargs.N),
|
||||
make_tuple(kargs.n),
|
||||
make_tuple(1),
|
||||
number<kNPerThread>{},
|
||||
number<Vector_N>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(
|
||||
gamma_dram_naive, make_tuple(number<kNPerBlock>{}), sequence<kPadN>{});
|
||||
const auto tmp2_ =
|
||||
pad_tensor_view(tmp_, make_tuple(number<Block_N>{}), sequence<false>{});
|
||||
return make_tile_window(tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {0});
|
||||
}();
|
||||
|
||||
const auto iM = get_block_id() * kMPerBlock;
|
||||
|
||||
constexpr auto xDstr = MakeXBlockTileDistribution();
|
||||
|
||||
auto x_block_window = make_tile_window(
|
||||
x_m_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, 0}, xDstr);
|
||||
|
||||
const auto y_m_n = [&]() {
|
||||
const auto y_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
auto y_window = [&]() {
|
||||
auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<YDataType*>(kargs.p_y),
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(kargs.N, 1),
|
||||
number<kNPerThread>{},
|
||||
make_tuple(kargs.m, kargs.n),
|
||||
make_tuple(kargs.stride, 1),
|
||||
number<Vector_N>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(y_dram_naive,
|
||||
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
|
||||
sequence<kPadM, kPadN>{});
|
||||
auto tmp2_ = pad_tensor_view(
|
||||
tmp_, make_tuple(number<Block_M>{}, number<Block_N>{}), sequence<kPadM, kPadN>{});
|
||||
return make_tile_window(
|
||||
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
|
||||
}();
|
||||
|
||||
auto y_block_window = make_tile_window(
|
||||
y_m_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, 0});
|
||||
|
||||
constexpr auto gammaDstr = MakeGammaBetaBlockTileDistribution();
|
||||
constexpr auto betaDstr = gammaDstr;
|
||||
|
||||
auto gamma_block_window =
|
||||
make_tile_window(gamma_n, make_tuple(number<kNPerBlock>{}), {0}, gammaDstr);
|
||||
|
||||
auto beta_block_window = make_tile_window(
|
||||
beta_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {0}, betaDstr);
|
||||
|
||||
auto mean_block_window = [&]() {
|
||||
auto mean_window = [&]() {
|
||||
if constexpr(kSaveMean)
|
||||
{
|
||||
const auto mean_m = [&]() {
|
||||
const auto mean_dram_naive =
|
||||
make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
static_cast<MeanDataType*>(kargs.p_mean),
|
||||
make_tuple(kargs.M),
|
||||
make_tuple(kargs.m),
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(
|
||||
mean_dram_naive, make_tuple(number<kMPerBlock>{}), sequence<kPadM>{});
|
||||
mean_dram_naive, make_tuple(number<Block_M>{}), sequence<kPadM>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(mean_m, make_tuple(number<kMPerBlock>{}), {iM});
|
||||
return make_tile_window(mean_m, make_tuple(number<Block_M>{}), {iM});
|
||||
}
|
||||
else
|
||||
return make_null_tile_window(make_tuple(number<kMPerBlock>{}));
|
||||
return make_null_tile_window(make_tuple(number<Block_M>{}));
|
||||
}();
|
||||
|
||||
auto inv_std_block_window = [&]() {
|
||||
auto inv_std_window = [&]() {
|
||||
if constexpr(kSaveInvStd)
|
||||
{
|
||||
const auto inv_std_m = [&]() {
|
||||
const auto inv_std_dram_naive =
|
||||
make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
static_cast<InvStdDataType*>(kargs.p_invStd),
|
||||
make_tuple(kargs.M),
|
||||
make_tuple(kargs.m),
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(
|
||||
inv_std_dram_naive, make_tuple(number<kMPerBlock>{}), sequence<kPadM>{});
|
||||
inv_std_dram_naive, make_tuple(number<Block_M>{}), sequence<kPadM>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(inv_std_m, make_tuple(number<kMPerBlock>{}), {iM});
|
||||
return make_tile_window(inv_std_m, make_tuple(number<Block_M>{}), {iM});
|
||||
}
|
||||
else
|
||||
return make_null_tile_window(make_tuple(number<kMPerBlock>{}));
|
||||
return make_null_tile_window(make_tuple(number<Block_M>{}));
|
||||
}();
|
||||
|
||||
if(kargs.N <= kNPerBlock)
|
||||
OnePassLayernorm2dFwd(x_block_window,
|
||||
gamma_block_window,
|
||||
beta_block_window,
|
||||
y_block_window,
|
||||
mean_block_window,
|
||||
inv_std_block_window,
|
||||
static_cast<const ComputeDataType>(kargs.epsilon),
|
||||
kargs.N);
|
||||
else
|
||||
TwoPassLayernorm2dFwd(x_block_window,
|
||||
gamma_block_window,
|
||||
beta_block_window,
|
||||
y_block_window,
|
||||
mean_block_window,
|
||||
inv_std_block_window,
|
||||
static_cast<const ComputeDataType>(kargs.epsilon),
|
||||
kargs.N);
|
||||
__shared__ char smem[GetSmemSize()];
|
||||
|
||||
Pipeline{}(x_window,
|
||||
gamma_window,
|
||||
beta_window,
|
||||
y_window,
|
||||
mean_window,
|
||||
inv_std_window,
|
||||
static_cast<const ComputeDataType>(kargs.epsilon),
|
||||
kargs.n,
|
||||
smem);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
/*
|
||||
// clang-format off
|
||||
|
||||
4-level descriptor: BlockTile-> WarpPerBlock-> WarpTile-> Vector
|
||||
|
||||
Block_N (Warp_N * WarpPerBlock_N * Repeat_N )
|
||||
+<----------------------< Repeat_N(2)>--------------------->+
|
||||
| |
|
||||
+<-- <WarpPerBlock_N(2)> -->+
|
||||
Warp_N
|
||||
+--------------+--------------+--------------+--------------+----+----------------+
|
||||
Warp_M | wrap_0 | wrap_1 | | ^ ^
|
||||
+--------------+--------------+ | <WarpPerBlock_M(2)> |
|
||||
| wrap_2 | wrap_3 | | v
|
||||
+--------------+--------------+--------------+--------------+----+ Block_M
|
||||
| | |
|
||||
+ + |
|
||||
| | | v
|
||||
+--------------+--------------+--------------+--------------+ +
|
||||
|
||||
each Warp-tile (e.g 16 thrd per row)
|
||||
|
||||
Vector_N (contiguous pixels each thrd holds along N, or vector size)
|
||||
+-----------+-----------+-----------+-----------+-----------+
|
||||
| thrd_0 | thrd_1 | thrd_2 | thrd_3 | ... Vector_M
|
||||
+-----------+-----------+-----------+-----------+-----------+
|
||||
| thrd_16 | thrd_17 | thrd_18 | thrd_19 | ...
|
||||
+-----------+-----------+-----------+-----------+-----------+
|
||||
// clang-format on
|
||||
*/
|
||||
template <typename BlockTile_, // block size, seq<M, N>
|
||||
typename WarpPerBlock_, // num warps along seq<M, N>
|
||||
typename WarpTile_, // warp size, seq<M, N>
|
||||
typename Vector_, // contiguous pixels(vector size) along seq<M, N>
|
||||
index_t BlockSize_ =
|
||||
warpSize* reduce_on_sequence(WarpPerBlock_{}, multiplies{}, number<1>{})>
|
||||
struct Layernorm2dShape
|
||||
{
|
||||
// block size
|
||||
static constexpr index_t Block_M = BlockTile_::at(number<0>{});
|
||||
static constexpr index_t Block_N = BlockTile_::at(number<1>{});
|
||||
|
||||
// num warps along seq<M, N>, within each block
|
||||
static constexpr index_t WarpPerBlock_M = WarpPerBlock_::at(number<0>{});
|
||||
static constexpr index_t WarpPerBlock_N = WarpPerBlock_::at(number<1>{});
|
||||
|
||||
// warp size
|
||||
static constexpr index_t Warp_M = WarpTile_::at(number<0>{});
|
||||
static constexpr index_t Warp_N = WarpTile_::at(number<1>{});
|
||||
|
||||
static_assert(Block_M % (WarpPerBlock_M * Warp_M) == 0);
|
||||
static_assert(Block_N % (WarpPerBlock_N * Warp_N) == 0);
|
||||
// repeat of each thread along seq<M, N>
|
||||
static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M);
|
||||
static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N);
|
||||
|
||||
// vector size along seq<M, N>
|
||||
static constexpr index_t Vector_M = Vector_::at(number<0>{});
|
||||
static constexpr index_t Vector_N = Vector_::at(number<1>{});
|
||||
|
||||
static_assert(Warp_M % Vector_M == 0);
|
||||
static_assert(Warp_N % Vector_N == 0);
|
||||
// num of threads along seq<M, N>, within each warp
|
||||
static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M;
|
||||
static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N;
|
||||
|
||||
static constexpr index_t BlockSize = BlockSize_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,34 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename XDataType_,
|
||||
typename GammaDataType_,
|
||||
typename BetaDataType_,
|
||||
typename ComputeDataType_,
|
||||
typename YDataType_,
|
||||
typename MeanDataType_,
|
||||
typename InvStdDataType_,
|
||||
typename BlockShape_,
|
||||
bool kPadM_,
|
||||
bool kPadN_>
|
||||
struct BlockLayernorm2dFwdProblem
|
||||
{
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
using GammaDataType = remove_cvref_t<GammaDataType_>;
|
||||
using BetaDataType = remove_cvref_t<BetaDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
using YDataType = remove_cvref_t<YDataType_>;
|
||||
using MeanDataType = remove_cvref_t<MeanDataType_>;
|
||||
using InvStdDataType = remove_cvref_t<InvStdDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
static constexpr bool kPadM = kPadM_;
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,99 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/welford/block/block_welford_problem.hpp"
|
||||
#include "ck_tile/ops/welford/block/block_welford.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct Layernorm2dFwdPipelineDefaultPolicy
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution()
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M, S::Vector_M>,
|
||||
sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
|
||||
tuple<sequence<1, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>, sequence<2, 2>>,
|
||||
sequence<1, 1, 2, 2>,
|
||||
sequence<0, 3, 0, 3>>{});
|
||||
}
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeGammaBetaBlockTileDistribution()
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<S::WarpPerBlock_M, S::ThreadPerWarp_M>,
|
||||
tuple<sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>, sequence<1, 2>>,
|
||||
sequence<1, 1>,
|
||||
sequence<0, 3>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelford()
|
||||
{
|
||||
using P_ = BlockWelfordProblem<typename Problem::XDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
|
||||
return BlockWelford<P_>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelfordSync()
|
||||
{
|
||||
using P_ = BlockWelfordProblem<typename Problem::XDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
|
||||
return BlockWelfordSync<P_>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelfordCrossWarpSync()
|
||||
{
|
||||
using P_ = BlockWelfordProblem<typename Problem::XDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
|
||||
return BlockWelfordCrossWarpSync<P_>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
if constexpr(Problem::kNeedCrossWarpSync)
|
||||
{
|
||||
using P_ = BlockWelfordProblem<typename Problem::XDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
|
||||
using block_welford = BlockWelford<P_>;
|
||||
using x_block_tile =
|
||||
decltype(make_static_distributed_tensor<typename Problem::XDataType>(
|
||||
MakeXBlockTileDistribution<Problem>()));
|
||||
using mean_var_block_tile =
|
||||
decltype(block_welford::template MakeMeanVarBlockTile<x_block_tile>());
|
||||
|
||||
return GetBlockWelfordCrossWarpSync<Problem>()
|
||||
.template GetSmemSize<mean_var_block_tile>();
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1; // zero size arrays are an extension
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,119 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp"
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem_, typename Policy_ = Layernorm2dFwdPipelineDefaultPolicy>
|
||||
struct Layernorm2dFwdPipelineOnePass
|
||||
{
|
||||
using Problem = ck_tile::remove_cvref_t<Problem_>;
|
||||
using Policy = ck_tile::remove_cvref_t<Policy_>;
|
||||
|
||||
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
|
||||
using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>;
|
||||
using BetaDataType = ck_tile::remove_cvref_t<typename Problem::BetaDataType>;
|
||||
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
|
||||
using MeanDataType = ck_tile::remove_cvref_t<typename Problem::MeanDataType>;
|
||||
using InvStdDataType = ck_tile::remove_cvref_t<typename Problem::InvStdDataType>;
|
||||
|
||||
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
|
||||
static constexpr bool kHasBeta = !std::is_same_v<BetaDataType, ck_tile::null_type>;
|
||||
static constexpr bool kSaveMean = Problem::kSaveMeanInvStd;
|
||||
static constexpr bool kSaveInvStd = Problem::kSaveMeanInvStd;
|
||||
|
||||
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
|
||||
static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
|
||||
static constexpr const char* name = []() {
|
||||
if constexpr(kNeedCrossWarpSync)
|
||||
return "bpr"; // block per row
|
||||
else
|
||||
return "wpr"; // warp per row
|
||||
}();
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename XWindow,
|
||||
typename GammaWindow,
|
||||
typename BetaWindow,
|
||||
typename YWindow,
|
||||
typename MeanWindow,
|
||||
typename InvStdWindow>
|
||||
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
|
||||
const GammaWindow& gamma_window_,
|
||||
const BetaWindow& beta_window_,
|
||||
YWindow& y_window,
|
||||
MeanWindow& mean_window,
|
||||
InvStdWindow& inv_std_window,
|
||||
ComputeDataType epsilon,
|
||||
ck_tile::index_t row_size,
|
||||
void* smem) const
|
||||
{
|
||||
const auto x_window =
|
||||
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
const auto gamma_window = make_tile_window(
|
||||
gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
|
||||
const auto beta_window = make_tile_window(
|
||||
beta_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
|
||||
|
||||
const auto x = load_tile(x_window);
|
||||
int cur_count = 0;
|
||||
int max_count =
|
||||
block_tile_welford_calculate_max_count<typename Problem::BlockShape>(row_size);
|
||||
auto block_welford = Policy::template GetBlockWelford<Problem>();
|
||||
auto block_welford_sync = Policy::template GetBlockWelfordSync<Problem>();
|
||||
auto block_welford_cross_warp_sync =
|
||||
Policy::template GetBlockWelfordCrossWarpSync<Problem>();
|
||||
|
||||
// load gamma/beta (TODO: support no gamma/beta?)
|
||||
const auto gamma = load_tile(gamma_window);
|
||||
const auto beta = load_tile(beta_window);
|
||||
|
||||
// compute welford each-thread->cross-lane->cross-warp
|
||||
auto [mean, var] = block_welford(x, cur_count, max_count);
|
||||
block_welford_sync(mean, var, cur_count);
|
||||
block_welford_cross_warp_sync(mean, var, cur_count, smem);
|
||||
block_tile_welford_post_scale_var(var, cur_count);
|
||||
|
||||
// compute inv-std
|
||||
auto inv_std = tile_elementwise_in(
|
||||
[&](const auto& v_) {
|
||||
return type_convert<ComputeDataType>(1.0f) / (sqrt(v_) + epsilon);
|
||||
},
|
||||
var);
|
||||
|
||||
if constexpr(kSaveMean)
|
||||
store_tile(mean_window, cast_tile<MeanDataType>(mean));
|
||||
if constexpr(kSaveInvStd)
|
||||
store_tile(inv_std_window, cast_tile<InvStdDataType>(inv_std));
|
||||
|
||||
// layernorm computation
|
||||
auto y = make_static_distributed_tensor<YDataType>(x.get_tile_distribution());
|
||||
sweep_tile(y, [&, mean_ = mean](auto idx) {
|
||||
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
|
||||
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
|
||||
|
||||
const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
|
||||
const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]);
|
||||
|
||||
const auto x_ = type_convert<ComputeDataType>(x[idx]);
|
||||
auto y_ = (x_ - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
|
||||
|
||||
y(idx) = type_convert<YDataType>(y_);
|
||||
});
|
||||
store_tile(y_window, y);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,40 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename XDataType_,
|
||||
typename GammaDataType_,
|
||||
typename BetaDataType_,
|
||||
typename ComputeDataType_,
|
||||
typename YDataType_,
|
||||
typename MeanDataType_,
|
||||
typename InvStdDataType_,
|
||||
typename BlockShape_,
|
||||
bool kPadN_,
|
||||
bool kSaveMeanInvStd_,
|
||||
bool kTwoPass_>
|
||||
struct Layernorm2dFwdPipelineProblem
|
||||
{
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
using GammaDataType = remove_cvref_t<GammaDataType_>;
|
||||
using BetaDataType = remove_cvref_t<BetaDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
using YDataType = remove_cvref_t<YDataType_>;
|
||||
using MeanDataType = remove_cvref_t<MeanDataType_>;
|
||||
using InvStdDataType = remove_cvref_t<InvStdDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
|
||||
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
|
||||
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
|
||||
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
|
||||
static constexpr bool kTwoPass = kTwoPass_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,160 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp"
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem_, typename Policy_ = Layernorm2dFwdPipelineDefaultPolicy>
|
||||
struct Layernorm2dFwdPipelineTwoPass
|
||||
{
|
||||
using Problem = ck_tile::remove_cvref_t<Problem_>;
|
||||
using Policy = ck_tile::remove_cvref_t<Policy_>;
|
||||
|
||||
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
|
||||
using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>;
|
||||
using BetaDataType = ck_tile::remove_cvref_t<typename Problem::BetaDataType>;
|
||||
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
|
||||
using MeanDataType = ck_tile::remove_cvref_t<typename Problem::MeanDataType>;
|
||||
using InvStdDataType = ck_tile::remove_cvref_t<typename Problem::InvStdDataType>;
|
||||
|
||||
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
|
||||
static constexpr bool kHasBeta = !std::is_same_v<BetaDataType, ck_tile::null_type>;
|
||||
static constexpr bool kSaveMean = Problem::kSaveMeanInvStd;
|
||||
static constexpr bool kSaveInvStd = Problem::kSaveMeanInvStd;
|
||||
|
||||
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
|
||||
static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
|
||||
static constexpr const char* name = []() {
|
||||
if constexpr(kNeedCrossWarpSync)
|
||||
return "bpr"; // block per row
|
||||
else
|
||||
return "wpr"; // warp per row
|
||||
}();
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename XWindow,
|
||||
typename GammaWindow,
|
||||
typename BetaWindow,
|
||||
typename YWindow,
|
||||
typename MeanWindow,
|
||||
typename InvStdWindow>
|
||||
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
|
||||
const GammaWindow& gamma_window_,
|
||||
const BetaWindow& beta_window_,
|
||||
YWindow& y_window,
|
||||
MeanWindow& mean_window,
|
||||
InvStdWindow& inv_std_window,
|
||||
ComputeDataType epsilon,
|
||||
ck_tile::index_t row_size,
|
||||
void* smem) const
|
||||
{
|
||||
auto x_window =
|
||||
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
auto gamma_window = make_tile_window(
|
||||
gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
|
||||
auto beta_window = make_tile_window(
|
||||
beta_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
|
||||
|
||||
// Problem::BlockShape
|
||||
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
|
||||
index_t num_n_tile_iteration =
|
||||
__builtin_amdgcn_readfirstlane(integer_divide_ceil(row_size, Block_N));
|
||||
|
||||
// total number of count assume current iter have no pad(only last iter has pad)
|
||||
constexpr index_t count_per_iter =
|
||||
Problem::BlockShape::Repeat_N * Problem::BlockShape::Vector_N;
|
||||
const index_t last_iter_n = row_size - (num_n_tile_iteration - 1) * Block_N;
|
||||
|
||||
int cur_count = 0;
|
||||
int max_count =
|
||||
(num_n_tile_iteration - 1) * count_per_iter +
|
||||
block_tile_welford_calculate_max_count<typename Problem::BlockShape>(last_iter_n);
|
||||
auto block_welford = Policy::template GetBlockWelford<Problem>();
|
||||
auto block_welford_sync = Policy::template GetBlockWelfordSync<Problem>();
|
||||
auto block_welford_cross_warp_sync =
|
||||
Policy::template GetBlockWelfordCrossWarpSync<Problem>();
|
||||
|
||||
using XTensorType = decltype(load_tile(x_window));
|
||||
auto mean = block_welford.template MakeMeanVarBlockTile<XTensorType>();
|
||||
auto var = block_welford.template MakeMeanVarBlockTile<XTensorType>();
|
||||
|
||||
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
|
||||
{
|
||||
const auto x = load_tile(x_window);
|
||||
block_welford(x, mean, var, cur_count, max_count);
|
||||
move_tile_window(x_window, {0, Block_N});
|
||||
}
|
||||
|
||||
block_welford_sync(mean, var, cur_count);
|
||||
block_welford_cross_warp_sync(mean, var, cur_count, smem);
|
||||
block_tile_welford_post_scale_var(var, cur_count);
|
||||
|
||||
// compute inv-std
|
||||
auto inv_std = tile_elementwise_in(
|
||||
[&](const auto& v_) {
|
||||
return type_convert<ComputeDataType>(1.0f) / (sqrt(v_) + epsilon);
|
||||
},
|
||||
var);
|
||||
|
||||
if constexpr(kSaveMean)
|
||||
store_tile(mean_window, cast_tile<MeanDataType>(mean));
|
||||
if constexpr(kSaveInvStd)
|
||||
store_tile(inv_std_window, cast_tile<InvStdDataType>(inv_std));
|
||||
|
||||
// reverse read x to reuse cache
|
||||
ck_tile::index_t stride_to_right_most_window =
|
||||
row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N;
|
||||
|
||||
// x_window.foo();
|
||||
// gamma_window.foo();
|
||||
move_tile_window(x_window, {0, -Block_N});
|
||||
move_tile_window(gamma_window, {stride_to_right_most_window});
|
||||
move_tile_window(beta_window, {stride_to_right_most_window});
|
||||
move_tile_window(y_window, {0, stride_to_right_most_window});
|
||||
|
||||
// layernorm computation
|
||||
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
|
||||
{
|
||||
const auto x = load_tile(x_window);
|
||||
// load gamma/beta (TODO: support no gamma/beta?)
|
||||
const auto gamma = load_tile(gamma_window);
|
||||
const auto beta = load_tile(beta_window);
|
||||
|
||||
auto y = make_static_distributed_tensor<YDataType>(x.get_tile_distribution());
|
||||
|
||||
sweep_tile(y, [&, mean_ = mean](auto idx) {
|
||||
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
|
||||
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
|
||||
|
||||
const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
|
||||
const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]);
|
||||
|
||||
const auto x_ = type_convert<ComputeDataType>(x[idx]);
|
||||
auto y_ = (x_ - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
|
||||
|
||||
y(idx) = type_convert<YDataType>(y_);
|
||||
});
|
||||
|
||||
store_tile(y_window, y);
|
||||
|
||||
move_tile_window(x_window, {0, -Block_N});
|
||||
move_tile_window(gamma_window, {-Block_N});
|
||||
move_tile_window(beta_window, {-Block_N});
|
||||
move_tile_window(y_window, {0, -Block_N});
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -1,35 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
template <typename ThreadTile, // Sequence<...
|
||||
typename WarpTile, // Sequence<...
|
||||
typename BlockTile> // Sequence<...
|
||||
struct TileLayernorm2dShape
|
||||
{
|
||||
static constexpr index_t kMPerThread = ThreadTile::at(number<0>{});
|
||||
static constexpr index_t kNPerThread = ThreadTile::at(number<1>{});
|
||||
|
||||
static constexpr index_t kMPerWarp = WarpTile::at(number<0>{});
|
||||
static constexpr index_t kNPerWarp = WarpTile::at(number<1>{});
|
||||
|
||||
static constexpr index_t kMThreadPerWarp = kMPerWarp / kMPerThread;
|
||||
static constexpr index_t kNThreadPerWarp = kNPerWarp / kNPerThread;
|
||||
|
||||
static constexpr index_t kMPerBlock = BlockTile::at(number<0>{});
|
||||
static constexpr index_t kNPerBlock = BlockTile::at(number<1>{});
|
||||
|
||||
static constexpr index_t kMWarpPerBlock = kMPerBlock / kMPerWarp;
|
||||
static constexpr index_t kNWarpPerBlock = kNPerBlock / kNPerWarp;
|
||||
|
||||
// TODO - kNNumWarps can only be 1 if we don't support cross warp welford
|
||||
static_assert(kNWarpPerBlock == 1);
|
||||
|
||||
static constexpr index_t kBlockSize = warpSize * kMWarpPerBlock * kNWarpPerBlock;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/welford/block/block_welford.hpp"
|
||||
#include "ck_tile/ops/welford/block/block_welford_problem.hpp"
|
||||
#include "ck_tile/ops/welford/thread/thread_welford.hpp"
|
||||
#include "ck_tile/ops/welford/warp/warp_welford.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
362
include/ck_tile/ops/welford/block/block_welford.hpp
Normal file
362
include/ck_tile/ops/welford/block/block_welford.hpp
Normal file
@@ -0,0 +1,362 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/welford/thread/thread_welford.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
struct BlockWelford
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using XDataType = typename Problem::XDataType;
|
||||
using ComputeDataType = typename Problem::ComputeDataType;
|
||||
|
||||
CK_TILE_DEVICE constexpr BlockWelford() {}
|
||||
|
||||
// [CAUSION] - max_count_ is to deal with the padding problem
|
||||
// max_count_ is depend on caller, eg: naive and splitN welford will have different
|
||||
// calculation of max_count_
|
||||
// -> use block_welford_calculate_max_count to compute
|
||||
template <typename XDistributedTensor_,
|
||||
typename MeanDistributedTensor_,
|
||||
typename VarDistributedTensor_>
|
||||
CK_TILE_DEVICE void operator()(const XDistributedTensor_& x_tensor,
|
||||
MeanDistributedTensor_& mean_tensor,
|
||||
VarDistributedTensor_& var_tensor,
|
||||
int& cur_count_, // -> prefer init as zero
|
||||
const int& max_count_)
|
||||
{
|
||||
constexpr auto I0 = number<0>{};
|
||||
constexpr auto I1 = number<1>{};
|
||||
|
||||
constexpr auto spans = XDistributedTensor_::get_distributed_spans();
|
||||
|
||||
sweep_tile_span(spans[I1], [&](auto dstr_idx_i1) {
|
||||
if(cur_count_ < max_count_)
|
||||
{
|
||||
++cur_count_;
|
||||
|
||||
sweep_tile_span(spans[I0], [&](auto dstr_idx_i0) {
|
||||
constexpr auto in_dstr_idx = make_tuple(dstr_idx_i0, dstr_idx_i1);
|
||||
constexpr auto out_dstr_idx = make_tuple(dstr_idx_i0);
|
||||
|
||||
auto x = ck_tile::type_convert<ComputeDataType>(x_tensor[in_dstr_idx]);
|
||||
|
||||
welford_update(
|
||||
mean_tensor(out_dstr_idx), var_tensor(out_dstr_idx), x, cur_count_);
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename XDistributedTensor_>
|
||||
CK_TILE_DEVICE static auto MakeMeanVarBlockTile()
|
||||
{
|
||||
static_assert(std::is_same_v<XDataType, typename XDistributedTensor_::DataType>, "wrong!");
|
||||
|
||||
constexpr auto reduce_dims = sequence<1>{};
|
||||
|
||||
constexpr auto dstr =
|
||||
make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding(
|
||||
XDistributedTensor_::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding(),
|
||||
reduce_dims));
|
||||
|
||||
auto tensor = make_static_distributed_tensor<ComputeDataType>(dstr);
|
||||
|
||||
return tensor;
|
||||
}
|
||||
|
||||
template <typename XDistributedTensor_>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const XDistributedTensor_& x_tensor, int& cur_count_, const int& max_count_)
|
||||
{
|
||||
auto mean_tensor = MakeMeanVarBlockTile<XDistributedTensor_>();
|
||||
auto var_tensor = MakeMeanVarBlockTile<XDistributedTensor_>();
|
||||
clear_tile(mean_tensor);
|
||||
clear_tile(var_tensor);
|
||||
|
||||
(*this)(x_tensor, mean_tensor, var_tensor, cur_count_, max_count_);
|
||||
|
||||
return ck_tile::make_tuple(mean_tensor, var_tensor);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
struct BlockWelfordSync
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
|
||||
template <typename MeanDistributedTensor_, typename VarDistributedTensor_>
|
||||
CK_TILE_DEVICE void
|
||||
operator()(MeanDistributedTensor_& mean_tensor, VarDistributedTensor_& var_tensor, int& count)
|
||||
{
|
||||
using Dstr = typename MeanDistributedTensor_::StaticTileDistribution;
|
||||
using DstrEncode = typename Dstr::DstrEncode;
|
||||
using DstrEncodeDetail = typename DstrEncode::detail;
|
||||
|
||||
static_assert(std::is_same_v<Dstr, typename VarDistributedTensor_::StaticTileDistribution>,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t NDimP = Dstr::get_num_of_dimension_p();
|
||||
constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
|
||||
|
||||
constexpr index_t idim_p_lane = NDimP - 1;
|
||||
|
||||
// const auto ps_idx = make_array<index_t>(get_warp_id(), get_lane_id());
|
||||
// const auto rs_idx =
|
||||
// mean_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx);
|
||||
|
||||
constexpr index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size();
|
||||
static_assert(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size());
|
||||
|
||||
const int original_count = count;
|
||||
|
||||
// loop over thread data
|
||||
static_for<0, thread_buf_size, 1>{}([&](auto i) {
|
||||
auto v_local_mean = mean_tensor.get_thread_buffer()[i];
|
||||
auto v_local_var = var_tensor.get_thread_buffer()[i];
|
||||
auto v_local_count = original_count;
|
||||
|
||||
// cross-lane reduce for replication
|
||||
// only reduce on R dimension correspond to lane
|
||||
// (lane id maps to this R dimension)
|
||||
static_for<0, NDimR, 1>{}([&](auto idim_r) {
|
||||
// FIXME: nasty to use does_p_own_r_
|
||||
if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
|
||||
{
|
||||
constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
|
||||
|
||||
constexpr index_t lid_over_rid_derivative =
|
||||
DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r];
|
||||
|
||||
static_assert(is_power_of_two_integer(r_length),
|
||||
"wrong! only support power of 2 reduction");
|
||||
|
||||
constexpr index_t nstage = integer_log2_floor(r_length);
|
||||
|
||||
// reduction sweep forward
|
||||
static_for<0, nstage, 1>{}([&](auto istage) {
|
||||
// xor
|
||||
index_t src_lane =
|
||||
(__lane_id()) ^
|
||||
(number<lid_over_rid_derivative << istage.value>{}.value);
|
||||
|
||||
// pull data from remote lane
|
||||
const auto v_remote_mean = warp_shuffle(v_local_mean, src_lane);
|
||||
const auto v_remote_var = warp_shuffle(v_local_var, src_lane);
|
||||
const auto v_remote_count = warp_shuffle(v_local_count, src_lane);
|
||||
|
||||
// welford merge
|
||||
welford_merge(v_local_mean,
|
||||
v_local_var,
|
||||
v_local_count,
|
||||
v_remote_mean,
|
||||
v_remote_var,
|
||||
v_remote_count);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
mean_tensor.get_thread_buffer()(i) = v_local_mean;
|
||||
var_tensor.get_thread_buffer()(i) = v_local_var;
|
||||
|
||||
count = v_local_count;
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
struct BlockWelfordCrossWarpSync
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using BlockShape = typename Problem::BlockShape;
|
||||
|
||||
template <typename MeanDistributedTensor_>
|
||||
CK_TILE_DEVICE static constexpr index_t GetReduceWarps()
|
||||
{
|
||||
constexpr index_t num_reduce_warps = [&]() {
|
||||
using Dstr = typename MeanDistributedTensor_::StaticTileDistribution;
|
||||
using DstrEncode = typename Dstr::DstrEncode;
|
||||
using DstrEncodeDetail = typename DstrEncode::detail;
|
||||
|
||||
constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
|
||||
|
||||
constexpr index_t idim_p_warp = 0;
|
||||
|
||||
index_t len_ = 1;
|
||||
static_for<0, NDimR, 1>{}([&](auto idim_r) {
|
||||
if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_warp][idim_r])
|
||||
{
|
||||
constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
|
||||
len_ *= r_length;
|
||||
}
|
||||
});
|
||||
return len_;
|
||||
}();
|
||||
return num_reduce_warps;
|
||||
}
|
||||
|
||||
// return in byte
|
||||
template <typename MeanDistributedTensor_>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
// constexpr auto num_reduce_warps = GetReduceWarps<MeanDistributedTensor_>();
|
||||
|
||||
// data need to exchange is very small, we just pack mean+var+count -> 4dword
|
||||
constexpr index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size();
|
||||
|
||||
// we need to store all data from every wave into smem
|
||||
// e.g. 2x2 reduce along N
|
||||
// -------------> reduce N
|
||||
// | w0 | w1 | ___> | w01 |
|
||||
// | w2 | w3 | | w23 |
|
||||
//
|
||||
// -> store data from every wave into LDS
|
||||
//
|
||||
//
|
||||
// -------------> reduce N
|
||||
// | w0 | w1 | w2 | w3 | -----> | w0123 |
|
||||
//
|
||||
// -> also store data from every wave into LDS
|
||||
constexpr index_t num_warps = BlockShape::BlockSize / warpSize;
|
||||
return num_warps * 4 * thread_buf_size * sizeof(float);
|
||||
}
|
||||
|
||||
template <typename MeanDistributedTensor_, typename VarDistributedTensor_>
|
||||
CK_TILE_DEVICE void operator()(MeanDistributedTensor_& mean_tensor,
|
||||
VarDistributedTensor_& var_tensor,
|
||||
int& count,
|
||||
void* smem)
|
||||
{
|
||||
using DataType = typename MeanDistributedTensor_::DataType;
|
||||
using Dstr = typename MeanDistributedTensor_::StaticTileDistribution;
|
||||
// using DstrEncode = typename Dstr::DstrEncode;
|
||||
// using DstrEncodeDetail = typename DstrEncode::detail;
|
||||
|
||||
static_assert(std::is_same_v<Dstr, typename VarDistributedTensor_::StaticTileDistribution>,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size();
|
||||
static_assert(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size());
|
||||
|
||||
// Note: we always pack everything into fp32x4
|
||||
fp32x4_t* smem_ptr = reinterpret_cast<fp32x4_t*>(smem);
|
||||
const index_t lane_id = get_lane_id();
|
||||
const index_t warp_id = get_warp_id();
|
||||
constexpr auto num_reduce_warps = GetReduceWarps<MeanDistributedTensor_>();
|
||||
constexpr index_t num_warps = BlockShape::BlockSize / warpSize;
|
||||
const index_t smem_offset = warp_id;
|
||||
|
||||
// skip if nonthing to do
|
||||
if constexpr(num_reduce_warps == 1)
|
||||
return;
|
||||
|
||||
// store into smem only for lane-0 within one warp
|
||||
if(lane_id == 0)
|
||||
{
|
||||
static_for<0, thread_buf_size, 1>{}([&](auto i) {
|
||||
fp32x4_t local_scratch_;
|
||||
local_scratch_[0] = bit_cast<float>(mean_tensor.get_thread_buffer()[i]);
|
||||
local_scratch_[1] = bit_cast<float>(var_tensor.get_thread_buffer()[i]);
|
||||
local_scratch_[2] = bit_cast<float>(count);
|
||||
|
||||
smem_ptr[smem_offset + i * num_warps] = local_scratch_;
|
||||
});
|
||||
}
|
||||
block_sync_lds();
|
||||
|
||||
// load from smem. here we let everythread to do compute :)
|
||||
index_t local_warp_id = warp_id / num_reduce_warps;
|
||||
index_t local_smem_os = local_warp_id * num_reduce_warps;
|
||||
fp32x4_t all_scratch[thread_buf_size * num_reduce_warps];
|
||||
static_for<0, thread_buf_size, 1>{}([&](auto i_0) {
|
||||
static_for<0, num_reduce_warps, 1>{}([&](auto i_1) {
|
||||
all_scratch[i_0 * num_warps + i_1] =
|
||||
smem_ptr[i_0 * num_reduce_warps + local_smem_os + i_1];
|
||||
});
|
||||
});
|
||||
block_sync_lds(); // TODO: we don't need sync here
|
||||
|
||||
// const int original_count = count;
|
||||
|
||||
static_for<0, thread_buf_size, 1>{}([&](auto i_0) {
|
||||
// TODO: use descriptor for this
|
||||
auto v_local = all_scratch[i_0 * num_warps];
|
||||
auto v_local_mean = bit_cast<DataType>(v_local[0]);
|
||||
auto v_local_var = bit_cast<DataType>(v_local[1]);
|
||||
auto v_local_count = bit_cast<int>(v_local[2]);
|
||||
|
||||
// further reduce mean/var
|
||||
static_for<0, num_reduce_warps - 1, 1>{}([&](auto i_1_n1) {
|
||||
constexpr auto i_1 = number<i_1_n1 + 1>{};
|
||||
const fp32x4_t v_remote = all_scratch[i_0 * num_warps + i_1];
|
||||
const auto v_remote_mean = bit_cast<DataType>(v_remote[0]);
|
||||
const auto v_remote_var = bit_cast<DataType>(v_remote[1]);
|
||||
const auto v_remote_count = bit_cast<int>(v_remote[2]);
|
||||
|
||||
welford_merge(v_local_mean,
|
||||
v_local_var,
|
||||
v_local_count,
|
||||
v_remote_mean,
|
||||
v_remote_var,
|
||||
v_remote_count);
|
||||
});
|
||||
|
||||
mean_tensor.get_thread_buffer()(i_0) = v_local_mean;
|
||||
var_tensor.get_thread_buffer()(i_0) = v_local_var;
|
||||
|
||||
count = v_local_count;
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// compute the max count for a last dim reduce
|
||||
// everything may have vector/repeat, so the max count could be uneven
|
||||
// TODO: specify which dim to compute and proper set the problem
|
||||
// TODO: BlockShape we reuse layernorm_fwd_shape :)
|
||||
template <typename BlockShape>
|
||||
CK_TILE_DEVICE constexpr index_t block_tile_welford_calculate_max_count(int row_size)
|
||||
{
|
||||
#if 0
|
||||
using S = BlockShape;
|
||||
index_t LastloopN = row_size % S::Block_N == 0 ? S::Block_N : row_size % S::Block_N;
|
||||
constexpr index_t NThread = S::WarpPerBlock_N * S::ThreadPerWarp_N;
|
||||
index_t iNLane = get_thread_id() % NThread;
|
||||
index_t iN0 = LastloopN / (S::Vector_N * S::ThreadPerWarp_N);
|
||||
index_t iN1 = (LastloopN % (S::Vector_N * S::ThreadPerWarp_N)) / S::Vector_N;
|
||||
index_t N2 = (LastloopN % (S::Vector_N * S::ThreadPerWarp_N)) % S::Vector_N;
|
||||
index_t iN3 = iNLane < iN1 ? S::Vector_N : iNLane == iN1 ? N2 : 0;
|
||||
return iN0 * S::Vector_N + iN3;
|
||||
#endif
|
||||
using S_ = BlockShape;
|
||||
constexpr index_t ThreadsPerBlock_N = S_::WarpPerBlock_N * S_::ThreadPerWarp_N;
|
||||
|
||||
// TODO: we always check vector size, need be evenly devidable by vector-n
|
||||
const index_t element_per_row = row_size / S_::Vector_N;
|
||||
index_t lane_id_n = get_thread_id() % ThreadsPerBlock_N;
|
||||
|
||||
index_t cnt = 0;
|
||||
// TODO: Repeat_N can not be too long, otherwise this is not good
|
||||
static_for<0, S_::Repeat_N, 1>{}([&](auto) {
|
||||
index_t _a = lane_id_n < element_per_row ? 1 : 0;
|
||||
cnt += _a;
|
||||
lane_id_n += ThreadsPerBlock_N;
|
||||
});
|
||||
return cnt * S_::Vector_N;
|
||||
}
|
||||
|
||||
// Note: this function must be called after all the computation
|
||||
template <typename VarDistributedTensor_>
|
||||
CK_TILE_DEVICE constexpr void block_tile_welford_post_scale_var(VarDistributedTensor_& var_tensor,
|
||||
int count)
|
||||
{
|
||||
using DataType = typename VarDistributedTensor_::DataType;
|
||||
tile_elementwise_inout([&count](auto& x) { x = x / type_convert<DataType>(count); },
|
||||
var_tensor);
|
||||
}
|
||||
} // namespace ck_tile
|
||||
18
include/ck_tile/ops/welford/block/block_welford_problem.hpp
Normal file
18
include/ck_tile/ops/welford/block/block_welford_problem.hpp
Normal file
@@ -0,0 +1,18 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename XDataType_, typename ComputeDataType_, typename BlockShape_>
|
||||
struct BlockWelfordProblem
|
||||
{
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -7,95 +7,30 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ComputeDataType_, typename XDataType_>
|
||||
struct ThreadWelford
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE void welford_update(T& mean, T& var, T x, int count)
|
||||
{
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
// TODO: check nan? maybe no
|
||||
T delta = x - mean;
|
||||
mean += delta / count;
|
||||
T delta2 = x - mean;
|
||||
var += delta * delta2;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE void Update(T& mean, T& var, T x)
|
||||
{
|
||||
if(ck_tile::isnan(x))
|
||||
{
|
||||
mean = x;
|
||||
var = x;
|
||||
}
|
||||
else
|
||||
{
|
||||
T delta = x - mean;
|
||||
mean += delta / cur_count_;
|
||||
T delta2 = x - mean;
|
||||
var += delta * delta2;
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE static void
|
||||
welford_merge(T& mean_a, T& var_a, int& count_a, T mean_b, T var_b, int count_b)
|
||||
{
|
||||
int count = count_a + count_b;
|
||||
T count_ = type_convert<T>(count);
|
||||
T count_a_ = type_convert<T>(count_a);
|
||||
T count_b_ = type_convert<T>(count_b);
|
||||
T count_b_over_count = count == 0 ? type_convert<T>(0) : count_b_ / count_;
|
||||
|
||||
// [CAUSION] - max_count_ is to deal with the padding problem
|
||||
// max_count_ is depend on caller, eg: naive and splitN welford will have different
|
||||
// calculation of max_count_
|
||||
CK_TILE_DEVICE constexpr ThreadWelford(int max_count) : cur_count_(0), max_count_(max_count) {}
|
||||
|
||||
template <typename XDistributedTensor_,
|
||||
typename MeanDistributedTensor_,
|
||||
typename VarDistributedTensor_>
|
||||
CK_TILE_DEVICE void operator()(const XDistributedTensor_& x_tensor,
|
||||
MeanDistributedTensor_& mean_tensor,
|
||||
VarDistributedTensor_& var_tensor)
|
||||
{
|
||||
constexpr auto I0 = number<0>{};
|
||||
constexpr auto I1 = number<1>{};
|
||||
|
||||
constexpr auto spans = XDistributedTensor_::get_distributed_spans();
|
||||
|
||||
sweep_tile_span(spans[I1], [&](auto dstr_idx_i1) {
|
||||
if(cur_count_ < max_count_)
|
||||
{
|
||||
++cur_count_;
|
||||
|
||||
sweep_tile_span(spans[I0], [&](auto dstr_idx_i0) {
|
||||
constexpr auto in_dstr_idx = make_tuple(dstr_idx_i0, dstr_idx_i1);
|
||||
constexpr auto out_dstr_idx = make_tuple(dstr_idx_i0);
|
||||
|
||||
auto x = ck_tile::type_convert<ComputeDataType>(x_tensor[in_dstr_idx]);
|
||||
|
||||
Update(mean_tensor(out_dstr_idx), var_tensor(out_dstr_idx), x);
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename XDistributedTensor_>
|
||||
CK_TILE_DEVICE static auto MakeInitialMeanVarDistributedTensor()
|
||||
{
|
||||
static_assert(std::is_same_v<XDataType, typename XDistributedTensor_::DataType>, "wrong!");
|
||||
|
||||
constexpr auto reduce_dims = sequence<1>{};
|
||||
|
||||
constexpr auto dstr =
|
||||
make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding(
|
||||
XDistributedTensor_::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding(),
|
||||
reduce_dims));
|
||||
|
||||
auto tensor = make_static_distributed_tensor<ComputeDataType>(dstr);
|
||||
clear_tile(tensor);
|
||||
|
||||
return tensor;
|
||||
}
|
||||
|
||||
template <typename XDistributedTensor_>
|
||||
CK_TILE_DEVICE auto operator()(const XDistributedTensor_& x_tensor)
|
||||
{
|
||||
auto mean_tensor = MakeInitialMeanVarDistributedTensor<XDistributedTensor_>();
|
||||
auto var_tensor = MakeInitialMeanVarDistributedTensor<XDistributedTensor_>();
|
||||
|
||||
(*this)(x_tensor, mean_tensor, var_tensor);
|
||||
|
||||
return ck_tile::make_tuple(mean_tensor, var_tensor);
|
||||
}
|
||||
|
||||
int cur_count_;
|
||||
int max_count_;
|
||||
};
|
||||
T delta = mean_b - mean_a;
|
||||
mean_a += delta * count_b_over_count;
|
||||
var_a += var_b + delta * delta * count_a_ * count_b_over_count;
|
||||
count_a = count;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -1,154 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ComputeDataType_, bool BroadcastLane = true, bool GetActualVariance = true>
|
||||
struct WarpMergeWelford
|
||||
{
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE static void
|
||||
Merge(T& mean_a, T& var_a, int& count_a, T mean_b, T var_b, int count_b)
|
||||
{
|
||||
int count = count_a + count_b;
|
||||
T count_ = type_convert<T>(count);
|
||||
T count_a_ = type_convert<T>(count_a);
|
||||
T count_b_ = type_convert<T>(count_b);
|
||||
T count_b_over_count = count == 0 ? type_convert<T>(0) : count_b_ / count_;
|
||||
|
||||
T delta = mean_b - mean_a;
|
||||
mean_a += delta * count_b_over_count;
|
||||
var_a += var_b + delta * delta * count_a_ * count_b_over_count;
|
||||
count_a = count;
|
||||
}
|
||||
|
||||
template <typename MeanDistributedTensor_, typename VarDistributedTensor_>
|
||||
CK_TILE_DEVICE void
|
||||
operator()(MeanDistributedTensor_& mean_tensor, VarDistributedTensor_& var_tensor, int& count)
|
||||
{
|
||||
using Dstr = typename MeanDistributedTensor_::StaticTileDistribution;
|
||||
using DstrEncode = typename Dstr::DstrEncode;
|
||||
using DstrEncodeDetail = typename DstrEncode::detail;
|
||||
|
||||
static_assert(std::is_same_v<Dstr, typename VarDistributedTensor_::StaticTileDistribution>,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t NDimP = Dstr::get_num_of_dimension_p();
|
||||
constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
|
||||
|
||||
constexpr index_t idim_p_lane = NDimP - 1;
|
||||
|
||||
const auto ps_idx = make_array<index_t>(get_warp_id(), get_lane_id());
|
||||
const auto rs_idx =
|
||||
mean_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx);
|
||||
|
||||
constexpr index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size();
|
||||
static_assert(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size());
|
||||
|
||||
const int original_count = count;
|
||||
|
||||
// loop over thread data
|
||||
static_for<0, thread_buf_size, 1>{}([&](auto i) {
|
||||
auto v_local_mean = mean_tensor.get_thread_buffer()[i];
|
||||
auto v_local_var = var_tensor.get_thread_buffer()[i];
|
||||
auto v_local_count = original_count;
|
||||
|
||||
// cross-lane reduce for replication
|
||||
// only reduce on R dimension correspond to lane
|
||||
// (lane id maps to this R dimension)
|
||||
static_for<0, NDimR, 1>{}([&](auto idim_r) {
|
||||
// FIXME: nasty to use does_p_own_r_
|
||||
if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
|
||||
{
|
||||
constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
|
||||
|
||||
constexpr index_t lid_over_rid_derivative =
|
||||
DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r];
|
||||
|
||||
static_assert(is_power_of_two_integer(r_length),
|
||||
"wrong! only support power of 2 reduction");
|
||||
|
||||
constexpr index_t nstage = integer_log2_floor(r_length);
|
||||
|
||||
// reduction sweep forward
|
||||
static_for<0, nstage, 1>{}([&](auto istage) {
|
||||
constexpr index_t lid_delta =
|
||||
lid_over_rid_derivative * (1 << (nstage - istage - 1));
|
||||
|
||||
// pull data from remote lane
|
||||
const auto v_remote_mean = warp_shuffle_down(v_local_mean, lid_delta);
|
||||
const auto v_remote_var = warp_shuffle_down(v_local_var, lid_delta);
|
||||
const auto v_remote_count = warp_shuffle_down(v_local_count, lid_delta);
|
||||
|
||||
// welford merge
|
||||
Merge(v_local_mean,
|
||||
v_local_var,
|
||||
v_local_count,
|
||||
v_remote_mean,
|
||||
v_remote_var,
|
||||
v_remote_count);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
// cross-lane broadcast for replication
|
||||
// only broadcast on R dimension correspond to lane
|
||||
// (lane id maps to this R dimension)
|
||||
if constexpr(BroadcastLane)
|
||||
{
|
||||
static_for<0, NDimR, 1>{}([&](auto idim_r) {
|
||||
// FIXME: nasty to use does_p_own_r_
|
||||
if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
|
||||
{
|
||||
const index_t r_id = rs_idx[idim_r];
|
||||
|
||||
constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
|
||||
|
||||
constexpr index_t lid_over_rid_derivative =
|
||||
DstrEncodeDetail::ps_over_rs_derivative_[NDimP - 1][idim_r];
|
||||
|
||||
static_assert(is_power_of_two_integer(r_length),
|
||||
"wrong! only support power of 2 reduction");
|
||||
|
||||
constexpr index_t nstage = integer_log2_floor(r_length);
|
||||
|
||||
// broadcast sweep backward
|
||||
static_for<0, nstage, 1>{}([&](auto istage) {
|
||||
// do I hold reduced data?
|
||||
const bool do_i_hold_reduced_data = r_id < (1 << istage);
|
||||
|
||||
constexpr index_t lid_delta = lid_over_rid_derivative * (1 << istage);
|
||||
|
||||
// pull data from remote lane
|
||||
const auto v_remote_mean = warp_shuffle_up(v_local_mean, lid_delta);
|
||||
const auto v_remote_var = warp_shuffle_up(v_local_var, lid_delta);
|
||||
const auto v_remote_count = warp_shuffle_up(v_local_count, lid_delta);
|
||||
|
||||
// decide whether to update local data with remote data
|
||||
v_local_mean = do_i_hold_reduced_data ? v_local_mean : v_remote_mean;
|
||||
v_local_var = do_i_hold_reduced_data ? v_local_var : v_remote_var;
|
||||
v_local_count = do_i_hold_reduced_data ? v_local_count : v_remote_count;
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
mean_tensor.get_thread_buffer()(i) = v_local_mean;
|
||||
|
||||
if constexpr(GetActualVariance)
|
||||
var_tensor.get_thread_buffer()(i) = v_local_var / v_local_count;
|
||||
else
|
||||
var_tensor.get_thread_buffer()(i) = v_local_var;
|
||||
|
||||
count = v_local_count;
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user