mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
fix fp8 duplicated move/shift/and/or problem
This commit is contained in:
@@ -54,3 +54,4 @@
|
||||
#include "ck_tile/core/utility/to_sequence.hpp"
|
||||
#include "ck_tile/core/utility/transpose_vectors.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
|
||||
@@ -20,3 +20,4 @@
|
||||
#include "ck_tile/host/reference/reference_reduce.hpp"
|
||||
#include "ck_tile/host/reference/reference_softmax.hpp"
|
||||
#include "ck_tile/host/stream_config.hpp"
|
||||
|
||||
|
||||
@@ -4,3 +4,4 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
|
||||
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
|
||||
@@ -18,3 +18,4 @@
|
||||
#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
|
||||
@@ -75,8 +75,10 @@ struct WarpGemmAtrributeMfmaIterateK
|
||||
using BDataType = typename Impl::BDataType;
|
||||
using CDataType = typename Impl::CDataType;
|
||||
|
||||
using AVecType = array<ADataType, Impl::AVecType::size() * kKIter>;
|
||||
using BVecType = array<BDataType, Impl::BVecType::size() * kKIter>;
|
||||
using AVecType =
|
||||
ext_vector_t<ADataType, vector_traits<typename Impl::AVecType>::vector_size * kKIter>;
|
||||
using BVecType =
|
||||
ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>;
|
||||
using CVecType = typename Impl::CVecType;
|
||||
|
||||
static constexpr index_t kM = Impl::kM;
|
||||
@@ -112,10 +114,15 @@ struct WarpGemmAtrributeMfmaIterateK
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
|
||||
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
|
||||
|
||||
static_for<0, kKIter, 1>{}([&](auto iKIter) {
|
||||
Impl{}(c_vec,
|
||||
a_vec.template get_as<typename Impl::AVecType>()[iKIter],
|
||||
b_vec.template get_as<typename Impl::BVecType>()[iKIter]);
|
||||
reinterpret_cast<const buf_a>(a_vec)
|
||||
.template get_as<typename Impl::AVecType>()[iKIter],
|
||||
reinterpret_cast<const buf_b>(b_vec)
|
||||
.template get_as<typename Impl::BVecType>()[iKIter]);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -123,16 +130,21 @@ struct WarpGemmAtrributeMfmaIterateK
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
constexpr auto I0 = number<0>{};
|
||||
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
|
||||
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
|
||||
|
||||
// c = a * b
|
||||
auto c_vec = Impl{}(a_vec.template get_as<typename Impl::AVecType>()[I0],
|
||||
b_vec.template get_as<typename Impl::BVecType>()[I0]);
|
||||
auto c_vec = Impl{}(
|
||||
reinterpret_cast<const buf_a>(a_vec).template get_as<typename Impl::AVecType>()[I0],
|
||||
reinterpret_cast<const buf_b>(b_vec).template get_as<typename Impl::BVecType>()[I0]);
|
||||
|
||||
// c += a * b
|
||||
static_for<1, kKIter, 1>{}([&](auto iKIter) {
|
||||
Impl{}(c_vec,
|
||||
a_vec.template get_as<typename Impl::AVecType>()[iKIter],
|
||||
b_vec.template get_as<typename Impl::BVecType>()[iKIter]);
|
||||
reinterpret_cast<const buf_a>(a_vec)
|
||||
.template get_as<typename Impl::AVecType>()[iKIter],
|
||||
reinterpret_cast<const buf_b>(b_vec)
|
||||
.template get_as<typename Impl::BVecType>()[iKIter]);
|
||||
});
|
||||
|
||||
return c_vec;
|
||||
@@ -269,8 +281,10 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
|
||||
using BDataType = typename Impl::ADataType;
|
||||
using CDataType = typename Impl::CDataType;
|
||||
|
||||
using AVecType = array<ADataType, Impl::AVecType::size() * kKIter>;
|
||||
using BVecType = array<BDataType, Impl::BVecType::size() * kKIter>;
|
||||
using AVecType =
|
||||
ext_vector_t<ADataType, vector_traits<typename Impl::AVecType>::vector_size * kKIter>;
|
||||
using BVecType =
|
||||
ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>;
|
||||
using CVecType = typename Impl::CVecType;
|
||||
|
||||
static constexpr index_t kM = Impl::kN;
|
||||
@@ -306,11 +320,15 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
|
||||
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
|
||||
// swap A and B, value and type
|
||||
static_for<0, kKIter, 1>{}([&](auto iKIter) {
|
||||
Impl{}(c_vec,
|
||||
b_vec.template get_as<typename Impl::AVecType>()[iKIter],
|
||||
a_vec.template get_as<typename Impl::BVecType>()[iKIter]);
|
||||
reinterpret_cast<const buf_b&>(b_vec)
|
||||
.template get_as<typename Impl::BVecType>()[iKIter],
|
||||
reinterpret_cast<const buf_a&>(a_vec)
|
||||
.template get_as<typename Impl::AVecType>()[iKIter]);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -318,15 +336,20 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
constexpr auto I0 = number<0>{};
|
||||
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
|
||||
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
|
||||
|
||||
// swap A and B, value and type
|
||||
auto c_vec = Impl{}(b_vec.template get_as<typename Impl::AVecType>()[I0],
|
||||
a_vec.template get_as<typename Impl::BVecType>()[I0]);
|
||||
auto c_vec = Impl{}(
|
||||
reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0],
|
||||
reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0]);
|
||||
|
||||
static_for<1, kKIter, 1>{}([&](auto iKIter) {
|
||||
Impl{}(c_vec,
|
||||
b_vec.template get_as<typename Impl::AVecType>()[iKIter],
|
||||
a_vec.template get_as<typename Impl::BVecType>()[iKIter]);
|
||||
reinterpret_cast<const buf_b&>(b_vec)
|
||||
.template get_as<typename Impl::BVecType>()[iKIter],
|
||||
reinterpret_cast<const buf_a&>(a_vec)
|
||||
.template get_as<typename Impl::AVecType>()[iKIter]);
|
||||
});
|
||||
|
||||
return c_vec;
|
||||
@@ -343,8 +366,10 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
|
||||
using BDataType = typename Impl::ADataType;
|
||||
using CDataType = typename Impl::CDataType;
|
||||
|
||||
using AVecType = array<ADataType, Impl::AVecType::size() * kKIter>;
|
||||
using BVecType = array<BDataType, Impl::BVecType::size() * kKIter>;
|
||||
using AVecType =
|
||||
ext_vector_t<ADataType, vector_traits<typename Impl::AVecType>::vector_size * kKIter>;
|
||||
using BVecType =
|
||||
ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>;
|
||||
using CVecType = typename Impl::CVecType;
|
||||
|
||||
static constexpr index_t kM = Impl::kN;
|
||||
@@ -406,27 +431,36 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
|
||||
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
|
||||
// swap A and B, value and type
|
||||
static_for<0, kKIter, 1>{}([&](auto iKIter) {
|
||||
Impl{}(c_vec,
|
||||
b_vec.template get_as<typename Impl::AVecType>()[iKIter],
|
||||
a_vec.template get_as<typename Impl::BVecType>()[iKIter]);
|
||||
reinterpret_cast<const buf_b&>(b_vec)
|
||||
.template get_as<typename Impl::BVecType>()[iKIter],
|
||||
reinterpret_cast<const buf_a&>(a_vec)
|
||||
.template get_as<typename Impl::AVecType>()[iKIter]);
|
||||
});
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
|
||||
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
// swap A and B, value and type
|
||||
auto c_vec = Impl{}(b_vec.template get_as<typename Impl::AVecType>()[I0],
|
||||
a_vec.template get_as<typename Impl::BVecType>()[I0]);
|
||||
auto c_vec = Impl{}(
|
||||
reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0],
|
||||
reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0]);
|
||||
|
||||
static_for<1, kKIter, 1>{}([&](auto iKIter) {
|
||||
Impl{}(c_vec,
|
||||
b_vec.template get_as<typename Impl::AVecType>()[iKIter],
|
||||
a_vec.template get_as<typename Impl::BVecType>()[iKIter]);
|
||||
reinterpret_cast<const buf_b&>(b_vec)
|
||||
.template get_as<typename Impl::BVecType>()[iKIter],
|
||||
reinterpret_cast<const buf_a&>(a_vec)
|
||||
.template get_as<typename Impl::AVecType>()[iKIter]);
|
||||
});
|
||||
|
||||
return c_vec;
|
||||
|
||||
@@ -14,9 +14,9 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
|
||||
using BDataType = fp16_t;
|
||||
using CDataType = float;
|
||||
|
||||
using AVecType = array<fp16_t, 4>;
|
||||
using BVecType = array<fp16_t, 4>;
|
||||
using CVecType = array<float, 16>;
|
||||
using AVecType = ext_vector_t<fp16_t, 4>;
|
||||
using BVecType = ext_vector_t<fp16_t, 4>;
|
||||
using CVecType = ext_vector_t<float, 16>;
|
||||
|
||||
static constexpr index_t kM = 32;
|
||||
static constexpr index_t kN = 32;
|
||||
@@ -36,25 +36,14 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
c_vec.template get_as<fp32x16_t>()[number<0>{}] =
|
||||
__builtin_amdgcn_mfma_f32_32x32x8f16(a_vec.template get_as<fp16x4_t>()[number<0>{}],
|
||||
b_vec.template get_as<fp16x4_t>()[number<0>{}],
|
||||
c_vec.template get_as<fp32x16_t>()[number<0>{}],
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, c_vec, 0, 0, 0);
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
return bit_cast<CVecType>(
|
||||
__builtin_amdgcn_mfma_f32_32x32x8f16(a_vec.template get_as<fp16x4_t>()[number<0>{}],
|
||||
b_vec.template get_as<fp16x4_t>()[number<0>{}],
|
||||
fp32x16_t{0.f},
|
||||
0,
|
||||
0,
|
||||
0));
|
||||
__builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -64,9 +53,9 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
|
||||
using BDataType = fp16_t;
|
||||
using CDataType = float;
|
||||
|
||||
using AVecType = array<fp16_t, 4>;
|
||||
using BVecType = array<fp16_t, 4>;
|
||||
using CVecType = array<float, 4>;
|
||||
using AVecType = ext_vector_t<fp16_t, 4>;
|
||||
using BVecType = ext_vector_t<fp16_t, 4>;
|
||||
using CVecType = ext_vector_t<float, 4>;
|
||||
|
||||
static constexpr index_t kM = 16;
|
||||
static constexpr index_t kN = 16;
|
||||
@@ -86,25 +75,14 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
c_vec.template get_as<fp32x4_t>()[number<0>{}] =
|
||||
__builtin_amdgcn_mfma_f32_16x16x16f16(a_vec.template get_as<fp16x4_t>()[number<0>{}],
|
||||
b_vec.template get_as<fp16x4_t>()[number<0>{}],
|
||||
c_vec.template get_as<fp32x4_t>()[number<0>{}],
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
c_vec = __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, c_vec, 0, 0, 0);
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
return bit_cast<CVecType>(
|
||||
__builtin_amdgcn_mfma_f32_16x16x16f16(a_vec.template get_as<fp16x4_t>()[number<0>{}],
|
||||
b_vec.template get_as<fp16x4_t>()[number<0>{}],
|
||||
fp32x4_t{0.f},
|
||||
0,
|
||||
0,
|
||||
0));
|
||||
__builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -115,9 +93,9 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
|
||||
using BDataType = bf16_t;
|
||||
using CDataType = float;
|
||||
|
||||
using AVecType = array<bf16_t, 4>;
|
||||
using BVecType = array<bf16_t, 4>;
|
||||
using CVecType = array<float, 16>;
|
||||
using AVecType = ext_vector_t<bf16_t, 4>;
|
||||
using BVecType = ext_vector_t<bf16_t, 4>;
|
||||
using CVecType = ext_vector_t<float, 16>;
|
||||
|
||||
static constexpr index_t kM = 32;
|
||||
static constexpr index_t kN = 32;
|
||||
@@ -137,25 +115,14 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
c_vec.template get_as<fp32x16_t>()[number<0>{}] = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
|
||||
a_vec.template get_as<bf16x4_t>()[number<0>{}],
|
||||
b_vec.template get_as<bf16x4_t>()[number<0>{}],
|
||||
c_vec.template get_as<fp32x16_t>()[number<0>{}],
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
return bit_cast<CVecType>(
|
||||
__builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec.template get_as<bf16x4_t>()[number<0>{}],
|
||||
b_vec.template get_as<bf16x4_t>()[number<0>{}],
|
||||
fp32x16_t{0.f},
|
||||
0,
|
||||
0,
|
||||
0));
|
||||
__builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -165,9 +132,9 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
|
||||
using BDataType = bf16_t;
|
||||
using CDataType = float;
|
||||
|
||||
using AVecType = array<bf16_t, 4>;
|
||||
using BVecType = array<bf16_t, 4>;
|
||||
using CVecType = array<float, 4>;
|
||||
using AVecType = ext_vector_t<bf16_t, 4>;
|
||||
using BVecType = ext_vector_t<bf16_t, 4>;
|
||||
using CVecType = ext_vector_t<float, 4>;
|
||||
|
||||
static constexpr index_t kM = 16;
|
||||
static constexpr index_t kN = 16;
|
||||
@@ -187,25 +154,14 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
c_vec.template get_as<fp32x4_t>()[number<0>{}] = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
|
||||
a_vec.template get_as<bf16x4_t>()[number<0>{}],
|
||||
b_vec.template get_as<bf16x4_t>()[number<0>{}],
|
||||
c_vec.template get_as<fp32x4_t>()[number<0>{}],
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
c_vec = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
|
||||
a_vec.template get_as<bf16x4_t>()[number<0>{}],
|
||||
b_vec.template get_as<bf16x4_t>()[number<0>{}],
|
||||
fp32x4_t{0.f},
|
||||
0,
|
||||
0,
|
||||
0));
|
||||
return bit_cast<CVecType>(
|
||||
__builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -217,9 +173,9 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
|
||||
using BDataType = BType_;
|
||||
using CDataType = float;
|
||||
|
||||
using AVecType = array<ADataType, 8>;
|
||||
using BVecType = array<BDataType, 8>;
|
||||
using CVecType = array<CDataType, 16>;
|
||||
using AVecType = ext_vector_t<ADataType, 8>;
|
||||
using BVecType = ext_vector_t<BDataType, 8>;
|
||||
using CVecType = ext_vector_t<CDataType, 16>;
|
||||
|
||||
static constexpr index_t kM = 32;
|
||||
static constexpr index_t kN = 32;
|
||||
@@ -241,48 +197,27 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
c_vec.template get_as<fp32x16_t>()[number<0>{}] =
|
||||
__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
|
||||
bit_cast<long>(a_vec),
|
||||
bit_cast<long>(b_vec),
|
||||
c_vec.template get_as<fp32x16_t>()[number<0>{}],
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
c_vec.template get_as<fp32x16_t>()[number<0>{}] =
|
||||
__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
|
||||
bit_cast<long>(a_vec),
|
||||
bit_cast<long>(b_vec),
|
||||
c_vec.template get_as<fp32x16_t>()[number<0>{}],
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
c_vec.template get_as<fp32x16_t>()[number<0>{}] =
|
||||
__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
|
||||
bit_cast<long>(a_vec),
|
||||
bit_cast<long>(b_vec),
|
||||
c_vec.template get_as<fp32x16_t>()[number<0>{}],
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
c_vec.template get_as<fp32x16_t>()[number<0>{}] =
|
||||
__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
|
||||
bit_cast<long>(a_vec),
|
||||
bit_cast<long>(b_vec),
|
||||
c_vec.template get_as<fp32x16_t>()[number<0>{}],
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
#else
|
||||
static_for<0, 8, 1>{}([&](auto k) {
|
||||
float a_f32 = type_convert<float>(a_vec.template get_as<ADataType>()[number<k>{}]);
|
||||
float b_f32 = type_convert<float>(b_vec.template get_as<BDataType>()[number<k>{}]);
|
||||
float a_f32 =
|
||||
type_convert<float>(reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
|
||||
.template get_as<ADataType>()[number<k>{}]);
|
||||
float b_f32 =
|
||||
type_convert<float>(reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
|
||||
.template get_as<BDataType>()[number<k>{}]);
|
||||
|
||||
c_vec.template get_as<fp32x16_t>()[number<0>{}] = __builtin_amdgcn_mfma_f32_32x32x2f32(
|
||||
a_f32, b_f32, c_vec.template get_as<fp32x16_t>()[number<0>{}], 0, 0, 0);
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0);
|
||||
});
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -33,9 +33,9 @@ struct WarpGemmImpl
|
||||
|
||||
CK_TILE_DEVICE void operator()(CWarpTensor& c, const AWarpTensor& a, const BWarpTensor& b) const
|
||||
{
|
||||
using AVec = array<ADataType, AWarpTensor::get_thread_buffer_size()>;
|
||||
using BVec = array<BDataType, BWarpTensor::get_thread_buffer_size()>;
|
||||
using CVec = array<CDataType, CWarpTensor::get_thread_buffer_size()>;
|
||||
using AVec = ext_vector_t<ADataType, AWarpTensor::get_thread_buffer_size()>;
|
||||
using BVec = ext_vector_t<BDataType, BWarpTensor::get_thread_buffer_size()>;
|
||||
using CVec = ext_vector_t<CDataType, CWarpTensor::get_thread_buffer_size()>;
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
@@ -53,9 +53,9 @@ struct WarpGemmImpl
|
||||
{
|
||||
CWarpTensor c;
|
||||
|
||||
using AVec = array<ADataType, AWarpTensor::get_thread_buffer_size()>;
|
||||
using BVec = array<BDataType, BWarpTensor::get_thread_buffer_size()>;
|
||||
using CVec = array<CDataType, CWarpTensor::get_thread_buffer_size()>;
|
||||
using AVec = ext_vector_t<ADataType, AWarpTensor::get_thread_buffer_size()>;
|
||||
using BVec = ext_vector_t<BDataType, BWarpTensor::get_thread_buffer_size()>;
|
||||
using CVec = ext_vector_t<CDataType, CWarpTensor::get_thread_buffer_size()>;
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user