mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
Add VectorType support into StaticBuffer (#27)
* init StaticBufferV2 * clean * adopt old output stage for staticBufferV2 * clean * remove hack * clean * clean * clean code * move c_buffer alloc into blockwise gemm * add adaptors for m/n_thread_data_on_grid * adjust blockwise_gemm_xdlops * reorder ops in GEMM hot loop Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
@@ -10,6 +10,7 @@ namespace ck {
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename AK0MK1BlockDesc,
|
||||
typename BK0NK1BlockDesc,
|
||||
index_t MPerXDL,
|
||||
@@ -29,14 +30,18 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
|
||||
static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
|
||||
|
||||
static constexpr index_t K0 = BK0NK1BlockDesc{}.GetLength(I0);
|
||||
static constexpr index_t KPerBlock = K0;
|
||||
static constexpr index_t K0 = BK0NK1BlockDesc{}.GetLength(I0);
|
||||
|
||||
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, K1>{};
|
||||
|
||||
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
|
||||
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
|
||||
|
||||
StaticBufferV2<AddressSpaceEnum_t::Vgpr, vector_type<FloatAcc, 16>, MRepeat * NRepeat, true>
|
||||
c_thread_buf_;
|
||||
|
||||
__host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
|
||||
|
||||
__device__ static auto GetWaveIdx()
|
||||
{
|
||||
const index_t thread_id = get_thread_local_1d_id();
|
||||
@@ -162,7 +167,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
AK0MK1BlockDesc{},
|
||||
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
|
||||
make_tuple(make_pass_through_transform(Number<K0>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerXDL>{})),
|
||||
make_pass_through_transform(Number<K1>{})),
|
||||
@@ -174,7 +179,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
BK0NK1BlockDesc{},
|
||||
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
|
||||
make_tuple(make_pass_through_transform(Number<K0>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerXDL>{})),
|
||||
make_pass_through_transform(Number<K1>{})),
|
||||
@@ -195,48 +200,43 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAB>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
vector_type<FloatAB, K1> a_thread_vec;
|
||||
|
||||
vector_type<FloatAB, K1> b_thread_vec;
|
||||
|
||||
static_for<0, KPerBlock, xdlops_gemm.KPerXdlops / xdlops_gemm.KPerThread>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
// read A
|
||||
a_thread_copy_.Run(a_k0_m0_m1_m2_k1_block_desc,
|
||||
make_tuple(k0, I0, I0, I0, I0),
|
||||
make_tuple(I0, m0, I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// read B
|
||||
b_thread_copy_.Run(b_k0_n0_n1_n2_k1_block_desc,
|
||||
make_tuple(k0, I0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
// read B
|
||||
b_thread_copy_.Run(b_k0_n0_n1_n2_k1_block_desc,
|
||||
make_tuple(I0, n0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
using mfma_input_type = typename vector_type<FloatAB, xdlops_gemm.KPerThread>::type;
|
||||
static_for<0, K0, xdlops_gemm.K0PerXdlops>{}([&](auto k0) {
|
||||
vector_type<FloatAB, K1> a_thread_vec;
|
||||
vector_type<FloatAB, K1> b_thread_vec;
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, K1, 1>{}([&](auto i) {
|
||||
a_thread_vec.template AsType<FloatAB>()(i) = a_thread_buf
|
||||
[Number<a_thread_desc_.CalculateOffset(make_tuple(0, m0, 0, 0, i))>{}];
|
||||
});
|
||||
|
||||
static_for<0, K1, 1>{}([&](auto i) {
|
||||
[Number<a_thread_desc_.CalculateOffset(make_tuple(k0, 0, 0, 0, i))>{}];
|
||||
b_thread_vec.template AsType<FloatAB>()(i) = b_thread_buf
|
||||
[Number<b_thread_desc_.CalculateOffset(make_tuple(0, n0, 0, 0, i))>{}];
|
||||
[Number<b_thread_desc_.CalculateOffset(make_tuple(k0, 0, 0, 0, i))>{}];
|
||||
});
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
using mfma_input_type =
|
||||
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run<c_offset>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf);
|
||||
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0));
|
||||
|
||||
xdlops_gemm.template Run(a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf.GetVector(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -244,35 +244,35 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
|
||||
private:
|
||||
// A[K, M]
|
||||
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1, Number<MRepeat>{}, I1, I1, Number<K1>{}));
|
||||
static constexpr auto a_thread_desc_ =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<K0>{}, I1, I1, I1, Number<K1>{}));
|
||||
|
||||
// B[K, N]
|
||||
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1, Number<NRepeat>{}, I1, I1, Number<K1>{}));
|
||||
static constexpr auto b_thread_desc_ =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<K0>{}, I1, I1, I1, Number<K1>{}));
|
||||
|
||||
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, Number<xdlops_gemm.GetNumXdlops()>{}));
|
||||
static constexpr auto c_thread_desc_ =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
|
||||
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
|
||||
FloatAB,
|
||||
decltype(a_k0_m0_m1_m2_k1_block_desc),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, MRepeat, 1, 1, K1>,
|
||||
Sequence<K0, 1, 1, 1, K1>,
|
||||
Sequence<0, 1, 2, 3, 4>,
|
||||
4,
|
||||
K1,
|
||||
1>;
|
||||
K1>;
|
||||
|
||||
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
|
||||
FloatAB,
|
||||
decltype(b_k0_n0_n1_n2_k1_block_desc),
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<1, NRepeat, 1, 1, K1>,
|
||||
Sequence<K0, 1, 1, 1, K1>,
|
||||
Sequence<0, 1, 2, 3, 4>,
|
||||
4,
|
||||
K1,
|
||||
1>;
|
||||
K1>;
|
||||
|
||||
AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()};
|
||||
BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()};
|
||||
|
||||
@@ -142,6 +142,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
static constexpr auto I7 = Number<7>{};
|
||||
|
||||
// K1 should be Number<...>
|
||||
static constexpr auto K1 = Number<K1Value>{};
|
||||
@@ -220,6 +221,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
using BlockwiseGemm =
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_k0_m_k1_block_desc),
|
||||
decltype(b_k0_n_k1_block_desc),
|
||||
MPerXDL,
|
||||
@@ -363,9 +365,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
// register
|
||||
// sanity check
|
||||
|
||||
const auto blockwise_gemm =
|
||||
auto blockwise_gemm =
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_k0_m_k1_block_desc),
|
||||
decltype(b_k0_n_k1_block_desc),
|
||||
MPerXDL,
|
||||
@@ -374,18 +377,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
NRepeat,
|
||||
K1>{};
|
||||
|
||||
constexpr auto c_mr_nr_blk_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
|
||||
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
|
||||
blockwise_gemm.GetCM0N0M1N1M2M3M4N2ThreadDescriptor();
|
||||
constexpr auto CBlkSize = c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc.GetElementSpaceSize();
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
vector_type<FloatAcc, CBlkSize>,
|
||||
c_mr_nr_blk_desc.GetElementSpaceSize(),
|
||||
true>
|
||||
c_thread_buf;
|
||||
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size =
|
||||
@@ -460,9 +452,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
|
||||
blockwise_gemm.GetCM0N0M1N1M2M3M4N2BlockDescriptor();
|
||||
|
||||
constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0);
|
||||
constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1);
|
||||
constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2);
|
||||
constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3);
|
||||
constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4);
|
||||
constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5);
|
||||
constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
|
||||
constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7);
|
||||
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(
|
||||
Number<M0>{}, Number<N0>{}, I1, I1, Number<M2>{}, I1, Number<M4>{}, I1));
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
@@ -477,224 +478,54 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks = CGridStepHacks{};
|
||||
|
||||
const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
|
||||
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto m_thread_data_on_grid_idx =
|
||||
m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(m_thread_data_on_grid));
|
||||
|
||||
const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
|
||||
make_tuple(Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto n_thread_data_on_grid_idx =
|
||||
n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(n_thread_data_on_grid));
|
||||
|
||||
auto c_thread_copy =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<FloatC,
|
||||
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
|
||||
FloatC,
|
||||
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
|
||||
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc),
|
||||
Sequence<I1, I1, I1, I1, M2, I1, M4, I1>,
|
||||
Sequence<M0, N0, I1, I1, M2, I1, M4, I1>,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
CGlobalMemoryDataOperation,
|
||||
1,
|
||||
true>{
|
||||
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
make_multi_index(0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
m_thread_data_on_grid / (M3 * M4),
|
||||
m_thread_data_on_grid % (M3 * M4) / M4,
|
||||
m_thread_data_on_grid % M4,
|
||||
n_thread_data_on_grid)};
|
||||
make_multi_index(m_thread_data_on_grid_idx[I0],
|
||||
n_thread_data_on_grid_idx[I0],
|
||||
m_thread_data_on_grid_idx[I1],
|
||||
n_thread_data_on_grid_idx[I1],
|
||||
m_thread_data_on_grid_idx[I2],
|
||||
m_thread_data_on_grid_idx[I3],
|
||||
m_thread_data_on_grid_idx[I4],
|
||||
n_thread_data_on_grid_idx[I2])};
|
||||
|
||||
auto init_copy = [&](auto c_thread_idx_) {
|
||||
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
|
||||
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
|
||||
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
c_grid_buf,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
|
||||
|
||||
return c_thread_idx_;
|
||||
};
|
||||
|
||||
auto mrepeat_plus_copy = [&](auto c_thread_idx_) {
|
||||
constexpr auto mrepeat_step_plus = make_multi_index(1, 0, 0, 0, 0, 0, 0, 0);
|
||||
c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
mrepeat_step_plus);
|
||||
|
||||
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
|
||||
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
|
||||
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
c_grid_buf,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
|
||||
};
|
||||
|
||||
auto nrepeat_plus_copy = [&](auto c_thread_idx_) {
|
||||
constexpr auto nrepeat_step_plus = make_multi_index(0, 1, 0, 0, 0, 0, 0, 0);
|
||||
c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
nrepeat_step_plus);
|
||||
|
||||
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
|
||||
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
|
||||
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
c_grid_buf,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
|
||||
};
|
||||
|
||||
auto mrepeat_minus_copy = [&](auto c_thread_idx_) {
|
||||
constexpr auto mrepeat_step_plus = make_multi_index(-1, 0, 0, 0, 0, 0, 0, 0);
|
||||
c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
mrepeat_step_plus);
|
||||
|
||||
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
|
||||
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
|
||||
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
c_grid_buf,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
|
||||
};
|
||||
|
||||
auto nrepeat_minus_copy = [&](auto c_thread_idx_) {
|
||||
constexpr auto nrepeat_step_minus = make_multi_index(0, -1, 0, 0, 0, 0, 0, 0);
|
||||
c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
nrepeat_step_minus);
|
||||
|
||||
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
|
||||
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
|
||||
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
c_grid_buf,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
|
||||
};
|
||||
|
||||
static_assert((MRepeat == 4 && NRepeat == 4) or (MRepeat == 4 && NRepeat == 2) or
|
||||
(MRepeat == 2 && NRepeat == 4) or (MRepeat == 2 && NRepeat == 2) or
|
||||
(MRepeat == 2 && NRepeat == 1) or (MRepeat == 1 && NRepeat == 2) or
|
||||
(MRepeat == 1 && NRepeat == 1),
|
||||
"wrong");
|
||||
|
||||
if constexpr(MRepeat == 4 && NRepeat == 4)
|
||||
{
|
||||
init_copy(make_tuple(I0, I0));
|
||||
|
||||
if constexpr(CAccessOrderMRepeatNRepeat)
|
||||
{
|
||||
nrepeat_plus_copy(make_tuple(I0, I1));
|
||||
nrepeat_plus_copy(make_tuple(I0, I2));
|
||||
nrepeat_plus_copy(make_tuple(I0, I3));
|
||||
mrepeat_plus_copy(make_tuple(I1, I3));
|
||||
nrepeat_minus_copy(make_tuple(I1, I2));
|
||||
nrepeat_minus_copy(make_tuple(I1, I1));
|
||||
nrepeat_minus_copy(make_tuple(I1, I0));
|
||||
mrepeat_plus_copy(make_tuple(I2, I0));
|
||||
nrepeat_plus_copy(make_tuple(I2, I1));
|
||||
nrepeat_plus_copy(make_tuple(I2, I2));
|
||||
nrepeat_plus_copy(make_tuple(I2, I3));
|
||||
mrepeat_plus_copy(make_tuple(I3, I3));
|
||||
nrepeat_minus_copy(make_tuple(I3, I2));
|
||||
nrepeat_minus_copy(make_tuple(I3, I1));
|
||||
nrepeat_minus_copy(make_tuple(I3, I0));
|
||||
}
|
||||
else
|
||||
{
|
||||
mrepeat_plus_copy(make_tuple(I1, I0));
|
||||
mrepeat_plus_copy(make_tuple(I2, I0));
|
||||
mrepeat_plus_copy(make_tuple(I3, I0));
|
||||
nrepeat_plus_copy(make_tuple(I3, I1));
|
||||
mrepeat_minus_copy(make_tuple(I2, I1));
|
||||
mrepeat_minus_copy(make_tuple(I1, I1));
|
||||
mrepeat_minus_copy(make_tuple(I0, I1));
|
||||
nrepeat_plus_copy(make_tuple(I0, I2));
|
||||
mrepeat_plus_copy(make_tuple(I1, I2));
|
||||
mrepeat_plus_copy(make_tuple(I2, I2));
|
||||
mrepeat_plus_copy(make_tuple(I3, I2));
|
||||
nrepeat_plus_copy(make_tuple(I3, I3));
|
||||
mrepeat_minus_copy(make_tuple(I2, I3));
|
||||
mrepeat_minus_copy(make_tuple(I1, I3));
|
||||
mrepeat_minus_copy(make_tuple(I0, I3));
|
||||
}
|
||||
}
|
||||
else if constexpr(MRepeat == 4 && NRepeat == 2)
|
||||
{
|
||||
init_copy(make_tuple(I0, I0));
|
||||
|
||||
if constexpr(CAccessOrderMRepeatNRepeat)
|
||||
{
|
||||
nrepeat_plus_copy(make_tuple(I0, I1));
|
||||
mrepeat_plus_copy(make_tuple(I1, I1));
|
||||
nrepeat_minus_copy(make_tuple(I1, I0));
|
||||
mrepeat_plus_copy(make_tuple(I2, I0));
|
||||
nrepeat_plus_copy(make_tuple(I2, I1));
|
||||
mrepeat_plus_copy(make_tuple(I3, I1));
|
||||
nrepeat_minus_copy(make_tuple(I3, I0));
|
||||
}
|
||||
else
|
||||
{
|
||||
mrepeat_plus_copy(make_tuple(I1, I0));
|
||||
mrepeat_plus_copy(make_tuple(I2, I0));
|
||||
mrepeat_plus_copy(make_tuple(I3, I0));
|
||||
nrepeat_plus_copy(make_tuple(I3, I1));
|
||||
mrepeat_minus_copy(make_tuple(I2, I1));
|
||||
mrepeat_minus_copy(make_tuple(I1, I1));
|
||||
mrepeat_minus_copy(make_tuple(I0, I1));
|
||||
}
|
||||
}
|
||||
else if constexpr(MRepeat == 2 && NRepeat == 4)
|
||||
{
|
||||
init_copy(make_tuple(I0, I0));
|
||||
|
||||
if constexpr(CAccessOrderMRepeatNRepeat)
|
||||
{
|
||||
nrepeat_plus_copy(make_tuple(I0, I1));
|
||||
nrepeat_plus_copy(make_tuple(I0, I2));
|
||||
nrepeat_plus_copy(make_tuple(I0, I3));
|
||||
mrepeat_plus_copy(make_tuple(I1, I3));
|
||||
nrepeat_minus_copy(make_tuple(I1, I2));
|
||||
nrepeat_minus_copy(make_tuple(I1, I1));
|
||||
nrepeat_minus_copy(make_tuple(I1, I0));
|
||||
}
|
||||
else
|
||||
{
|
||||
mrepeat_plus_copy(make_tuple(I1, I0));
|
||||
nrepeat_plus_copy(make_tuple(I1, I1));
|
||||
mrepeat_minus_copy(make_tuple(I0, I1));
|
||||
nrepeat_plus_copy(make_tuple(I0, I2));
|
||||
mrepeat_plus_copy(make_tuple(I1, I2));
|
||||
nrepeat_plus_copy(make_tuple(I1, I3));
|
||||
mrepeat_minus_copy(make_tuple(I0, I3));
|
||||
}
|
||||
}
|
||||
else if constexpr(MRepeat == 2 && NRepeat == 2)
|
||||
{
|
||||
init_copy(make_tuple(I0, I0));
|
||||
|
||||
if constexpr(CAccessOrderMRepeatNRepeat)
|
||||
{
|
||||
nrepeat_plus_copy(make_tuple(I0, I1));
|
||||
mrepeat_plus_copy(make_tuple(I1, I1));
|
||||
nrepeat_minus_copy(make_tuple(I1, I0));
|
||||
}
|
||||
else
|
||||
{
|
||||
mrepeat_plus_copy(make_tuple(I1, I0));
|
||||
nrepeat_plus_copy(make_tuple(I1, I1));
|
||||
mrepeat_minus_copy(make_tuple(I0, I1));
|
||||
}
|
||||
}
|
||||
else if constexpr(MRepeat == 2 && NRepeat == 1)
|
||||
{
|
||||
init_copy(make_tuple(I0, I0));
|
||||
mrepeat_plus_copy(make_tuple(I1, I0));
|
||||
}
|
||||
else if constexpr(MRepeat == 1 && NRepeat == 2)
|
||||
{
|
||||
init_copy(make_tuple(I0, I0));
|
||||
nrepeat_plus_copy(make_tuple(I0, I1));
|
||||
}
|
||||
else if constexpr(MRepeat == 1 && NRepeat == 1)
|
||||
{
|
||||
init_copy(make_tuple(I0, I0));
|
||||
}
|
||||
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
c_grid_buf,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
|
||||
}
|
||||
}
|
||||
}; // namespace ck
|
||||
|
||||
@@ -44,15 +44,10 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x1xf32>
|
||||
static constexpr index_t k_per_blk = 1;
|
||||
static constexpr bool is_k_reduction = false;
|
||||
|
||||
template <index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
index_t COffset,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC>
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_mfma_f32_32x32x1f32<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
|
||||
intrin_mfma_f32_32x32x1f32<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -71,15 +66,10 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x2xf32>
|
||||
static constexpr index_t k_per_blk = 1;
|
||||
static constexpr bool is_k_reduction = true;
|
||||
|
||||
template <index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
index_t COffset,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC>
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_mfma_f32_32x32x2f32<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
|
||||
intrin_mfma_f32_32x32x2f32<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -98,15 +88,10 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x4xf32>
|
||||
static constexpr index_t k_per_blk = 1;
|
||||
static constexpr bool is_k_reduction = true;
|
||||
|
||||
template <index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
index_t COffset,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC>
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_mfma_f32_16x16x4f32<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
|
||||
intrin_mfma_f32_16x16x4f32<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -125,15 +110,10 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x1xf32>
|
||||
static constexpr index_t k_per_blk = 1;
|
||||
static constexpr bool is_k_reduction = false;
|
||||
|
||||
template <index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
index_t COffset,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC>
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_mfma_f32_16x16x1f32<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
|
||||
intrin_mfma_f32_16x16x1f32<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -153,15 +133,10 @@ struct mfma_type<MfmaInstr::mfma_f32_4x4x1xf32>
|
||||
static constexpr index_t k_per_blk = 1;
|
||||
static constexpr bool is_k_reduction = false;
|
||||
|
||||
template <index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
index_t COffset,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC>
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_mfma_f32_4x4x1f32<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
|
||||
intrin_mfma_f32_4x4x1f32<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -180,15 +155,10 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x4f16>
|
||||
static constexpr index_t k_per_blk = 4;
|
||||
static constexpr bool is_k_reduction = false;
|
||||
|
||||
template <index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
index_t COffset,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC>
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_mfma_f32_32x32x4f16<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
|
||||
intrin_mfma_f32_32x32x4f16<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -207,15 +177,10 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x8f16>
|
||||
static constexpr index_t k_per_blk = 4;
|
||||
static constexpr bool is_k_reduction = true;
|
||||
|
||||
template <index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
index_t COffset,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC>
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_mfma_f32_32x32x8f16<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
|
||||
intrin_mfma_f32_32x32x8f16<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -234,15 +199,10 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x16f16>
|
||||
static constexpr index_t k_per_blk = 4;
|
||||
static constexpr bool is_k_reduction = true;
|
||||
|
||||
template <index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
index_t COffset,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC>
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_mfma_f32_16x16x16f16<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
|
||||
intrin_mfma_f32_16x16x16f16<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -261,15 +221,10 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x4f16>
|
||||
static constexpr index_t k_per_blk = 4;
|
||||
static constexpr bool is_k_reduction = false;
|
||||
|
||||
template <index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
index_t COffset,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC>
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_mfma_f32_16x16x4f16<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
|
||||
intrin_mfma_f32_16x16x4f16<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -288,15 +243,10 @@ struct mfma_type<MfmaInstr::mfma_f32_4x4x4f16>
|
||||
static constexpr index_t k_per_blk = 4;
|
||||
static constexpr bool is_k_reduction = false;
|
||||
|
||||
template <index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
index_t COffset,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC>
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_mfma_f32_4x4x4f16<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
|
||||
intrin_mfma_f32_4x4x4f16<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -732,7 +682,7 @@ struct XdlopsGemm
|
||||
return MPerXdlops * NPerXdlops / mfma_instr.wave_size;
|
||||
}
|
||||
|
||||
template <index_t c_offset, class FloatA, class FloatB, class FloatC>
|
||||
template <class FloatA, class FloatB, class FloatC>
|
||||
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
|
||||
{
|
||||
static_assert(is_same<base_type, float>::value || is_same<base_type, half_t>::value ||
|
||||
@@ -740,8 +690,7 @@ struct XdlopsGemm
|
||||
"base base_type must be float, half, ushort!");
|
||||
|
||||
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
|
||||
mfma_instr.template run<MPerXdlops, NPerXdlops, c_offset>(
|
||||
p_a_wave[k], p_b_wave[k], p_c_thread);
|
||||
mfma_instr.template run<MPerXdlops, NPerXdlops>(p_a_wave[k], p_b_wave[k], p_c_thread);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -819,8 +768,9 @@ struct XdlopsGemm
|
||||
|
||||
static constexpr auto mfma_instr = mfma.selected_mfma;
|
||||
|
||||
static constexpr auto KPerXdlops = mfma.GetKPerXdlops();
|
||||
static constexpr auto KPerThread = mfma.GetKPerThread();
|
||||
static constexpr auto KPerXdlops = mfma.GetKPerXdlops();
|
||||
static constexpr auto K1PerXdlops = mfma.GetKPerThread();
|
||||
static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops;
|
||||
|
||||
__host__ __device__ static constexpr auto GetCM0M1M2NThreadBlkLengths()
|
||||
{
|
||||
|
||||
@@ -51,304 +51,196 @@ extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(
|
||||
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(
|
||||
ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x2bf16");
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave, index_t COffset>
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_32x32x1f32;
|
||||
|
||||
template <index_t COffset>
|
||||
struct intrin_mfma_f32_32x32x1f32<64, 64, COffset>
|
||||
template <>
|
||||
struct intrin_mfma_f32_32x32x1f32<64, 64>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c(Number<COffset>{}).template AsType<float32_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset>{}].template AsType<float32_t>()[Number<0>{}],
|
||||
1,
|
||||
0,
|
||||
0);
|
||||
reg_c(Number<COffset + 1>{}).template AsType<float32_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset + 1>{}].template AsType<float32_t>()[Number<0>{}],
|
||||
1,
|
||||
1,
|
||||
0);
|
||||
reg_c.template AsType<float32_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
|
||||
reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
|
||||
reg_c.template AsType<float32_t>()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
|
||||
reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<1>{}], 1, 1, 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t COffset>
|
||||
struct intrin_mfma_f32_32x32x1f32<32, 64, COffset>
|
||||
template <>
|
||||
struct intrin_mfma_f32_32x32x1f32<32, 64>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c(Number<COffset>{}).template AsType<float32_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset>{}].template AsType<float32_t>()[Number<0>{}],
|
||||
1,
|
||||
0,
|
||||
0);
|
||||
reg_c.template AsType<float32_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
|
||||
reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave, index_t COffset>
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_32x32x2f32;
|
||||
|
||||
template <index_t COffset>
|
||||
struct intrin_mfma_f32_32x32x2f32<32, 32, COffset>
|
||||
template <>
|
||||
struct intrin_mfma_f32_32x32x2f32<32, 32>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c(Number<COffset>{}).template AsType<float16_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_32x32x2f32(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset>{}].template AsType<float16_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x2f32(
|
||||
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave, index_t COffset>
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_16x16x4f32;
|
||||
|
||||
template <index_t COffset>
|
||||
struct intrin_mfma_f32_16x16x4f32<16, 16, COffset>
|
||||
template <>
|
||||
struct intrin_mfma_f32_16x16x4f32<16, 16>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_16x16x4f32(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x4f32(
|
||||
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave, index_t COffset>
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_16x16x1f32;
|
||||
|
||||
template <index_t COffset>
|
||||
struct intrin_mfma_f32_16x16x1f32<16, 64, COffset>
|
||||
template <>
|
||||
struct intrin_mfma_f32_16x16x1f32<16, 64>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
|
||||
{
|
||||
|
||||
reg_c(Number<COffset>{}).template AsType<float16_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_16x16x1f32(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset>{}].template AsType<float16_t>()[Number<0>{}],
|
||||
2,
|
||||
0,
|
||||
0);
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x1f32(
|
||||
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 2, 0, 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave, index_t COffset>
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_4x4x1f32;
|
||||
|
||||
template <index_t COffset>
|
||||
struct intrin_mfma_f32_4x4x1f32<4, 64, COffset>
|
||||
template <>
|
||||
struct intrin_mfma_f32_4x4x1f32<4, 64>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
|
||||
4,
|
||||
0,
|
||||
0);
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
|
||||
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t COffset>
|
||||
struct intrin_mfma_f32_4x4x1f32<8, 64, COffset>
|
||||
template <>
|
||||
struct intrin_mfma_f32_4x4x1f32<8, 64>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
|
||||
4,
|
||||
0,
|
||||
0);
|
||||
reg_c(Number<COffset + 1>{}).template AsType<float4_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset + 1>{}].template AsType<float4_t>()[Number<0>{}],
|
||||
4,
|
||||
1,
|
||||
0);
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
|
||||
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
|
||||
reg_c.template AsType<float4_t>()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
|
||||
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<1>{}], 4, 1, 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave, index_t COffset>
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_32x32x4f16;
|
||||
|
||||
template <index_t COffset>
|
||||
struct intrin_mfma_f32_32x32x4f16<64, 64, COffset>
|
||||
template <>
|
||||
struct intrin_mfma_f32_32x32x4f16<64, 64>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c(Number<COffset>{}).template AsType<float32_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset>{}].template AsType<float32_t>()[Number<0>{}],
|
||||
1,
|
||||
0,
|
||||
0);
|
||||
reg_c(Number<COffset + 1>{}).template AsType<float32_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset + 1>{}].template AsType<float32_t>()[Number<0>{}],
|
||||
1,
|
||||
1,
|
||||
0);
|
||||
reg_c.template AsType<float32_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
|
||||
reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
|
||||
reg_c.template AsType<float32_t>()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
|
||||
reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<1>{}], 1, 1, 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t COffset>
|
||||
struct intrin_mfma_f32_32x32x4f16<32, 64, COffset>
|
||||
template <>
|
||||
struct intrin_mfma_f32_32x32x4f16<32, 64>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c(Number<COffset>{}).template AsType<float32_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset>{}].template AsType<float32_t>()[Number<0>{}],
|
||||
1,
|
||||
0,
|
||||
0);
|
||||
reg_c.template AsType<float32_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
|
||||
reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave, index_t COffset>
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_32x32x8f16;
|
||||
|
||||
template <index_t COffset>
|
||||
struct intrin_mfma_f32_32x32x8f16<32, 32, COffset>
|
||||
template <>
|
||||
struct intrin_mfma_f32_32x32x8f16<32, 32>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c(Number<COffset>{}).template AsType<float16_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_32x32x8f16(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset>{}].template AsType<float16_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x8f16(
|
||||
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave, index_t COffset>
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_16x16x16f16;
|
||||
|
||||
template <index_t COffset>
|
||||
struct intrin_mfma_f32_16x16x16f16<16, 16, COffset>
|
||||
template <>
|
||||
struct intrin_mfma_f32_16x16x16f16<16, 16>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_16x16x16f16(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x16f16(
|
||||
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave, index_t COffset>
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_16x16x4f16;
|
||||
|
||||
template <index_t COffset>
|
||||
struct intrin_mfma_f32_16x16x4f16<16, 64, COffset>
|
||||
template <>
|
||||
struct intrin_mfma_f32_16x16x4f16<16, 64>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c(Number<COffset>{}).template AsType<float16_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_16x16x4f16(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset>{}].template AsType<float16_t>()[Number<0>{}],
|
||||
2,
|
||||
0,
|
||||
0);
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x4f16(
|
||||
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 2, 0, 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave, index_t COffset>
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_4x4x4f16;
|
||||
|
||||
template <index_t COffset>
|
||||
struct intrin_mfma_f32_4x4x4f16<4, 64, COffset>
|
||||
template <>
|
||||
struct intrin_mfma_f32_4x4x4f16<4, 64>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
|
||||
4,
|
||||
0,
|
||||
0);
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
|
||||
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t COffset>
|
||||
struct intrin_mfma_f32_4x4x4f16<8, 64, COffset>
|
||||
template <>
|
||||
struct intrin_mfma_f32_4x4x4f16<8, 64>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
|
||||
4,
|
||||
0,
|
||||
0);
|
||||
reg_c(Number<COffset + 1>{}).template AsType<float4_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset + 1>{}].template AsType<float4_t>()[Number<0>{}],
|
||||
4,
|
||||
1,
|
||||
0);
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
|
||||
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
|
||||
reg_c.template AsType<float4_t>()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
|
||||
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<1>{}], 4, 1, 0);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -448,7 +340,6 @@ template <index_t MPerWave, index_t NPerWave>
|
||||
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16(const ushort2_t* reg_a,
|
||||
const ushort2_t* reg_b,
|
||||
c_vec16_1_t::VecType reg_c);
|
||||
|
||||
template <>
|
||||
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16<16, 64>(const ushort2_t* reg_a,
|
||||
const ushort2_t* reg_b,
|
||||
|
||||
@@ -55,6 +55,98 @@ struct StaticBuffer : public StaticallyIndexedArray<T, N>
|
||||
__host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
|
||||
};
|
||||
|
||||
template <AddressSpaceEnum_t BufferAddressSpace,
|
||||
typename T,
|
||||
index_t N,
|
||||
bool InvalidElementUseNumericalZeroValue>
|
||||
struct StaticBufferV2 : public StaticallyIndexedArray<T, N>
|
||||
{
|
||||
using type = T;
|
||||
using base = StaticallyIndexedArray<T, N>;
|
||||
|
||||
using VecBaseType = typename T::d1_t;
|
||||
|
||||
__host__ __device__ static constexpr index_t GetVectorSize()
|
||||
{
|
||||
return sizeof(typename T::type) / sizeof(VecBaseType);
|
||||
}
|
||||
|
||||
static constexpr index_t vector_size = GetVectorSize();
|
||||
|
||||
VecBaseType invalid_element_value_ = VecBaseType{0};
|
||||
|
||||
T invalid_vec_value_ = T{0};
|
||||
|
||||
__host__ __device__ constexpr StaticBufferV2() : base{} {}
|
||||
|
||||
__host__ __device__ constexpr StaticBufferV2(VecBaseType invalid_element_value)
|
||||
: base{},
|
||||
invalid_vec_value_{invalid_element_value},
|
||||
invalid_element_value_{invalid_element_value}
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace()
|
||||
{
|
||||
return BufferAddressSpace;
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto& GetVector(Number<I> vec_id)
|
||||
{
|
||||
return this->At(vec_id);
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr const auto& GetVector(Number<I> vec_id) const
|
||||
{
|
||||
return this->At(vec_id);
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto& GetElement(Number<I> i, bool)
|
||||
{
|
||||
constexpr auto vec_id = Number<i / vector_size>{};
|
||||
constexpr auto vec_off = Number<i % vector_size>{};
|
||||
|
||||
return this->At(vec_id).template AsType<VecBaseType>()(vec_off);
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto GetElement(Number<I> i, bool is_valid_element) const
|
||||
{
|
||||
constexpr auto vec_id = Number<i / vector_size>{};
|
||||
constexpr auto vec_off = Number<i % vector_size>{};
|
||||
|
||||
if constexpr(InvalidElementUseNumericalZeroValue)
|
||||
{
|
||||
return is_valid_element ? this->At(vec_id).template AsType<VecBaseType>()[vec_off]
|
||||
: VecBaseType{0};
|
||||
}
|
||||
else
|
||||
{
|
||||
return is_valid_element ? this->At(vec_id).template AsType<VecBaseType>()[vec_off]
|
||||
: invalid_element_value_;
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto operator[](Number<I> i) const
|
||||
{
|
||||
return GetElement(i, true);
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto& operator()(Number<I> i)
|
||||
{
|
||||
return GetElement(i, true);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
|
||||
|
||||
__host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
|
||||
};
|
||||
|
||||
template <AddressSpaceEnum_t BufferAddressSpace, typename T, index_t N>
|
||||
__host__ __device__ constexpr auto make_static_buffer(Number<N>)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user