Gemm reduce examples int4/int8/fp32/bf16 (#368)

* GEMM + Reduce max fp16+fp32

* GEmm + Max bf16 + int8

* Refactor common definitions.

* Refactor common func of mean meansquare example.

* More examples for mean meansquare.

* Update int8 examples and skip them cause of random errors.

* Int4 examples.

* Fix examples for max int4/8

* Tensor conversion for int4 input data for mean meansquare example.

* Remove int4 mean_meansquare example

* Fix int8 mean_meansquare example.

-All ReductionAccData and R<N>DataType have to be F32. The INT32 data
type is giving wrong results.

* Guard int4 with ifdef

* Change int8 example to add_addsquare due to div rounding err.

* Clang format

* Change the return type of common function.

* Get back int8 example with division.

* Remove int8 mean meansquare.

* Use proper cast for BF16 data type.

* Use ck::literals.

* Use proper data type for host tensors & reference.

- Use ReduceAccDataType for reference gemm output data type.
- Cast host reference output tensor to EDataType
- Fix ifdefs for int4.

Co-authored-by: Adam Osewski <aosewski@amd.com>

[ROCm/composable_kernel commit: d00e6115b9]
This commit is contained in:
Adam Osewski
2022-08-30 18:38:26 +02:00
committed by GitHub
parent 6446894289
commit 0cc91c73d8
13 changed files with 2141 additions and 359 deletions

View File

@@ -21,9 +21,9 @@ namespace reduce {
// vector space
// (http://pages.cs.wisc.edu/~matthewb/pages/notes/pdf/linearalgebra/VectorSpaces.pdf).
// 2) IsCompatibleInMemoryDataOperation() -- return true if the reduction task corresponding to this
// operator can use the InMemoryDataOperation to finalize, or else it return false 3) operator() --
// the first argument of the operator must be both an input & output, and the corresponding variable
// usually stores
// operator can use the InMemoryDataOperation to finalize, or else it return false
// 3) operator() -- the first argument of the operator must be both an input & output, and the
// corresponding variable usually stores
// the accumulated result of many operator() calls; the second argument is only an
// input. For indexable binary
// operator, the second version of operator() has third argument (which is an