fix fp8 duplicated move/shift/and/or problem

This commit is contained in:
carlushuang
2024-03-19 23:29:51 +00:00
parent 886d040a81
commit bb1f6e48eb
8 changed files with 108 additions and 134 deletions

View File

@@ -54,3 +54,4 @@
#include "ck_tile/core/utility/to_sequence.hpp"
#include "ck_tile/core/utility/transpose_vectors.hpp"
#include "ck_tile/core/utility/type_traits.hpp"

View File

@@ -20,3 +20,4 @@
#include "ck_tile/host/reference/reference_reduce.hpp"
#include "ck_tile/host/reference/reference_softmax.hpp"
#include "ck_tile/host/stream_config.hpp"

View File

@@ -4,3 +4,4 @@
#pragma once
#include "ck_tile/ops/common/tensor_layout.hpp"

View File

@@ -5,3 +5,4 @@
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"

View File

@@ -18,3 +18,4 @@
#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp"
#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"

View File

@@ -75,8 +75,10 @@ struct WarpGemmAtrributeMfmaIterateK
using BDataType = typename Impl::BDataType;
using CDataType = typename Impl::CDataType;
using AVecType = array<ADataType, Impl::AVecType::size() * kKIter>;
using BVecType = array<BDataType, Impl::BVecType::size() * kKIter>;
using AVecType =
ext_vector_t<ADataType, vector_traits<typename Impl::AVecType>::vector_size * kKIter>;
using BVecType =
ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>;
using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kM;
@@ -112,10 +114,15 @@ struct WarpGemmAtrributeMfmaIterateK
CK_TILE_DEVICE void
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
{
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
a_vec.template get_as<typename Impl::AVecType>()[iKIter],
b_vec.template get_as<typename Impl::BVecType>()[iKIter]);
reinterpret_cast<const buf_a>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter],
reinterpret_cast<const buf_b>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter]);
});
}
@@ -123,16 +130,21 @@ struct WarpGemmAtrributeMfmaIterateK
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
constexpr auto I0 = number<0>{};
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
// c = a * b
auto c_vec = Impl{}(a_vec.template get_as<typename Impl::AVecType>()[I0],
b_vec.template get_as<typename Impl::BVecType>()[I0]);
auto c_vec = Impl{}(
reinterpret_cast<const buf_a>(a_vec).template get_as<typename Impl::AVecType>()[I0],
reinterpret_cast<const buf_b>(b_vec).template get_as<typename Impl::BVecType>()[I0]);
// c += a * b
static_for<1, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
a_vec.template get_as<typename Impl::AVecType>()[iKIter],
b_vec.template get_as<typename Impl::BVecType>()[iKIter]);
reinterpret_cast<const buf_a>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter],
reinterpret_cast<const buf_b>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter]);
});
return c_vec;
@@ -269,8 +281,10 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
using BDataType = typename Impl::ADataType;
using CDataType = typename Impl::CDataType;
using AVecType = array<ADataType, Impl::AVecType::size() * kKIter>;
using BVecType = array<BDataType, Impl::BVecType::size() * kKIter>;
using AVecType =
ext_vector_t<ADataType, vector_traits<typename Impl::AVecType>::vector_size * kKIter>;
using BVecType =
ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>;
using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kN;
@@ -306,11 +320,15 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
CK_TILE_DEVICE void
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
{
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
// swap A and B, value and type
static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
b_vec.template get_as<typename Impl::AVecType>()[iKIter],
a_vec.template get_as<typename Impl::BVecType>()[iKIter]);
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter],
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter]);
});
}
@@ -318,15 +336,20 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
constexpr auto I0 = number<0>{};
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
// swap A and B, value and type
auto c_vec = Impl{}(b_vec.template get_as<typename Impl::AVecType>()[I0],
a_vec.template get_as<typename Impl::BVecType>()[I0]);
auto c_vec = Impl{}(
reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0],
reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0]);
static_for<1, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
b_vec.template get_as<typename Impl::AVecType>()[iKIter],
a_vec.template get_as<typename Impl::BVecType>()[iKIter]);
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter],
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter]);
});
return c_vec;
@@ -343,8 +366,10 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
using BDataType = typename Impl::ADataType;
using CDataType = typename Impl::CDataType;
using AVecType = array<ADataType, Impl::AVecType::size() * kKIter>;
using BVecType = array<BDataType, Impl::BVecType::size() * kKIter>;
using AVecType =
ext_vector_t<ADataType, vector_traits<typename Impl::AVecType>::vector_size * kKIter>;
using BVecType =
ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>;
using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kN;
@@ -406,27 +431,36 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
CK_TILE_DEVICE void
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
{
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
// swap A and B, value and type
static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
b_vec.template get_as<typename Impl::AVecType>()[iKIter],
a_vec.template get_as<typename Impl::BVecType>()[iKIter]);
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter],
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter]);
});
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
constexpr auto I0 = number<0>{};
// swap A and B, value and type
auto c_vec = Impl{}(b_vec.template get_as<typename Impl::AVecType>()[I0],
a_vec.template get_as<typename Impl::BVecType>()[I0]);
auto c_vec = Impl{}(
reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0],
reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0]);
static_for<1, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
b_vec.template get_as<typename Impl::AVecType>()[iKIter],
a_vec.template get_as<typename Impl::BVecType>()[iKIter]);
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter],
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter]);
});
return c_vec;

