mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
Fused attention (#345)
* initial stub for gemm_gemm_xdl_cshuffle * set up example code * compiles * prevent integer overflow * harmonize interface between ref_gemm and ref_batched_gemm * batched_gemm_gemm * fix example * host tensor gen: diagonal pattern in lowest two-dimensions only * make c descriptors containing only integral constants * clean up * add BlockwiseGemmXdlops_v2 while exploring an unified approach * implement proper interface * tidy up example * fix compilation warnings * coarsely controlled 2nd gemm padding * remove rocm-cmake's hard requirement for certain revision * clang-format * resolve merge conflict * fix compilation error on gfx10 * adds acc0 elementwise op to interface * attention host validation * add blockwsie softmax v1 * iteratively update softmax+gemm * transpose both gemm0 and gemm1 xdl output so as to avoid broadcasting softmax max/sum * add init method for easier debugging * do away with manual thread cluster calculation * generalize blockwise softmax interface * row-wise softmax sum & max * format * rename to DeviceBatchedGemmSoftmaxGemm * add gemm_softmax_gemm instances and tests * comment Co-authored-by: ltqin <letao.qin@amd.com> Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
@@ -25,6 +25,22 @@ constexpr LoopScheduler make_default_loop_scheduler()
|
||||
#endif // if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
|
||||
}
|
||||
|
||||
template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(const TileDesc_K0_MN_K1&)
|
||||
{
|
||||
constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
|
||||
constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
TileDesc_K0_MN_K1{},
|
||||
make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number<K0>{}, Number<K1>{})),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<MNXdlPerWave>{}, Number<MNWaves>{}, Number<MNPerXdl>{}))),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
|
||||
}
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
@@ -585,4 +601,361 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
|
||||
}
|
||||
};
|
||||
|
||||
// Blockwise gemm supporting
|
||||
// 1. regular XDL output M2_M3_M4_M2 and transposed XDL output M2_N2_N3_N4
|
||||
// 2. decoupled input tile descriptor and mma tile descriptor in order to support both vgpr and LDS
|
||||
// source buffer
|
||||
// 3. configurable k index starting position and step size after each FMA/XDL instruction
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack,
|
||||
bool TransposeC = false,
|
||||
index_t AMmaKStride =
|
||||
KPack* XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, TransposeC>{}.K0PerXdlops,
|
||||
index_t BMmaKStride =
|
||||
KPack* XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, TransposeC>{}.K0PerXdlops>
|
||||
struct BlockwiseGemmXdlops_v2
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
static constexpr index_t WaveSize = get_warp_size();
|
||||
|
||||
static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0);
|
||||
static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0);
|
||||
static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2);
|
||||
static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2);
|
||||
|
||||
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, TransposeC>{};
|
||||
|
||||
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
|
||||
|
||||
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
|
||||
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
|
||||
|
||||
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
|
||||
FloatAcc,
|
||||
MRepeat * NRepeat,
|
||||
xdlops_gemm.GetRegSizePerXdlops(),
|
||||
true>
|
||||
c_thread_buf_;
|
||||
|
||||
__host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
|
||||
|
||||
__device__ static auto GetWaveIdx()
|
||||
{
|
||||
const index_t thread_id = ThisThreadBlock::GetThreadId();
|
||||
|
||||
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
|
||||
make_tuple(Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
|
||||
}
|
||||
|
||||
__device__ static auto CalculateAThreadOriginDataIndex()
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
|
||||
const auto waveId_m = wave_idx[I0];
|
||||
|
||||
const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
|
||||
|
||||
return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPack * xdlops_a_idx[I0]);
|
||||
}
|
||||
|
||||
__device__ static auto CalculateBThreadOriginDataIndex()
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
|
||||
const auto waveId_n = wave_idx[I1];
|
||||
|
||||
const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
|
||||
|
||||
return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPack * xdlops_b_idx[I0]);
|
||||
}
|
||||
|
||||
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
|
||||
__device__ static auto
|
||||
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
|
||||
const auto waveId_m = wave_idx[I0];
|
||||
const auto waveId_n = wave_idx[I1];
|
||||
|
||||
const auto tmp = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
|
||||
const auto blk_idx =
|
||||
TransposeC ? make_multi_index(tmp[I1], tmp[I0]) : make_multi_index(tmp[I0], tmp[I1]);
|
||||
|
||||
constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}));
|
||||
|
||||
constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}));
|
||||
|
||||
const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
|
||||
make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
|
||||
const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
|
||||
make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
|
||||
|
||||
return make_tuple(c_thread_m, c_thread_n);
|
||||
}
|
||||
|
||||
using Tuple4 = decltype(CalculateAThreadOriginDataIndex());
|
||||
|
||||
__host__ __device__ BlockwiseGemmXdlops_v2(Tuple4 a_origin = CalculateAThreadOriginDataIndex(),
|
||||
Tuple4 b_origin = CalculateBThreadOriginDataIndex())
|
||||
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
|
||||
{
|
||||
static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
|
||||
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
|
||||
|
||||
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
|
||||
"wrong!");
|
||||
}
|
||||
|
||||
__host__ __device__ BlockwiseGemmXdlops_v2(const BlockwiseGemmXdlops_v2& other)
|
||||
: a_thread_copy_(other.a_origin), b_thread_copy_(other.b_origin)
|
||||
{
|
||||
}
|
||||
|
||||
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
|
||||
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
|
||||
{
|
||||
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
|
||||
|
||||
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
|
||||
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
|
||||
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
|
||||
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
|
||||
|
||||
return make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, N, M0, M1, M2));
|
||||
}
|
||||
|
||||
// XDL output supporting C_xdl = A_xdl * B_xdl
|
||||
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
|
||||
{
|
||||
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
|
||||
|
||||
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
|
||||
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
|
||||
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
|
||||
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
|
||||
|
||||
return make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
|
||||
{
|
||||
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
|
||||
|
||||
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
|
||||
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
|
||||
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
|
||||
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
|
||||
|
||||
return make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
|
||||
}
|
||||
|
||||
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
|
||||
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
|
||||
{
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
|
||||
Number<NRepeat>{},
|
||||
Number<MWaves>{},
|
||||
Number<NWaves>{},
|
||||
Number<MPerXDL>{},
|
||||
Number<NPerXDL>{}));
|
||||
|
||||
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_block_desc_m0_n0_m1_n1_m2_n2);
|
||||
}
|
||||
|
||||
// XDL output supporting C_xdl = A_xdl * B_xdl
|
||||
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
|
||||
{
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
|
||||
Number<NRepeat>{},
|
||||
Number<MWaves>{},
|
||||
Number<NWaves>{},
|
||||
Number<MPerXDL>{},
|
||||
Number<NPerXDL>{}));
|
||||
|
||||
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
|
||||
{
|
||||
constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(I1,
|
||||
Number<MRepeat>{},
|
||||
Number<NRepeat>{},
|
||||
Number<MWaves>{},
|
||||
Number<NWaves>{},
|
||||
Number<MPerXDL>{},
|
||||
Number<NPerXDL>{}));
|
||||
|
||||
return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
|
||||
c_block_desc_g_m0_n0_m1_n1_m2_n2);
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
|
||||
{
|
||||
const auto M = c_grid_desc_m_n.GetLength(I0);
|
||||
const auto N = c_grid_desc_m_n.GetLength(I1);
|
||||
|
||||
const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
|
||||
c_grid_desc_m_n,
|
||||
make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
|
||||
make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}));
|
||||
|
||||
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2);
|
||||
}
|
||||
|
||||
template <typename CGridDesc_G_M_N>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n)
|
||||
{
|
||||
const auto G = c_grid_desc_g_m_n.GetLength(I0);
|
||||
const auto M = c_grid_desc_g_m_n.GetLength(I1);
|
||||
const auto N = c_grid_desc_g_m_n.GetLength(I2);
|
||||
|
||||
const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
|
||||
c_grid_desc_g_m_n,
|
||||
make_tuple(make_pass_through_transform(G),
|
||||
make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
|
||||
make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{}));
|
||||
|
||||
return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
|
||||
c_grid_desc_g_m0_n0_m1_n1_m2_n2);
|
||||
}
|
||||
|
||||
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k;
|
||||
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k;
|
||||
|
||||
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
|
||||
__device__ void Run(const ABlockBuffer& a_block_buf,
|
||||
const BBlockBuffer& b_block_buf,
|
||||
CThreadBuffer& c_thread_buf) const
|
||||
{
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
static_for<0, KPerThread / KPack, 1>{}([&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
// read A
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
// read B
|
||||
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
|
||||
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
vector_type<FloatAB, KPack> a_thread_vec;
|
||||
vector_type<FloatAB, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto i) {
|
||||
a_thread_vec.template AsType<FloatAB>()(i) = a_thread_buf
|
||||
[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, i))>{}];
|
||||
b_thread_vec.template AsType<FloatAB>()(i) = b_thread_buf
|
||||
[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, i))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
xdlops_gemm.template Run(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
protected:
|
||||
// A[M0, M1, M2, KPerThread]
|
||||
static constexpr auto a_thread_desc_ =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerThread>{}));
|
||||
|
||||
// B[N0, N1, N2, KPerThread]
|
||||
static constexpr auto b_thread_desc_ =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerThread>{}));
|
||||
|
||||
// C[M, N, NumRegXdlops]
|
||||
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
|
||||
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
|
||||
FloatAB,
|
||||
decltype(a_block_desc_m0_m1_m2_k),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, 1, 1, KPack>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
A_K1,
|
||||
A_K1>;
|
||||
|
||||
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
|
||||
FloatAB,
|
||||
decltype(b_block_desc_n0_n1_n2_k),
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<1, 1, 1, KPack>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
B_K1,
|
||||
B_K1>;
|
||||
|
||||
AThreadCopy a_thread_copy_;
|
||||
BThreadCopy b_thread_copy_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
96
include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp
Normal file
96
include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp
Normal file
@@ -0,0 +1,96 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/reduction_common.hpp"
|
||||
#include "ck/utility/reduction_operator.hpp"
|
||||
#include "ck/utility/reduction_functions_accumulate.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename AccDataType,
|
||||
typename ThreadMap_M_K, // thread_id to m_k
|
||||
typename ThreadClusterDesc_M_K,
|
||||
typename ThreadSliceDesc_M_K>
|
||||
struct BlockwiseSoftmax
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr index_t MRepeat = ThreadSliceDesc_M_K{}.GetLength(I0);
|
||||
static constexpr index_t KRepeat = ThreadSliceDesc_M_K{}.GetLength(I1);
|
||||
|
||||
using ThreadSliceDesc_M = decltype(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(ThreadSliceDesc_M_K{}.GetLength(I0))));
|
||||
|
||||
using ThreadwiseMaxReduce = ThreadwiseReduction<AccDataType,
|
||||
ThreadSliceDesc_M_K,
|
||||
ThreadSliceDesc_M,
|
||||
reduce::Max,
|
||||
false>;
|
||||
|
||||
using ThreadClusterLengths_M_K = decltype(ThreadClusterDesc_M_K{}.GetLengths());
|
||||
|
||||
using BlockwiseMaxReduce = PartitionedBlockwiseReduction_v2<AccDataType,
|
||||
BlockSize,
|
||||
ThreadClusterLengths_M_K,
|
||||
ThreadMap_M_K,
|
||||
reduce::Max,
|
||||
false>;
|
||||
|
||||
using BlockwiseSumReduce = PartitionedBlockwiseReduction_v2<AccDataType,
|
||||
BlockSize,
|
||||
ThreadClusterLengths_M_K,
|
||||
ThreadMap_M_K,
|
||||
reduce::Add,
|
||||
false>;
|
||||
|
||||
using ThreadwiseSumReduce = ThreadwiseReduction<AccDataType,
|
||||
ThreadSliceDesc_M_K,
|
||||
ThreadSliceDesc_M,
|
||||
reduce::Add,
|
||||
false>;
|
||||
|
||||
using BufferType = StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MRepeat, true>;
|
||||
|
||||
template <typename CThreadBuffer, typename WorkspaceBuffer>
|
||||
__host__ __device__ void Run(CThreadBuffer& in_thread_buf, WorkspaceBuffer& reduce_work_buf)
|
||||
{
|
||||
// find max value
|
||||
static_for<0, MRepeat, 1>{}([&](auto I) {
|
||||
max_value_buf(I) = reduce::Max::template GetIdentityValue<AccDataType>();
|
||||
});
|
||||
ThreadwiseMaxReduce::Reduce(in_thread_buf, max_value_buf);
|
||||
static_for<0, MRepeat, 1>{}([&](auto I) {
|
||||
BlockwiseMaxReduce::Reduce(reduce_work_buf, max_value_buf(I));
|
||||
block_sync_lds();
|
||||
});
|
||||
|
||||
// calculate exp for elements, P=exp(s-max)
|
||||
static_for<0, MRepeat, 1>{}([&](auto iM) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto iK) {
|
||||
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
|
||||
in_thread_buf(offset) = math::exp(in_thread_buf[offset] - max_value_buf(iM));
|
||||
});
|
||||
});
|
||||
|
||||
// sum data
|
||||
static_for<0, MRepeat, 1>{}([&](auto I) {
|
||||
sum_value_buf(I) = reduce::Add::template GetIdentityValue<AccDataType>();
|
||||
});
|
||||
ThreadwiseSumReduce::Reduce(in_thread_buf, sum_value_buf);
|
||||
static_for<0, MRepeat, 1>{}([&](auto I) {
|
||||
BlockwiseSumReduce::Reduce(reduce_work_buf, sum_value_buf(I));
|
||||
block_sync_lds();
|
||||
});
|
||||
}
|
||||
|
||||
BufferType max_value_buf;
|
||||
BufferType sum_value_buf;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -82,6 +82,78 @@ struct PartitionedBlockwiseReduction
|
||||
};
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
// Assume:
|
||||
// 1) work_buffer is buffer (typically LDS) allocated outside as workspace, does not include any in/out data
|
||||
// 2) work_buffer has AccDataType elements, and space size is no less than BlockSize
|
||||
// 3) in_out_value is the input data in vgpr from each thread
|
||||
// 4) in_out_value is the over-written reduced output in vgpr for each thread
|
||||
// clang-format on
|
||||
template <typename AccDataType,
|
||||
index_t BlockSize,
|
||||
typename ThreadClusterLengths_M_K,
|
||||
typename ThreadClusterDesc,
|
||||
typename OpReduce,
|
||||
bool PropagateNan,
|
||||
typename Accumulation =
|
||||
detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>>
|
||||
struct PartitionedBlockwiseReduction_v2
|
||||
{
|
||||
static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1),
|
||||
"The product of cluster lengths should be same as BlockSize!");
|
||||
|
||||
static constexpr auto BufferLength_M = ThreadClusterLengths_M_K::At(0);
|
||||
static constexpr auto BufferLength_K = ThreadClusterLengths_M_K::At(1);
|
||||
|
||||
static_assert(BufferLength_K > 1, "Parallel reduction need work on at least two elements");
|
||||
|
||||
static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<BufferLength_M>{}, Number<BufferLength_K>{}));
|
||||
|
||||
static constexpr auto thread_cluster_desc = ThreadClusterDesc{};
|
||||
|
||||
template <typename BufferType>
|
||||
__device__ static void Reduce(BufferType& work_buffer, AccDataType& in_out_value)
|
||||
{
|
||||
static_assert(is_same<typename BufferType::type, AccDataType>{},
|
||||
"Buffer data type should be consistent as AccDataType!");
|
||||
|
||||
constexpr auto cluster_len_shift = get_shift<BufferLength_K>();
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
|
||||
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}];
|
||||
const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}];
|
||||
|
||||
work_buffer(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) = in_out_value;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
static_for<0, cluster_len_shift, 1>{}([&](auto I) {
|
||||
constexpr index_t indOffset = 1 << (cluster_len_shift - 1 - I());
|
||||
|
||||
if(thread_k_cluster_id < indOffset)
|
||||
{
|
||||
index_t offset1 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx);
|
||||
index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx +
|
||||
make_tuple(0, indOffset));
|
||||
|
||||
AccDataType opData1 = work_buffer[offset1];
|
||||
AccDataType opData2 = work_buffer[offset2];
|
||||
Accumulation::Calculate(opData1, opData2);
|
||||
work_buffer(offset1) = opData1;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
});
|
||||
|
||||
index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0));
|
||||
|
||||
in_out_value = work_buffer[offset];
|
||||
};
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
// Assume:
|
||||
// 1) work_val_buffer/work_idx_buffer is buffer (typically LDS) allocated outside as workspace, does not include any in/out data
|
||||
|
||||
Reference in New Issue
Block a user