mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
adding implicit gemm
This commit is contained in:
@@ -343,7 +343,7 @@ int main()
|
||||
constexpr unsigned K = 1;
|
||||
constexpr unsigned S = 3;
|
||||
constexpr unsigned R = 3;
|
||||
#elif 1
|
||||
#elif 0
|
||||
constexpr unsigned N = 1;
|
||||
constexpr unsigned C = 1;
|
||||
constexpr unsigned HI = 34;
|
||||
@@ -396,19 +396,14 @@ int main()
|
||||
#if 0
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
wei_kcsr.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
#elif 0
|
||||
#elif 1
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
wei_kcsr.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
#elif 0
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
wei_kcsr.GenerateTensorValue(GeneratorTensor_3{}, num_thread);
|
||||
#elif 1
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_3{}, num_thread);
|
||||
wei_kcsr.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
#endif
|
||||
|
||||
#if 1
|
||||
auto wei_srck_desc = make_ConstantTensorDescriptor(Sequence<S, R, C, K>{});
|
||||
ostream_ConstantTensorDescriptor(wei_srck_desc, std::cout << "wei_srck_desc: ");
|
||||
Tensor<float> wei_srck(make_TensorDescriptor(wei_srck_desc));
|
||||
|
||||
auto f_reorder_kcsr2srck = [&](auto k, auto c, auto s, auto r) {
|
||||
@@ -418,12 +413,7 @@ int main()
|
||||
make_ParallelTensorFunctor(f_reorder_kcsr2srck, K, C, S, R)(num_thread);
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
wei_srck.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
out_nkhw_device.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
#endif
|
||||
|
||||
for(int i = 0; i < 1; ++i)
|
||||
for(int i = 0; i < 40; ++i)
|
||||
{
|
||||
#if 0
|
||||
device_direct_convolution_1(in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device);
|
||||
@@ -450,7 +440,7 @@ int main()
|
||||
check_error(out_nkhw_host, out_nkhw_device);
|
||||
#endif
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "wei_kcsr: ", wei_kcsr.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#pragma once
|
||||
//#include "gridwise_implicit_gemm_convolution_nchw_kcsr.cuh"
|
||||
#include "gridwise_implicit_gemm_convolution_nchw_kcsr.cuh"
|
||||
#include "gridwise_implicit_gemm_convolution_nchw_srck.cuh"
|
||||
|
||||
template <class T, class InDesc, class WeiDesc, class OutDesc>
|
||||
@@ -26,14 +26,13 @@ void device_implicit_gemm_convolution(
|
||||
constexpr auto wei_desc = WeiDesc{};
|
||||
constexpr auto out_desc = OutDesc{};
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
constexpr unsigned NPerBlock = 1;
|
||||
constexpr unsigned KPerBlock = 1;
|
||||
constexpr unsigned CPerBlock = 1;
|
||||
constexpr unsigned HoPerBlock = 2;
|
||||
constexpr unsigned WoPerBlock = 32;
|
||||
|
||||
constexpr unsigned NPerThread = 1;
|
||||
constexpr unsigned KPerThread = 1;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
constexpr unsigned HoPerThread = 2;
|
||||
@@ -47,13 +46,25 @@ void device_implicit_gemm_convolution(
|
||||
constexpr unsigned HoPerBlock = 2;
|
||||
constexpr unsigned WoPerBlock = 32;
|
||||
|
||||
constexpr unsigned NPerThread = 2;
|
||||
constexpr unsigned KPerThread = 4;
|
||||
constexpr unsigned CPerThread = 2;
|
||||
constexpr unsigned HoPerThread = 2;
|
||||
constexpr unsigned WoPerThread = 2;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
#elif 1
|
||||
constexpr unsigned NPerBlock = 2;
|
||||
constexpr unsigned KPerBlock = 64;
|
||||
constexpr unsigned CPerBlock = 4;
|
||||
constexpr unsigned HoPerBlock = 2;
|
||||
constexpr unsigned WoPerBlock = 32;
|
||||
|
||||
constexpr unsigned KPerThread = 4;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
constexpr unsigned HoPerThread = 2;
|
||||
constexpr unsigned WoPerThread = 2;
|
||||
|
||||
constexpr unsigned BlockSize = 256;
|
||||
#endif
|
||||
|
||||
constexpr unsigned GridSize =
|
||||
|
||||
Reference in New Issue
Block a user