View File

@@ -14,9 +14,9 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
using BDataType = fp16_t;
using CDataType = float;
using AVecType = array<fp16_t, 4>;
using BVecType = array<fp16_t, 4>;
using CVecType = array<float, 16>;
using AVecType = ext_vector_t<fp16_t, 4>;
using BVecType = ext_vector_t<fp16_t, 4>;
using CVecType = ext_vector_t<float, 16>;
static constexpr index_t kM = 32;
static constexpr index_t kN = 32;
@@ -36,25 +36,14 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
CK_TILE_DEVICE void
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
{
c_vec.template get_as<fp32x16_t>()[number<0>{}] =
__builtin_amdgcn_mfma_f32_32x32x8f16(a_vec.template get_as<fp16x4_t>()[number<0>{}],
b_vec.template get_as<fp16x4_t>()[number<0>{}],
c_vec.template get_as<fp32x16_t>()[number<0>{}],
0,
0,
0);
c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, c_vec, 0, 0, 0);
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
return bit_cast<CVecType>(
__builtin_amdgcn_mfma_f32_32x32x8f16(a_vec.template get_as<fp16x4_t>()[number<0>{}],
b_vec.template get_as<fp16x4_t>()[number<0>{}],
fp32x16_t{0.f},
0,
0,
0));
__builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0));
}
};
@@ -64,9 +53,9 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
using BDataType = fp16_t;
using CDataType = float;
using AVecType = array<fp16_t, 4>;
using BVecType = array<fp16_t, 4>;
using CVecType = array<float, 4>;
using AVecType = ext_vector_t<fp16_t, 4>;
using BVecType = ext_vector_t<fp16_t, 4>;
using CVecType = ext_vector_t<float, 4>;
static constexpr index_t kM = 16;
static constexpr index_t kN = 16;
@@ -86,25 +75,14 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
CK_TILE_DEVICE void
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
{
c_vec.template get_as<fp32x4_t>()[number<0>{}] =
__builtin_amdgcn_mfma_f32_16x16x16f16(a_vec.template get_as<fp16x4_t>()[number<0>{}],
b_vec.template get_as<fp16x4_t>()[number<0>{}],
c_vec.template get_as<fp32x4_t>()[number<0>{}],
0,
0,
0);
c_vec = __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, c_vec, 0, 0, 0);
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
return bit_cast<CVecType>(
__builtin_amdgcn_mfma_f32_16x16x16f16(a_vec.template get_as<fp16x4_t>()[number<0>{}],
b_vec.template get_as<fp16x4_t>()[number<0>{}],
fp32x4_t{0.f},
0,
0,
0));
__builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
}
};
@@ -115,9 +93,9 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
using BDataType = bf16_t;
using CDataType = float;
using AVecType = array<bf16_t, 4>;
using BVecType = array<bf16_t, 4>;
using CVecType = array<float, 16>;
using AVecType = ext_vector_t<bf16_t, 4>;
using BVecType = ext_vector_t<bf16_t, 4>;
using CVecType = ext_vector_t<float, 16>;
static constexpr index_t kM = 32;
static constexpr index_t kN = 32;
@@ -137,25 +115,14 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
CK_TILE_DEVICE void
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
{
c_vec.template get_as<fp32x16_t>()[number<0>{}] = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
a_vec.template get_as<bf16x4_t>()[number<0>{}],
b_vec.template get_as<bf16x4_t>()[number<0>{}],
c_vec.template get_as<fp32x16_t>()[number<0>{}],
0,
0,
0);
c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
return bit_cast<CVecType>(
__builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec.template get_as<bf16x4_t>()[number<0>{}],
b_vec.template get_as<bf16x4_t>()[number<0>{}],
fp32x16_t{0.f},
0,
0,
0));
__builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0));
}
};
@@ -165,9 +132,9 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
using BDataType = bf16_t;
using CDataType = float;
using AVecType = array<bf16_t, 4>;
using BVecType = array<bf16_t, 4>;
using CVecType = array<float, 4>;
using AVecType = ext_vector_t<bf16_t, 4>;
using BVecType = ext_vector_t<bf16_t, 4>;
using CVecType = ext_vector_t<float, 4>;
static constexpr index_t kM = 16;
static constexpr index_t kN = 16;
@@ -187,25 +154,14 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
CK_TILE_DEVICE void
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
{
c_vec.template get_as<fp32x4_t>()[number<0>{}] = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
a_vec.template get_as<bf16x4_t>()[number<0>{}],
b_vec.template get_as<bf16x4_t>()[number<0>{}],
c_vec.template get_as<fp32x4_t>()[number<0>{}],
0,
0,
0);
c_vec = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
a_vec.template get_as<bf16x4_t>()[number<0>{}],
b_vec.template get_as<bf16x4_t>()[number<0>{}],
fp32x4_t{0.f},
0,
0,
0));
return bit_cast<CVecType>(
__builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
}
};
@@ -217,9 +173,9 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
using BDataType = BType_;
using CDataType = float;
using AVecType = array<ADataType, 8>;
using BVecType = array<BDataType, 8>;
using CVecType = array<CDataType, 16>;
using AVecType = ext_vector_t<ADataType, 8>;
using BVecType = ext_vector_t<BDataType, 8>;
using CVecType = ext_vector_t<CDataType, 16>;
static constexpr index_t kM = 32;
static constexpr index_t kN = 32;
@@ -241,48 +197,27 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
c_vec.template get_as<fp32x16_t>()[number<0>{}] =
__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
bit_cast<long>(a_vec),
bit_cast<long>(b_vec),
c_vec.template get_as<fp32x16_t>()[number<0>{}],
0,
0,
0);
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
c_vec.template get_as<fp32x16_t>()[number<0>{}] =
__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
bit_cast<long>(a_vec),
bit_cast<long>(b_vec),
c_vec.template get_as<fp32x16_t>()[number<0>{}],
0,
0,
0);
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
c_vec.template get_as<fp32x16_t>()[number<0>{}] =
__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
bit_cast<long>(a_vec),
bit_cast<long>(b_vec),
c_vec.template get_as<fp32x16_t>()[number<0>{}],
0,
0,
0);
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
c_vec.template get_as<fp32x16_t>()[number<0>{}] =
__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
bit_cast<long>(a_vec),
bit_cast<long>(b_vec),
c_vec.template get_as<fp32x16_t>()[number<0>{}],
0,
0,
0);
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
#else
static_for<0, 8, 1>{}([&](auto k) {
float a_f32 = type_convert<float>(a_vec.template get_as<ADataType>()[number<k>{}]);
float b_f32 = type_convert<float>(b_vec.template get_as<BDataType>()[number<k>{}]);
float a_f32 =
type_convert<float>(reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
.template get_as<ADataType>()[number<k>{}]);
float b_f32 =
type_convert<float>(reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
.template get_as<BDataType>()[number<k>{}]);
c_vec.template get_as<fp32x16_t>()[number<0>{}] = __builtin_amdgcn_mfma_f32_32x32x2f32(
a_f32, b_f32, c_vec.template get_as<fp32x16_t>()[number<0>{}], 0, 0, 0);
c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0);
});
#endif
}

