Split the static library into several files. (#1044)

* spolit the static library into several

* update lib paths and fix client example

* do not use device_mha_operarions for client examples

* use appropriate libs to link to client examples

* remove the gpu/transpose path from the list

* try fixing clinet examples 3,4,9

* add necessary libs for client examples

* fix the layernorm client example

* fix the client examples 23 and 24

* fix typo

* add interface library and refresh clang format

[ROCm/composable_kernel commit: 7965d66a81]
This commit is contained in:
Illia Silin
2023-11-28 11:17:37 -08:00
committed by GitHub
parent 4e27eae99d
commit 2e8781c151
35 changed files with 225 additions and 122 deletions

View File

@@ -44,9 +44,9 @@ int run_layernorm2d_fwd_example()
{0, 1},
std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()},
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
save_mean.mDesc.GetStrides().end()},
save_mean.mDesc.GetStrides().end()},
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
save_mean.mDesc.GetStrides().end()},
save_mean.mDesc.GetStrides().end()},
{1},
1e-4,
x_dev.GetDeviceBuffer(),

View File

@@ -65,9 +65,9 @@ int run_groupnorm_fwd_example(int argc, char* argv[])
{0, 0, 0, C, 1},
std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()},
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
save_mean.mDesc.GetStrides().end()},
save_mean.mDesc.GetStrides().end()},
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
save_mean.mDesc.GetStrides().end()},
save_mean.mDesc.GetStrides().end()},
{1, 2, 4}, // reduction dimension: [H, W, C]
1e-6,
x_dev.GetDeviceBuffer(),

View File

@@ -46,9 +46,9 @@ int run_layernorm4d_fwd_example()
{0, W * C, C, 1},
std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()},
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
save_mean.mDesc.GetStrides().end()},
save_mean.mDesc.GetStrides().end()},
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
save_mean.mDesc.GetStrides().end()},
save_mean.mDesc.GetStrides().end()},
{1, 2, 3},
1e-4,
x_dev.GetDeviceBuffer(),