mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 21:09:08 +00:00
Batched gemm and reduction (#156)
* adding batched_gemm_and_reduction
* batched_gemm_reduce works with bactch_count=1
* fix a bug in grid_size; batched_gemm_reduce works for batch_count > 1
* adding profiler for batched_gemm_fp16
* fixed a bug in declaration of d1 and d0; both example and profiler work
* clang-format
* cleanup
* batched_gemm_reduce: add test
* minor change
* fixed some typo in function names
[ROCm/composable_kernel commit: 34c661e71c]
This commit is contained in:
@@ -73,10 +73,10 @@ struct HostTensorDescriptor
|
||||
HostTensorDescriptor() = delete;
|
||||
|
||||
template <typename X>
|
||||
HostTensorDescriptor(std::vector<X> lens);
|
||||
HostTensorDescriptor(const std::vector<X>& lens);
|
||||
|
||||
template <typename X, typename Y>
|
||||
HostTensorDescriptor(std::vector<X> lens, std::vector<Y> strides);
|
||||
HostTensorDescriptor(const std::vector<X>& lens, const std::vector<Y>& strides);
|
||||
|
||||
void CalculateStrides();
|
||||
|
||||
@@ -285,13 +285,14 @@ struct Tensor
|
||||
};
|
||||
|
||||
template <typename X>
|
||||
HostTensorDescriptor::HostTensorDescriptor(std::vector<X> lens) : mLens(lens)
|
||||
HostTensorDescriptor::HostTensorDescriptor(const std::vector<X>& lens) : mLens(lens)
|
||||
{
|
||||
this->CalculateStrides();
|
||||
}
|
||||
|
||||
template <typename X, typename Y>
|
||||
HostTensorDescriptor::HostTensorDescriptor(std::vector<X> lens, std::vector<Y> strides)
|
||||
HostTensorDescriptor::HostTensorDescriptor(const std::vector<X>& lens,
|
||||
const std::vector<Y>& strides)
|
||||
: mLens(lens), mStrides(strides)
|
||||
{
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user