View File

@@ -33,9 +33,9 @@ struct WarpGemmImpl
CK_TILE_DEVICE void operator()(CWarpTensor& c, const AWarpTensor& a, const BWarpTensor& b) const
{
using AVec = array<ADataType, AWarpTensor::get_thread_buffer_size()>;
using BVec = array<BDataType, BWarpTensor::get_thread_buffer_size()>;
using CVec = array<CDataType, CWarpTensor::get_thread_buffer_size()>;
using AVec = ext_vector_t<ADataType, AWarpTensor::get_thread_buffer_size()>;
using BVec = ext_vector_t<BDataType, BWarpTensor::get_thread_buffer_size()>;
using CVec = ext_vector_t<CDataType, CWarpTensor::get_thread_buffer_size()>;
constexpr auto I0 = number<0>{};
@@ -53,9 +53,9 @@ struct WarpGemmImpl
{
CWarpTensor c;
using AVec = array<ADataType, AWarpTensor::get_thread_buffer_size()>;
using BVec = array<BDataType, BWarpTensor::get_thread_buffer_size()>;
using CVec = array<CDataType, CWarpTensor::get_thread_buffer_size()>;
using AVec = ext_vector_t<ADataType, AWarpTensor::get_thread_buffer_size()>;
using BVec = ext_vector_t<BDataType, BWarpTensor::get_thread_buffer_size()>;
using CVec = ext_vector_t<CDataType, CWarpTensor::get_thread_buffer_size()>;
constexpr auto I0 = number<0>{};