mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 20:51:23 +00:00
introducing ck_tile! (#1216)
* enable gfx940
* switch between intrinsic mfma routines on mi100/200 and mi300
* fix mfma_int8 on MI300
* disable 2 int8 examples on MI300
* Update cmake-ck-dev.sh
* restore gitignore file
* modify Jenkinsfile to the internal repo
* Bump rocm-docs-core from 0.24.0 to 0.29.0 in /docs/sphinx
Bumps [rocm-docs-core](https://github.com/RadeonOpenCompute/rocm-docs-core) from 0.24.0 to 0.29.0.
- [Release notes](https://github.com/RadeonOpenCompute/rocm-docs-core/releases)
- [Changelog](https://github.com/RadeonOpenCompute/rocm-docs-core/blob/develop/CHANGELOG.md)
- [Commits](https://github.com/RadeonOpenCompute/rocm-docs-core/compare/v0.24.0...v0.29.0)
---
updated-dependencies:
- dependency-name: rocm-docs-core
dependency-type: direct:production
update-type: version-update:semver-minor
...
Signed-off-by: dependabot[bot] <support@github.com>
* initial enablement of gfx950
* fix clang format
* disable examples 31 and 41 int8 on gfx950
* add code
* fix build wip
* fix xx
* now can build
* naming
* minor fix
* wip fix
* fix macro for exp2; fix warpgemm a/b in transposedC
* unify as tuple_array
* Update the required Python version to 3.9
* Update executable name in test scripts
* re-structure tuple/array to avoid spill
* Merge function templates
* Fix format
* Add constraint to array<> ctor
* Re-use function
* Some minor changes
* remove wrong code in store_raw()
* fix compile issue in transpose
* Rename enum
Rename 'cood_transform_enum' to 'coord_transform_enum'
* let more integral_constant->constant, and formating
* make sure thread_buffer can be tuple/array
* temp fix buffer_store spill
* not using custom data type by default, now we can have ISA-level same code as opt_padding
* fix compile error, fp8 not ready now
* fix fp8 duplicated move/shift/and/or problem
* Default use CK_TILE_FLOAT_TO_FP8_STOCHASTIC rounding mode
* fix scratch in fp8 kernel
* update some readme
* fix merge from upstream
* sync with upstream
* sync upstream again
* sync 22
* remove unused
* fix clang-format
* update README of ck_tile example
* fix several issue
* let python version to be 3.8 as minimal
* remove ck_tile example from default cmake target like all/install/check
* remove mistake
* 1).support receipe in generate.py 2).use simplified mask type 3).change left/right to pass into karg
* fix some bug in group-mode masking and codegen. update README
* F8 quantization for FMHA forward (#1224)
* Add SAccElementFunction, PComputeElementFunction, OAccElementFunction in pipeline
* Add element function to fmha api
* Adjust P elementwise function
* Fix bug of elementwise op, our elementwise op is not inout
* Add some elementwise op, prepare to quantization
* Let generate.py can generate different elementwise function
* To prevent compiler issue, remove the elementwise function we have not used.
* Remove f8 pipeline, we should share the same pipeline even in f8
* Remove remove_cvref_t
* Avoid warning
* Fix wrong fp8 QK/KV block gemm setting
* Check fp8 rounding error in check_err()
* Set fp8 rounding error for check_err()
* Use CK_TILE_FLOAT_TO_FP8_STANDARD as default fp8 rounding mode
* 1. codgen the f8 api and kernel
2. f8 host code
* prevent warning in filter mode
* Remove not-in-use elementwise function kargs
* Remove more not-in-use elementwise function kargs
* Small refinements in C++ source files
* Use conditional_t<> to simplify code
* Support heterogeneous argument for binary function types
* Re-use already-existing scales<> functor template
* Fix wrong value produced by saturating
* Generalize the composes<> template
* Unify saturates<> implementation
* Fix type errors in composes<>
* Extend less_equal<>
* Reuse the existing template less_equal<> in check_err()
* Add equal<float> & equal<double>
* Rename check_err() parameter
* Rename check_err() parameter
* Add FIXME comment for adding new macro in future
* Remove unnecessary cast to void
* Eliminate duplicated code
* Avoid dividing api pool into more than 2 groups
* Use more clear variable names
* Use affirmative condition in if stmt
* Remove blank lines
* Donot perfect forwarding in composes<>
* To fix compile error, revert generate.py back to 4439cc107d
* Fix bug of p element function
* Add compute element op to host softmax
* Remove element function in api interface
* Extract user parameter
* Rename pscale and oscale variable
* rename f8 to fp8
* rename more f8 to fp8
* Add pipeline::operator() without element_functor
* 1. Remove deprecated pipeline enum
2. Refine host code parameter
* Use quantization range as input
* 1. Rename max_dtype to dtype_max.
2. Rename scale to scale_s
3.Add init description
* Refine description
* prevent early return
* unify _squant kernel name in cpp, update README
* Adjust the default range.
* Refine error message and bias range
* Add fp8 benchmark and smoke test
* fix fp8 swizzle_factor=4 case
---------
Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
Co-authored-by: carlushuang <carlus.huang@amd.com>
---------
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: illsilin <Illia.Silin@amd.com>
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: Jing Zhang <jizha@amd.com>
Co-authored-by: zjing14 <zhangjing14@gmail.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Po-Yen, Chen <PoYen.Chen@amd.com>
Co-authored-by: rocking <ChunYu.Lai@amd.com>
This commit is contained in:
105
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
Normal file
105
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
Normal file
@@ -0,0 +1,105 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// fp16
|
||||
using WarpGemmMfmaF16F16F32M32N32K8 =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M32N32K8>>;
|
||||
|
||||
using WarpGemmMfmaF16F16F32M16N16K16 =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M16N16K16>>;
|
||||
|
||||
using WarpGemmMfmaF16F16F32M32N32K16 =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplF16F16F32M32N32K8, 2>>;
|
||||
|
||||
using WarpGemmMfmaF16F16F32M16N16K32 =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplF16F16F32M16N16K16, 2>>;
|
||||
|
||||
using WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImplF16F16F32M32N32K8>>;
|
||||
|
||||
using WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImplF16F16F32M16N16K16>>;
|
||||
|
||||
using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M32N32K8,
|
||||
2>>;
|
||||
|
||||
using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M16N16K16,
|
||||
2>>;
|
||||
|
||||
using WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M32N32K8,
|
||||
2>>;
|
||||
|
||||
// bf16
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K8 =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8>>;
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K16 =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16>>;
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K16 =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8, 2>>;
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K32 =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16, 2>>;
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8>>;
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16>>;
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8,
|
||||
2>>;
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16,
|
||||
2>>;
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8,
|
||||
2>>;
|
||||
|
||||
// fp8
|
||||
using WarpGemmMfma_f32_32x32x16_fp8_fp8 =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8>>;
|
||||
|
||||
using WarpGemmMfma_f32_32x32x16_fp8_bf8 =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8>>;
|
||||
|
||||
using WarpGemmMfma_f32_32x32x16_bf8_fp8 =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8>>;
|
||||
|
||||
using WarpGemmMfma_f32_32x32x16_bf8_bf8 =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8>>;
|
||||
|
||||
using WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8>>;
|
||||
|
||||
using WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8>>;
|
||||
|
||||
using WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8>>;
|
||||
|
||||
using WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8>>;
|
||||
|
||||
} // namespace ck_tile
|
||||
471
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
Normal file
471
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
Normal file
@@ -0,0 +1,471 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename WarpGemmAttributeMfmaImpl_>
|
||||
struct WarpGemmAtrributeMfma
|
||||
{
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
|
||||
|
||||
using ADataType = typename Impl::ADataType;
|
||||
using BDataType = typename Impl::BDataType;
|
||||
using CDataType = typename Impl::CDataType;
|
||||
|
||||
using AVecType = typename Impl::AVecType;
|
||||
using BVecType = typename Impl::BVecType;
|
||||
using CVecType = typename Impl::CVecType;
|
||||
|
||||
static constexpr index_t kM = Impl::kM;
|
||||
static constexpr index_t kN = Impl::kN;
|
||||
static constexpr index_t kK = Impl::kK;
|
||||
|
||||
using AWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kAMLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
|
||||
using BWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
|
||||
using CWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>,
|
||||
sequence<Impl::kCNLane>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 1>,
|
||||
sequence<0, 2>>;
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
Impl{}(c_vec, a_vec, b_vec);
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
return Impl{}(a_vec, b_vec);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename WarpGemmAttributeMfmaImpl_, index_t kKIter>
|
||||
struct WarpGemmAtrributeMfmaIterateK
|
||||
{
|
||||
static_assert(kKIter > 0, "wrong!");
|
||||
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
|
||||
|
||||
using ADataType = typename Impl::ADataType;
|
||||
using BDataType = typename Impl::BDataType;
|
||||
using CDataType = typename Impl::CDataType;
|
||||
|
||||
using AVecType =
|
||||
ext_vector_t<ADataType, vector_traits<typename Impl::AVecType>::vector_size * kKIter>;
|
||||
using BVecType =
|
||||
ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>;
|
||||
using CVecType = typename Impl::CVecType;
|
||||
|
||||
static constexpr index_t kM = Impl::kM;
|
||||
static constexpr index_t kN = Impl::kN;
|
||||
static constexpr index_t kK = Impl::kK * kKIter;
|
||||
|
||||
using AWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kAMLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
|
||||
using BWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
|
||||
using CWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>,
|
||||
sequence<Impl::kCNLane>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 1>,
|
||||
sequence<0, 2>>;
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
|
||||
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
|
||||
|
||||
static_for<0, kKIter, 1>{}([&](auto iKIter) {
|
||||
Impl{}(c_vec,
|
||||
reinterpret_cast<const buf_a>(a_vec)
|
||||
.template get_as<typename Impl::AVecType>()[iKIter],
|
||||
reinterpret_cast<const buf_b>(b_vec)
|
||||
.template get_as<typename Impl::BVecType>()[iKIter]);
|
||||
});
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
constexpr auto I0 = number<0>{};
|
||||
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
|
||||
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
|
||||
|
||||
// c = a * b
|
||||
auto c_vec = Impl{}(
|
||||
reinterpret_cast<const buf_a>(a_vec).template get_as<typename Impl::AVecType>()[I0],
|
||||
reinterpret_cast<const buf_b>(b_vec).template get_as<typename Impl::BVecType>()[I0]);
|
||||
|
||||
// c += a * b
|
||||
static_for<1, kKIter, 1>{}([&](auto iKIter) {
|
||||
Impl{}(c_vec,
|
||||
reinterpret_cast<const buf_a>(a_vec)
|
||||
.template get_as<typename Impl::AVecType>()[iKIter],
|
||||
reinterpret_cast<const buf_b>(b_vec)
|
||||
.template get_as<typename Impl::BVecType>()[iKIter]);
|
||||
});
|
||||
|
||||
return c_vec;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename WarpGemmAttributeMfmaImpl_>
|
||||
struct WarpGemmAtrributeMfmaTransposedCDistribution
|
||||
{
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
|
||||
|
||||
using ADataType = typename Impl::BDataType;
|
||||
using BDataType = typename Impl::ADataType;
|
||||
using CDataType = typename Impl::CDataType;
|
||||
|
||||
using AVecType = typename Impl::BVecType;
|
||||
using BVecType = typename Impl::AVecType;
|
||||
using CVecType = typename Impl::CVecType;
|
||||
|
||||
static constexpr index_t kM = Impl::kN;
|
||||
static constexpr index_t kN = Impl::kM;
|
||||
static constexpr index_t kK = Impl::kK;
|
||||
|
||||
using AWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
|
||||
using BWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kAMLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
|
||||
using CWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kCNLane>,
|
||||
sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>;
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
// swap A and B
|
||||
Impl{}(c_vec, b_vec, a_vec);
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
// swap A and B
|
||||
return Impl{}(b_vec, a_vec);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename WarpGemmAttributeMfmaImpl_>
|
||||
struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB
|
||||
{
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
|
||||
|
||||
using ADataType = typename Impl::BDataType;
|
||||
using BDataType = typename Impl::ADataType;
|
||||
using CDataType = typename Impl::CDataType;
|
||||
|
||||
using AVecType = typename Impl::BVecType;
|
||||
using BVecType = typename Impl::AVecType;
|
||||
using CVecType = typename Impl::CVecType;
|
||||
|
||||
static constexpr index_t kM = Impl::kN;
|
||||
static constexpr index_t kN = Impl::kM;
|
||||
static constexpr index_t kK = Impl::kK;
|
||||
|
||||
using AWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
|
||||
using BWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kAMLane / (Impl::kABKPerLane * Impl::kABKLane * 2),
|
||||
Impl::kABKLane,
|
||||
2,
|
||||
Impl::kABKPerLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane>>,
|
||||
tuple<sequence<2, 1, 1, 1, 1>>,
|
||||
tuple<sequence<0, 0, 2, 1, 3>>,
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
|
||||
using CWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kCNLane>,
|
||||
sequence<Impl::kCM0PerLane / 2, Impl::kCMLane, Impl::kCM1PerLane * 2>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>;
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
// swap A and B
|
||||
Impl{}(c_vec, b_vec, a_vec);
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
// swap A and B
|
||||
return Impl{}(b_vec, a_vec);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename WarpGemmAttributeMfmaImpl_, index_t kKIter>
|
||||
struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
|
||||
{
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
|
||||
|
||||
// swap A and B
|
||||
using ADataType = typename Impl::BDataType;
|
||||
using BDataType = typename Impl::ADataType;
|
||||
using CDataType = typename Impl::CDataType;
|
||||
|
||||
using AVecType =
|
||||
ext_vector_t<ADataType, vector_traits<typename Impl::AVecType>::vector_size * kKIter>;
|
||||
using BVecType =
|
||||
ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>;
|
||||
using CVecType = typename Impl::CVecType;
|
||||
|
||||
static constexpr index_t kM = Impl::kN;
|
||||
static constexpr index_t kN = Impl::kM;
|
||||
static constexpr index_t kK = Impl::kK * kKIter;
|
||||
|
||||
using AWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
|
||||
using BWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kAMLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
|
||||
using CWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kCNLane>,
|
||||
sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>;
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
|
||||
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
|
||||
// swap A and B, value and type
|
||||
static_for<0, kKIter, 1>{}([&](auto iKIter) {
|
||||
Impl{}(c_vec,
|
||||
reinterpret_cast<const buf_b&>(b_vec)
|
||||
.template get_as<typename Impl::BVecType>()[iKIter],
|
||||
reinterpret_cast<const buf_a&>(a_vec)
|
||||
.template get_as<typename Impl::AVecType>()[iKIter]);
|
||||
});
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
constexpr auto I0 = number<0>{};
|
||||
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
|
||||
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
|
||||
|
||||
// swap A and B, value and type
|
||||
auto c_vec = Impl{}(
|
||||
reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0],
|
||||
reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0]);
|
||||
|
||||
static_for<1, kKIter, 1>{}([&](auto iKIter) {
|
||||
Impl{}(c_vec,
|
||||
reinterpret_cast<const buf_b&>(b_vec)
|
||||
.template get_as<typename Impl::BVecType>()[iKIter],
|
||||
reinterpret_cast<const buf_a&>(a_vec)
|
||||
.template get_as<typename Impl::AVecType>()[iKIter]);
|
||||
});
|
||||
|
||||
return c_vec;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename WarpGemmAttributeMfmaImpl_, index_t kKIter, index_t SFactor_ = 2>
|
||||
struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
|
||||
{
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
|
||||
|
||||
// swap A and B
|
||||
using ADataType = typename Impl::BDataType;
|
||||
using BDataType = typename Impl::ADataType;
|
||||
using CDataType = typename Impl::CDataType;
|
||||
|
||||
using AVecType =
|
||||
ext_vector_t<ADataType, vector_traits<typename Impl::AVecType>::vector_size * kKIter>;
|
||||
using BVecType =
|
||||
ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>;
|
||||
using CVecType = typename Impl::CVecType;
|
||||
|
||||
static constexpr index_t kM = Impl::kN;
|
||||
static constexpr index_t kN = Impl::kM;
|
||||
static constexpr index_t kK = Impl::kK * kKIter;
|
||||
static constexpr index_t SFactor = SFactor_; // group how many CM1 together
|
||||
|
||||
using AWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
#if 0
|
||||
using BWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kAMLane / (Impl::kABKPerLane * Impl::kABKLane * 2),
|
||||
Impl::kABKLane,
|
||||
2,
|
||||
Impl::kABKPerLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<2, 1, 1, 1, 1>>,
|
||||
tuple<sequence<0, 0, 2, 1, 3>>,
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
|
||||
using CWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kCNLane>,
|
||||
sequence<Impl::kCM0PerLane / 2, Impl::kCMLane, Impl::kCM1PerLane * 2>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>;
|
||||
#else
|
||||
// TODO: more test not only 32x32
|
||||
using BWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kAMLane / (Impl::kCMLane * SFactor * Impl::kCM1PerLane),
|
||||
Impl::kCMLane,
|
||||
SFactor,
|
||||
Impl::kCM1PerLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<2, 1, 1, 1, 1>>,
|
||||
tuple<sequence<0, 0, 2, 1, 3>>,
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
|
||||
using CWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kCNLane>,
|
||||
sequence<Impl::kCM0PerLane / SFactor, Impl::kCMLane, Impl::kCM1PerLane * SFactor>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>;
|
||||
#endif
|
||||
// c_vec += a_vec * b_vec
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
|
||||
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
|
||||
// swap A and B, value and type
|
||||
static_for<0, kKIter, 1>{}([&](auto iKIter) {
|
||||
Impl{}(c_vec,
|
||||
reinterpret_cast<const buf_b&>(b_vec)
|
||||
.template get_as<typename Impl::BVecType>()[iKIter],
|
||||
reinterpret_cast<const buf_a&>(a_vec)
|
||||
.template get_as<typename Impl::AVecType>()[iKIter]);
|
||||
});
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
|
||||
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
// swap A and B, value and type
|
||||
auto c_vec = Impl{}(
|
||||
reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0],
|
||||
reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0]);
|
||||
|
||||
static_for<1, kKIter, 1>{}([&](auto iKIter) {
|
||||
Impl{}(c_vec,
|
||||
reinterpret_cast<const buf_b&>(b_vec)
|
||||
.template get_as<typename Impl::BVecType>()[iKIter],
|
||||
reinterpret_cast<const buf_a&>(a_vec)
|
||||
.template get_as<typename Impl::AVecType>()[iKIter]);
|
||||
});
|
||||
|
||||
return c_vec;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
379
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
Normal file
379
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
Normal file
@@ -0,0 +1,379 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// FP16
|
||||
struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
|
||||
{
|
||||
using ADataType = fp16_t;
|
||||
using BDataType = fp16_t;
|
||||
using CDataType = float;
|
||||
|
||||
using AVecType = ext_vector_t<fp16_t, 4>;
|
||||
using BVecType = ext_vector_t<fp16_t, 4>;
|
||||
using CVecType = ext_vector_t<float, 16>;
|
||||
|
||||
static constexpr index_t kM = 32;
|
||||
static constexpr index_t kN = 32;
|
||||
static constexpr index_t kK = 8;
|
||||
|
||||
static constexpr index_t kAMLane = 32;
|
||||
static constexpr index_t kBNLane = 32;
|
||||
static constexpr index_t kABKLane = 2;
|
||||
static constexpr index_t kABKPerLane = 4;
|
||||
|
||||
static constexpr index_t kCMLane = 2;
|
||||
static constexpr index_t kCNLane = 32;
|
||||
static constexpr index_t kCM0PerLane = 4;
|
||||
static constexpr index_t kCM1PerLane = 4;
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
|
||||
defined(__gfx942__)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, c_vec, 0, 0, 0);
|
||||
#else
|
||||
ck_tile::ignore = c_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
#endif
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
|
||||
defined(__gfx942__)
|
||||
return bit_cast<CVecType>(
|
||||
__builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0));
|
||||
#else
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
return CVecType{0.f};
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
|
||||
{
|
||||
using ADataType = fp16_t;
|
||||
using BDataType = fp16_t;
|
||||
using CDataType = float;
|
||||
|
||||
using AVecType = ext_vector_t<fp16_t, 4>;
|
||||
using BVecType = ext_vector_t<fp16_t, 4>;
|
||||
using CVecType = ext_vector_t<float, 4>;
|
||||
|
||||
static constexpr index_t kM = 16;
|
||||
static constexpr index_t kN = 16;
|
||||
static constexpr index_t kK = 16;
|
||||
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABKLane = 4;
|
||||
static constexpr index_t kABKPerLane = 4;
|
||||
|
||||
static constexpr index_t kCMLane = 4;
|
||||
static constexpr index_t kCNLane = 16;
|
||||
static constexpr index_t kCM0PerLane = 1;
|
||||
static constexpr index_t kCM1PerLane = 4;
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
|
||||
defined(__gfx942__)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, c_vec, 0, 0, 0);
|
||||
#else
|
||||
ck_tile::ignore = c_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
#endif
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
|
||||
defined(__gfx942__)
|
||||
return bit_cast<CVecType>(
|
||||
__builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
|
||||
#else
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
return CVecType{0.f};
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
// Bf16
|
||||
struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
|
||||
{
|
||||
using ADataType = bf16_t;
|
||||
using BDataType = bf16_t;
|
||||
using CDataType = float;
|
||||
|
||||
using AVecType = ext_vector_t<bf16_t, 4>;
|
||||
using BVecType = ext_vector_t<bf16_t, 4>;
|
||||
using CVecType = ext_vector_t<float, 16>;
|
||||
|
||||
static constexpr index_t kM = 32;
|
||||
static constexpr index_t kN = 32;
|
||||
static constexpr index_t kK = 8;
|
||||
|
||||
static constexpr index_t kAMLane = 32;
|
||||
static constexpr index_t kBNLane = 32;
|
||||
static constexpr index_t kABKLane = 2;
|
||||
static constexpr index_t kABKPerLane = 4;
|
||||
|
||||
static constexpr index_t kCMLane = 2;
|
||||
static constexpr index_t kCNLane = 32;
|
||||
static constexpr index_t kCM0PerLane = 4;
|
||||
static constexpr index_t kCM1PerLane = 4;
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
#if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
|
||||
#elif defined(__gfx908__)
|
||||
static_for<0, 2, 1>{}([&](auto k) {
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16(
|
||||
reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
|
||||
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
|
||||
reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
|
||||
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
|
||||
c_vec,
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
});
|
||||
#else
|
||||
ck_tile::ignore = c_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
#endif
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
#if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
return bit_cast<CVecType>(
|
||||
__builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0));
|
||||
#elif defined(__gfx908__)
|
||||
CVecType c_vec{0.f};
|
||||
static_for<0, 2, 1>{}([&](auto k) {
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16(
|
||||
reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
|
||||
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
|
||||
reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
|
||||
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
|
||||
c_vec,
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
});
|
||||
return c_vec;
|
||||
#else
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
return CVecType{0.f};
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
|
||||
{
|
||||
using ADataType = bf16_t;
|
||||
using BDataType = bf16_t;
|
||||
using CDataType = float;
|
||||
|
||||
using AVecType = ext_vector_t<bf16_t, 4>;
|
||||
using BVecType = ext_vector_t<bf16_t, 4>;
|
||||
using CVecType = ext_vector_t<float, 4>;
|
||||
|
||||
static constexpr index_t kM = 16;
|
||||
static constexpr index_t kN = 16;
|
||||
static constexpr index_t kK = 16;
|
||||
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABKLane = 4;
|
||||
static constexpr index_t kABKPerLane = 4;
|
||||
|
||||
static constexpr index_t kCMLane = 4;
|
||||
static constexpr index_t kCNLane = 16;
|
||||
static constexpr index_t kCM0PerLane = 1;
|
||||
static constexpr index_t kCM1PerLane = 4;
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
#if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
|
||||
#elif defined(__gfx908__)
|
||||
static_for<0, 2, 1>{}([&](auto k) {
|
||||
c_vec = __builtin_amdgcn_mfma_f32_16x16x8bf16(
|
||||
reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
|
||||
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
|
||||
reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
|
||||
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
|
||||
c_vec,
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
});
|
||||
#else
|
||||
ck_tile::ignore = c_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
#endif
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
#if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
return bit_cast<CVecType>(
|
||||
__builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
|
||||
#elif defined(__gfx908__)
|
||||
CVecType c_vec{0.f};
|
||||
static_for<0, 2, 1>{}([&](auto k) {
|
||||
c_vec = __builtin_amdgcn_mfma_f32_16x16x8bf16(
|
||||
reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
|
||||
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
|
||||
reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
|
||||
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
|
||||
c_vec,
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
});
|
||||
return c_vec;
|
||||
#else
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
return CVecType{0.f};
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
// FP8
|
||||
template <typename AType_, typename BType_>
|
||||
struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
|
||||
{
|
||||
using ADataType = AType_;
|
||||
using BDataType = BType_;
|
||||
using CDataType = float;
|
||||
|
||||
using AVecType = ext_vector_t<ADataType, 8>;
|
||||
using BVecType = ext_vector_t<BDataType, 8>;
|
||||
using CVecType = ext_vector_t<CDataType, 16>;
|
||||
|
||||
static constexpr index_t kM = 32;
|
||||
static constexpr index_t kN = 32;
|
||||
static constexpr index_t kK = 16;
|
||||
|
||||
static constexpr index_t kAMLane = 32;
|
||||
static constexpr index_t kBNLane = 32;
|
||||
static constexpr index_t kABKLane = 2;
|
||||
static constexpr index_t kABKPerLane = 8;
|
||||
|
||||
static constexpr index_t kCMLane = 2;
|
||||
static constexpr index_t kCNLane = 32;
|
||||
static constexpr index_t kCM0PerLane = 4;
|
||||
static constexpr index_t kCM1PerLane = 4;
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
#elif defined(__gfx908__) || defined(__gfx90a__)
|
||||
static_for<0, 8, 1>{}([&](auto k) {
|
||||
float a_f32 =
|
||||
type_convert<float>(reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
|
||||
.template get_as<ADataType>()[number<k>{}]);
|
||||
float b_f32 =
|
||||
type_convert<float>(reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
|
||||
.template get_as<BDataType>()[number<k>{}]);
|
||||
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0);
|
||||
});
|
||||
#else
|
||||
ck_tile::ignore = c_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
#endif
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
#elif defined(__gfx908__) || defined(__gfx90a__)
|
||||
CVecType c_vec{0.f};
|
||||
static_for<0, 8, 1>{}([&](auto k) {
|
||||
float a_f32 =
|
||||
type_convert<float>(reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
|
||||
.template get_as<ADataType>()[number<k>{}]);
|
||||
float b_f32 =
|
||||
type_convert<float>(reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
|
||||
.template get_as<BDataType>()[number<k>{}]);
|
||||
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0);
|
||||
});
|
||||
return c_vec;
|
||||
#else
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
return CVecType{0.f};
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
using WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8 =
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<fp8_t, fp8_t>;
|
||||
using WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8 =
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<fp8_t, bf8_t>;
|
||||
using WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8 =
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<bf8_t, fp8_t>;
|
||||
using WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8 =
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<bf8_t, bf8_t>;
|
||||
|
||||
} // namespace ck_tile
|
||||
65
include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
Normal file
65
include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
Normal file
@@ -0,0 +1,65 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
namespace impl {
|
||||
template <typename AType,
|
||||
typename BType,
|
||||
typename CType,
|
||||
index_t MPerWave,
|
||||
index_t NPerWave,
|
||||
index_t KPerWave,
|
||||
bool TransposeC>
|
||||
struct WarpGemmMfmaDispatcher;
|
||||
|
||||
// clang-format off
|
||||
// fp16
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaF16F16F32M32N32K8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaF16F16F32M32N32K16; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaF16F16F32M16N16K16; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaF16F16F32M16N16K32; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; };
|
||||
|
||||
// bf16
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; };
|
||||
|
||||
// fp8
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; };
|
||||
|
||||
// clang-format on
|
||||
} // namespace impl
|
||||
|
||||
template <typename AType,
|
||||
typename BType,
|
||||
typename CType,
|
||||
index_t MPerWave,
|
||||
index_t NPerWave,
|
||||
index_t KPerWave,
|
||||
bool TransposeC>
|
||||
using WarpGemmMfmaDispatcher = typename impl::
|
||||
WarpGemmMfmaDispatcher<AType, BType, CType, MPerWave, NPerWave, KPerWave, TransposeC>::Type;
|
||||
|
||||
} // namespace ck_tile
|
||||
74
include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp
Normal file
74
include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp
Normal file
@@ -0,0 +1,74 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename WarpGemmAttribute_>
|
||||
struct WarpGemmImpl
|
||||
{
|
||||
using WarpGemmAttribute = remove_cvref_t<WarpGemmAttribute_>;
|
||||
|
||||
static constexpr index_t kM = WarpGemmAttribute::kM;
|
||||
static constexpr index_t kN = WarpGemmAttribute::kN;
|
||||
static constexpr index_t kK = WarpGemmAttribute::kK;
|
||||
|
||||
using ADataType = typename WarpGemmAttribute::ADataType;
|
||||
using BDataType = typename WarpGemmAttribute::BDataType;
|
||||
using CDataType = typename WarpGemmAttribute::CDataType;
|
||||
|
||||
using AWarpDstrEncoding = typename WarpGemmAttribute::AWarpDstrEncoding;
|
||||
using BWarpDstrEncoding = typename WarpGemmAttribute::BWarpDstrEncoding;
|
||||
using CWarpDstrEncoding = typename WarpGemmAttribute::CWarpDstrEncoding;
|
||||
|
||||
using AWarpDstr = remove_cvref_t<decltype(make_static_tile_distribution(AWarpDstrEncoding{}))>;
|
||||
using BWarpDstr = remove_cvref_t<decltype(make_static_tile_distribution(BWarpDstrEncoding{}))>;
|
||||
using CWarpDstr = remove_cvref_t<decltype(make_static_tile_distribution(CWarpDstrEncoding{}))>;
|
||||
|
||||
using AWarpTensor = static_distributed_tensor<ADataType, AWarpDstr>;
|
||||
using BWarpTensor = static_distributed_tensor<BDataType, BWarpDstr>;
|
||||
using CWarpTensor = static_distributed_tensor<CDataType, CWarpDstr>;
|
||||
|
||||
CK_TILE_DEVICE void operator()(CWarpTensor& c, const AWarpTensor& a, const BWarpTensor& b) const
|
||||
{
|
||||
using AVec = ext_vector_t<ADataType, AWarpTensor::get_thread_buffer_size()>;
|
||||
using BVec = ext_vector_t<BDataType, BWarpTensor::get_thread_buffer_size()>;
|
||||
using CVec = ext_vector_t<CDataType, CWarpTensor::get_thread_buffer_size()>;
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
const auto a_vec = a.get_thread_buffer().template get_as<AVec>()[I0];
|
||||
const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
|
||||
auto c_vec = c.get_thread_buffer().template get_as<CVec>()[I0];
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
WarpGemmAttribute{}(c_vec, a_vec, b_vec);
|
||||
|
||||
c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE auto operator()(const AWarpTensor& a, const BWarpTensor& b) const
|
||||
{
|
||||
CWarpTensor c;
|
||||
|
||||
using AVec = ext_vector_t<ADataType, AWarpTensor::get_thread_buffer_size()>;
|
||||
using BVec = ext_vector_t<BDataType, BWarpTensor::get_thread_buffer_size()>;
|
||||
using CVec = ext_vector_t<CDataType, CWarpTensor::get_thread_buffer_size()>;
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
const auto a_vec = a.get_thread_buffer().template get_as<AVec>()[I0];
|
||||
const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
auto c_vec = WarpGemmAttribute{}(a_vec, b_vec);
|
||||
|
||||
c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
|
||||
|
||||
return c;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user