Batched gemm and reduction (#156)

* adding batched_gemm_and_reduction

* batched_gemm_reduce works with bactch_count=1

* fix a bug in grid_size; batched_gemm_reduce works for batch_count > 1

* adding profiler for batched_gemm_fp16

* fixed a bug in declaration of d1 and d0; both example and profiler work

* clang-format

* cleanup

* batched_gemm_reduce: add test

* minor change

* fixed some typo in function names

[ROCm/composable_kernel commit: 34c661e71c]
This commit is contained in:
Jianfeng Yan
2022-03-30 11:21:18 -05:00
committed by GitHub
parent 6d537a8c3e
commit cb97ce68d8
27 changed files with 2145 additions and 62 deletions

View File

@@ -73,10 +73,10 @@ struct HostTensorDescriptor
HostTensorDescriptor() = delete;
template <typename X>
HostTensorDescriptor(std::vector<X> lens);
HostTensorDescriptor(const std::vector<X>& lens);
template <typename X, typename Y>
HostTensorDescriptor(std::vector<X> lens, std::vector<Y> strides);
HostTensorDescriptor(const std::vector<X>& lens, const std::vector<Y>& strides);
void CalculateStrides();
@@ -285,13 +285,14 @@ struct Tensor
};
template <typename X>
HostTensorDescriptor::HostTensorDescriptor(std::vector<X> lens) : mLens(lens)
HostTensorDescriptor::HostTensorDescriptor(const std::vector<X>& lens) : mLens(lens)
{
this->CalculateStrides();
}
template <typename X, typename Y>
HostTensorDescriptor::HostTensorDescriptor(std::vector<X> lens, std::vector<Y> strides)
HostTensorDescriptor::HostTensorDescriptor(const std::vector<X>& lens,
const std::vector<Y>& strides)
: mLens(lens), mStrides(strides)
{
}