Batched GEMM for fp16 (#79)

* prepare host for batched_gemm

* init commit of batched kernels

* fixed

* refine transform with freeze

* m/n padding

* fixed a bug; clean

* add small tiles

* clean

* clean code

* clean code

* add nt, tn, tt layout

* add missing file

* use StaticBufferTupleOfVector instead

* add reference_batched_gemm

* fixed a macro
This commit is contained in:
zjing14
2022-02-11 09:36:52 -06:00
committed by GitHub
parent 6f928a0876
commit b53e9d08ed
16 changed files with 2098 additions and 13 deletions

View File

@@ -6,6 +6,7 @@
#include <half.hpp>
int profile_gemm(int, char*[]);
int profile_batched_gemm(int, char*[]);
int profile_gemm_bias_relu(int, char*[]);
int profile_gemm_bias_relu_add(int, char*[]);
int profile_conv_fwd(int, char*[]);
@@ -19,14 +20,18 @@ int main(int argc, char* argv[])
{
return profile_gemm(argc, argv);
}
if(strcmp(argv[1], "gemm_bias_relu") == 0)
else if(strcmp(argv[1], "gemm_bias_relu") == 0)
{
return profile_gemm_bias_relu(argc, argv);
}
if(strcmp(argv[1], "gemm_bias_relu_add") == 0)
else if(strcmp(argv[1], "gemm_bias_relu_add") == 0)
{
return profile_gemm_bias_relu_add(argc, argv);
}
else if(strcmp(argv[1], "batched_gemm") == 0)
{
return profile_batched_gemm(argc, argv);
}
else if(strcmp(argv[1], "conv_fwd") == 0)
{
return profile_conv_fwd(argc, argv);