mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
Merging the gfx12 code into public repo. (#1362)
This commit is contained in:
@@ -13,6 +13,504 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
#ifdef __gfx12__
|
||||
template <index_t BlockSize,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatAcc,
|
||||
typename ABlockDesc,
|
||||
typename BBlockDesc,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerWMMA,
|
||||
index_t NPerWMMA,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack,
|
||||
bool AEnableLds = true,
|
||||
bool BEnableLds = true,
|
||||
bool TransposeC = false>
|
||||
/* Option: Read from LDS, big buffer hold all threads required data
|
||||
* Source
|
||||
* A: K0PerBlock x MPerBlock x K1
|
||||
* B: K0PerBlock x NPerBlock x K1
|
||||
* Destination
|
||||
* C, non-transpose
|
||||
* thread level: MRepeat x NRepeat x MAccVgprs
|
||||
* block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
|
||||
* KPACK == WMMA_K = 16
|
||||
*
|
||||
* Option: Read from VMEM, small buffer hold each thread own required data (Skip LDS)
|
||||
* Source:
|
||||
* A(if skip LDS): MRepeat x KPack
|
||||
* B(if skip LDS): NRepeat x KPack
|
||||
* Destination
|
||||
* C, non-transpose
|
||||
* block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
|
||||
*/
|
||||
struct BlockwiseGemmWMMA
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
static constexpr auto WmmaK = Number<16>{};
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
// Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one.
|
||||
static constexpr index_t WaveSize = 32;
|
||||
|
||||
// When use LDS, each Row(16 consecutive lanes) read whole data from source buffer
|
||||
// When not use LDS, each Row read half of whole data from source buffer, exchange the data via
|
||||
// permutation
|
||||
static constexpr index_t A_KRow = 2;
|
||||
static constexpr index_t B_KRow = 2;
|
||||
|
||||
static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5);
|
||||
static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5);
|
||||
|
||||
static constexpr auto wmma_gemm =
|
||||
WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack, TransposeC>{};
|
||||
|
||||
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA);
|
||||
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA);
|
||||
|
||||
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
|
||||
FloatAcc,
|
||||
MRepeat * NRepeat,
|
||||
wmma_gemm.GetRegSizePerWmma(),
|
||||
true>
|
||||
c_thread_buf_;
|
||||
|
||||
__host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
|
||||
|
||||
__device__ static auto GetWaveIdx()
|
||||
{
|
||||
const index_t thread_id = ThisThreadBlock::GetThreadId();
|
||||
|
||||
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
|
||||
make_tuple(Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
|
||||
}
|
||||
|
||||
// Default, Block buffer in LDS, thread level offset enabled
|
||||
__device__ static auto CalculateAThreadOriginDataIndex()
|
||||
{
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
const auto waveId_m = wave_idx[I0];
|
||||
const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex();
|
||||
|
||||
// |KRepeat |MRepeat|MWave |KRow |MLane |KPack
|
||||
return make_tuple(0, 0, waveId_m, wmma_gemm.GetSubGroupId(), WMMA_a_idx, 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tuple(0, 0, 0, 0, 0, 0);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ static auto CalculateBThreadOriginDataIndex()
|
||||
{
|
||||
if constexpr(BEnableLds)
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
const auto waveId_n = wave_idx[I1];
|
||||
const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex();
|
||||
|
||||
// |KRepeat |NRepeat|Nwave |KRow |NLane |KPack
|
||||
return make_tuple(0, 0, waveId_n, wmma_gemm.GetSubGroupId(), WMMA_b_idx, 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tuple(0, 0, 0, 0, 0, 0);
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t m0, index_t n0>
|
||||
__device__ static auto CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>)
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
|
||||
const auto waveId_m = wave_idx[I0];
|
||||
const auto waveId_n = wave_idx[I1];
|
||||
|
||||
const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk();
|
||||
|
||||
constexpr auto mrepeat_mwave_mperWMMA_to_m_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerWMMA))),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}));
|
||||
|
||||
constexpr auto nrepeat_nwave_nperWMMA_to_n_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerWMMA))),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}));
|
||||
|
||||
const index_t c_thread_m = mrepeat_mwave_mperWMMA_to_m_adaptor.CalculateBottomIndex(
|
||||
make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
|
||||
const index_t c_thread_n = nrepeat_nwave_nperWMMA_to_n_adaptor.CalculateBottomIndex(
|
||||
make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
|
||||
|
||||
return make_tuple(c_thread_m, c_thread_n);
|
||||
}
|
||||
|
||||
template <index_t m0, index_t n0>
|
||||
__device__ static auto CalculateCThreadOriginDataIndex7D(Number<m0>, Number<n0>)
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
|
||||
const auto waveId_m = wave_idx[I0];
|
||||
const auto waveId_n = wave_idx[I1];
|
||||
|
||||
const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk3D();
|
||||
|
||||
return make_tuple(
|
||||
Number<m0>{}, waveId_m, blk_idx[I0], Number<n0>{}, waveId_n, blk_idx[I1], blk_idx[I2]);
|
||||
}
|
||||
|
||||
using Tuple6 = decltype(CalculateAThreadOriginDataIndex());
|
||||
__host__ __device__ BlockwiseGemmWMMA(Tuple6 a_origin = CalculateAThreadOriginDataIndex(),
|
||||
Tuple6 b_origin = CalculateBThreadOriginDataIndex())
|
||||
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
|
||||
{
|
||||
static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
|
||||
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
|
||||
|
||||
static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 &&
|
||||
NPerBlock % (NPerWMMA * NRepeat) == 0,
|
||||
"wrong!");
|
||||
}
|
||||
|
||||
// transposed WMMA output C' = B' * A'
|
||||
__host__ __device__ static constexpr auto
|
||||
GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs()
|
||||
{
|
||||
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
|
||||
wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
|
||||
|
||||
constexpr auto NAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
|
||||
|
||||
return make_naive_tensor_descriptor_packed(
|
||||
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
|
||||
// |NThreadPerSubGroup |MAccVgprs
|
||||
make_tuple(Number<MRepeat>{}, I1, I1, Number<NRepeat>{}, I1, I1, NAccVgprs));
|
||||
}
|
||||
|
||||
// Thread level, register decriptor. Vector-write
|
||||
__host__ __device__ static constexpr auto
|
||||
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
|
||||
{
|
||||
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
|
||||
wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
|
||||
|
||||
constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
|
||||
constexpr auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I3];
|
||||
return make_naive_tensor_descriptor(
|
||||
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
|
||||
// |NThreadPerSubGroup |MAccVgprs
|
||||
make_tuple(Number<MRepeat>{}, I1, I1, Number<NRepeat>{}, I1, I1, MAccVgprs),
|
||||
make_tuple(Number<NRepeat>{} * MAccVgprs * AccStride,
|
||||
Number<NRepeat>{} * MAccVgprs * AccStride,
|
||||
Number<NRepeat>{} * MAccVgprs * AccStride,
|
||||
MAccVgprs * AccStride,
|
||||
MAccVgprs * AccStride,
|
||||
MAccVgprs * AccStride,
|
||||
AccStride));
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
|
||||
const CGridDesc_M_N& c_grid_desc_m_n)
|
||||
{
|
||||
const auto M = c_grid_desc_m_n.GetLength(I0);
|
||||
const auto N = c_grid_desc_m_n.GetLength(I1);
|
||||
|
||||
const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma =
|
||||
transform_tensor_descriptor(
|
||||
c_grid_desc_m_n,
|
||||
make_tuple(
|
||||
make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)),
|
||||
make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
|
||||
|
||||
return wmma_gemm
|
||||
.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
|
||||
c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma);
|
||||
}
|
||||
|
||||
// transposed WMMA output C' = B' * A'
|
||||
__host__ __device__ static constexpr auto
|
||||
GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs()
|
||||
{
|
||||
constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
|
||||
Number<MWaves>{},
|
||||
Number<MPerWMMA>{},
|
||||
Number<NRepeat>{},
|
||||
Number<NWaves>{},
|
||||
Number<NPerWMMA>{}));
|
||||
|
||||
return wmma_gemm
|
||||
.MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs(
|
||||
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
|
||||
}
|
||||
|
||||
// Provide dimension size
|
||||
__host__ __device__ static constexpr auto
|
||||
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
|
||||
{
|
||||
constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
|
||||
Number<MWaves>{},
|
||||
Number<MPerWMMA>{},
|
||||
Number<NRepeat>{},
|
||||
Number<NWaves>{},
|
||||
Number<NPerWMMA>{}));
|
||||
|
||||
return wmma_gemm
|
||||
.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
|
||||
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
|
||||
}
|
||||
|
||||
// Describe how data allocated in thread copy src buffer
|
||||
// M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
|
||||
static constexpr ABlockDesc a_block_desc_k0_m0_m1_m2_k1;
|
||||
static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1;
|
||||
|
||||
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
|
||||
__device__ void Run(const ABlockBuffer& a_block_buf,
|
||||
const BBlockBuffer& b_block_buf,
|
||||
CThreadBuffer& c_thread_buf) const
|
||||
{
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
static_assert(KPack % (A_K1 * A_KRow) == 0, "");
|
||||
static_assert(KPack % (B_K1 * B_KRow) == 0, "");
|
||||
|
||||
// basic intrinsic to determine loopover direction
|
||||
if constexpr(MRepeat < NRepeat)
|
||||
{
|
||||
static_for<0, KPerBlock / KPack, 1>{}(
|
||||
[&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
// read A
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_k0_m0_m1_m2_k1,
|
||||
make_tuple(Number<k * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, m0, I0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
// read B
|
||||
b_thread_copy_.Run(
|
||||
b_block_desc_k0_n0_n1_n2_k1,
|
||||
make_tuple(Number<k * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, n0, I0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
vector_type<FloatA, KPack / A_KRow> a_thread_vec;
|
||||
vector_type<FloatB, KPack / B_KRow> b_thread_vec;
|
||||
|
||||
static_for<0, KPack / A_KRow, 1>{}([&](auto i) {
|
||||
a_thread_vec.template AsType<FloatA>()(i) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(i / A_K1, m0, 0, 0, 0, i % A_K1))>{}];
|
||||
});
|
||||
|
||||
static_for<0, KPack / B_KRow, 1>{}([&](auto i) {
|
||||
b_thread_vec.template AsType<FloatB>()(i) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(i / B_K1, n0, 0, 0, 0, i % B_K1))>{}];
|
||||
});
|
||||
|
||||
using wmma_input_type_a =
|
||||
typename vector_type<FloatA, WmmaK / A_KRow>::type;
|
||||
using wmma_input_type_b =
|
||||
typename vector_type<FloatB, WmmaK / B_KRow>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
wmma_gemm.template Run(
|
||||
a_thread_vec.template AsType<wmma_input_type_a>(),
|
||||
b_thread_vec.template AsType<wmma_input_type_b>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KPerBlock / KPack, 1>{}([&](auto k) { // k=0,1,2 instead of
|
||||
// k=0,kpack*1, ..
|
||||
// read B
|
||||
b_thread_copy_.Run(
|
||||
b_block_desc_k0_n0_n1_n2_k1,
|
||||
make_tuple(Number<k * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, n0, I0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
// read A
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_k0_m0_m1_m2_k1,
|
||||
make_tuple(Number<k * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, m0, I0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
vector_type<FloatA, KPack / A_KRow> a_thread_vec;
|
||||
vector_type<FloatB, KPack / B_KRow> b_thread_vec;
|
||||
|
||||
static_for<0, KPack / A_KRow, 1>{}([&](auto i) {
|
||||
a_thread_vec.template AsType<FloatA>()(i) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(i / A_K1, m0, 0, 0, 0, i % A_K1))>{}];
|
||||
});
|
||||
|
||||
static_for<0, KPack / B_KRow, 1>{}([&](auto i) {
|
||||
b_thread_vec.template AsType<FloatB>()(i) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(i / B_K1, n0, 0, 0, 0, i % B_K1))>{}];
|
||||
});
|
||||
|
||||
using wmma_input_type_a =
|
||||
typename vector_type<FloatA, WmmaK / A_KRow>::type;
|
||||
using wmma_input_type_b =
|
||||
typename vector_type<FloatB, WmmaK / B_KRow>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
wmma_gemm.template Run(
|
||||
a_thread_vec.template AsType<wmma_input_type_a>(),
|
||||
b_thread_vec.template AsType<wmma_input_type_b>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor(
|
||||
make_tuple(Number<KPack / A_K1 / A_KRow>{}, Number<MRepeat>{}, I1, I1, I1, Number<A_K1>{}),
|
||||
make_tuple(Number<A_K1>{},
|
||||
Number<KPack / A_KRow>{},
|
||||
Number<A_K1>{},
|
||||
Number<A_K1>{},
|
||||
Number<A_K1>{},
|
||||
Number<1>{}));
|
||||
|
||||
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor(
|
||||
make_tuple(Number<KPack / B_K1 / B_KRow>{}, Number<NRepeat>{}, I1, I1, I1, Number<B_K1>{}),
|
||||
make_tuple(Number<B_K1>{},
|
||||
Number<KPack / B_KRow>{},
|
||||
Number<B_K1>{},
|
||||
Number<B_K1>{},
|
||||
Number<B_K1>{},
|
||||
Number<1>{}));
|
||||
|
||||
// C[M, N, NumRegWMMA]
|
||||
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, wmma_gemm.GetRegSizePerWmma()));
|
||||
|
||||
template <bool EnableLds>
|
||||
struct AThreadCopySelector;
|
||||
|
||||
template <>
|
||||
struct AThreadCopySelector<true>
|
||||
{
|
||||
using type =
|
||||
ThreadwiseTensorSliceTransfer_v4<FloatA,
|
||||
FloatA,
|
||||
decltype(a_block_desc_k0_m0_m1_m2_k1),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5>,
|
||||
5,
|
||||
A_K1,
|
||||
A_K1>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct AThreadCopySelector<false>
|
||||
{
|
||||
using type = ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow<
|
||||
FloatA,
|
||||
FloatA,
|
||||
decltype(a_block_desc_k0_m0_m1_m2_k1),
|
||||
decltype(a_thread_desc_),
|
||||
tensor_operation::element_wise::PassThrough,
|
||||
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5>,
|
||||
5,
|
||||
A_K1,
|
||||
false>;
|
||||
};
|
||||
|
||||
template <bool EnableLds>
|
||||
struct BThreadCopySelector;
|
||||
|
||||
template <>
|
||||
struct BThreadCopySelector<true>
|
||||
{
|
||||
using type =
|
||||
ThreadwiseTensorSliceTransfer_v4<FloatB,
|
||||
FloatB,
|
||||
decltype(b_block_desc_k0_n0_n1_n2_k1),
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5>,
|
||||
5,
|
||||
B_K1,
|
||||
B_K1>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct BThreadCopySelector<false>
|
||||
{
|
||||
using type = ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow<
|
||||
FloatB,
|
||||
FloatB,
|
||||
decltype(b_block_desc_k0_n0_n1_n2_k1),
|
||||
decltype(b_thread_desc_),
|
||||
tensor_operation::element_wise::PassThrough,
|
||||
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5>,
|
||||
5,
|
||||
B_K1,
|
||||
false>;
|
||||
};
|
||||
|
||||
typename AThreadCopySelector<AEnableLds>::type a_thread_copy_;
|
||||
typename BThreadCopySelector<BEnableLds>::type b_thread_copy_;
|
||||
};
|
||||
#else
|
||||
template <index_t BlockSize,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
@@ -527,5 +1025,6 @@ struct BlockwiseGemmWMMA
|
||||
typename AThreadCopySelector<AEnableLds>::type a_thread_copy_;
|
||||
typename BThreadCopySelector<BEnableLds>::type b_thread_copy_;
|
||||
};
|
||||
#endif
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -487,7 +487,14 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
// sync point.
|
||||
if constexpr(k.value != 0 || KPerInnerLoop == KPerThread)
|
||||
{
|
||||
#ifdef __gfx12__
|
||||
asm volatile("\
|
||||
s_barrier_signal -1 \n \
|
||||
s_barrier_wait -1 \
|
||||
" ::);
|
||||
#else
|
||||
asm volatile("s_barrier" ::);
|
||||
#endif
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
|
||||
|
||||
@@ -133,8 +133,13 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
|
||||
static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
|
||||
|
||||
static constexpr auto AEnableLds_auto = NWaves == 1 ? false : true;
|
||||
static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true;
|
||||
static constexpr auto MaxVectorLoadA = K1 * sizeof(ADataType) == 16 ? true : false;
|
||||
static constexpr auto MaxVectorLoadB = K1 * sizeof(BDataType) == 16 ? true : false;
|
||||
|
||||
static constexpr auto AEnableLds_auto =
|
||||
(NWaves == 1 && (MaxVectorLoadA || MRepeat == 1)) ? false : true;
|
||||
static constexpr auto BEnableLds_auto =
|
||||
(MWaves == 1 && (MaxVectorLoadB || NRepeat == 1)) ? false : true;
|
||||
|
||||
// If true, LDS is used unconditionally
|
||||
static constexpr auto AEnableLds_manu = false;
|
||||
@@ -829,7 +834,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(ck::is_gfx11_supported())
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
@@ -869,11 +874,15 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
}
|
||||
else
|
||||
{
|
||||
if(!(arg.a_kz_stride_ == 1 &&
|
||||
arg.a_grid_desc_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0))
|
||||
if(!(arg.a_kz_stride_ == 1))
|
||||
{
|
||||
printf("DeviceOp: Vector Access A-k check failure\n");
|
||||
return false;
|
||||
index_t LastK =
|
||||
AEnableLds ? arg.a_grid_desc_.GetLength(I2) : arg.a_grid_desc_.GetLength(I6);
|
||||
if(LastK % ABlockTransferSrcScalarPerVector == 0)
|
||||
{
|
||||
printf("DeviceOp: Vector Access A-k check failure\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -70,8 +70,9 @@ __global__ void
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \
|
||||
defined(__gfx12__))
|
||||
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
@@ -648,7 +649,7 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD<ALayout,
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
|
||||
ck::is_gfx103_supported() || ck::is_gfx11_supported())
|
||||
ck::is_gfx103_supported() || ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
bool pass = true;
|
||||
pass = pass && arg.K_ % K1 == 0;
|
||||
|
||||
@@ -56,7 +56,7 @@ __global__ void
|
||||
bool input_permute,
|
||||
bool output_permute)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
||||
|
||||
// clang-format off
|
||||
// ***************************************************
|
||||
@@ -159,6 +159,7 @@ __global__ void
|
||||
ignore = O;
|
||||
ignore = G0;
|
||||
ignore = G1;
|
||||
ignore = alpha;
|
||||
ignore = input_permute;
|
||||
ignore = output_permute;
|
||||
#endif // end of if (defined(__gfx11__))
|
||||
@@ -187,7 +188,7 @@ __global__ void
|
||||
index_t head_size,
|
||||
float alpha)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
||||
|
||||
// clang-format off
|
||||
// ***************************************************
|
||||
@@ -321,7 +322,7 @@ __global__ void
|
||||
index_t head_size,
|
||||
float alpha)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
||||
|
||||
// clang-format off
|
||||
// ***************************************************
|
||||
@@ -858,7 +859,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
|
||||
|
||||
static bool IsSupportedArgument(const RawArg& arg)
|
||||
{
|
||||
if(ck::is_gfx11_supported())
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
|
||||
{
|
||||
|
||||
@@ -592,9 +592,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::get_device_name() != "gfx90a" && ck::get_device_name() != "gfx940" &&
|
||||
ck::get_device_name() != "gfx941" && ck::get_device_name() != "gfx942" &&
|
||||
std::is_same<ADataType, double>::value)
|
||||
if(!ck::is_lds_direct_load_supported() && std::is_same<ADataType, double>::value)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -1393,7 +1393,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
|
||||
{
|
||||
// check device
|
||||
if(!(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() ||
|
||||
ck::is_gfx11_supported()))
|
||||
ck::is_gfx11_supported() || ck::is_gfx12_supported()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -509,7 +509,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(ck::is_gfx11_supported())
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, ck::half_t> ||
|
||||
is_same_v<AccDataType, int32_t>))
|
||||
|
||||
@@ -536,7 +536,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
|
||||
}
|
||||
|
||||
if(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() ||
|
||||
ck::is_gfx11_supported())
|
||||
ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
return GridwiseGemm::CheckValidity(
|
||||
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_);
|
||||
|
||||
@@ -50,8 +50,9 @@ __global__ void
|
||||
const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \
|
||||
defined(__gfx12__))
|
||||
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(ABDataType);
|
||||
@@ -552,7 +553,7 @@ struct DeviceGemmMultipleD_Dl : public DeviceGemmMultipleD<ALayout,
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
|
||||
ck::is_gfx103_supported() || ck::is_gfx11_supported())
|
||||
ck::is_gfx103_supported() || ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
return GridwiseGemm::CheckValidity(
|
||||
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.e_grid_desc_m_n_);
|
||||
|
||||
@@ -515,7 +515,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(ck::is_gfx11_supported())
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
|
||||
@@ -84,14 +84,21 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
|
||||
// K1 = Max Vector Access Pixels
|
||||
static constexpr auto K1Number = Number<K1>{};
|
||||
|
||||
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
|
||||
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
|
||||
static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
|
||||
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
|
||||
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
|
||||
static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
|
||||
static constexpr auto MaxVectorLoadA = K1 * sizeof(ADataType) == 16 ? true : false;
|
||||
static constexpr auto MaxVectorLoadB = K1 * sizeof(BDataType) == 16 ? true : false;
|
||||
|
||||
static constexpr auto AEnableLds_auto =
|
||||
(NWaves == 1 && is_same<tensor_layout::gemm::RowMajor, ALayout>::value) ? false : true;
|
||||
static constexpr auto AEnableLds_auto = (NWaves == 1 && (MaxVectorLoadA || MRepeat == 1) &&
|
||||
is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
|
||||
? false
|
||||
: true;
|
||||
static constexpr auto BEnableLds_auto =
|
||||
(MWaves == 1 && is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) ? false : true;
|
||||
(MWaves == 1 && (MaxVectorLoadB || NRepeat == 1) &&
|
||||
is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
|
||||
? false
|
||||
: true;
|
||||
|
||||
// If true, LDS is used unconditionally
|
||||
static constexpr auto AEnableLds_manu = false;
|
||||
@@ -443,7 +450,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(ck::is_gfx11_supported())
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, ck::half_t> ||
|
||||
is_same_v<AccDataType, int32_t>))
|
||||
|
||||
@@ -629,7 +629,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
// check device
|
||||
if(ck::is_gfx11_supported())
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
|
||||
@@ -48,8 +48,9 @@ __global__ void
|
||||
const Block2CTileMap block_2_ctile_map,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
|
||||
defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
|
||||
defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__) || \
|
||||
defined(__gfx12__))
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
@@ -692,7 +692,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
// check device
|
||||
if(ck::is_gfx11_supported())
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
|
||||
@@ -90,8 +90,9 @@ __global__ void
|
||||
const Block2CTileMap block_2_ctile_map,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
|
||||
defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
|
||||
defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__) || \
|
||||
defined(__gfx12__))
|
||||
// offset base pointer for each work-group
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
@@ -667,7 +668,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
|
||||
// check device
|
||||
if(!(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
|
||||
ck::is_gfx103_supported() || ck::is_gfx11_supported()))
|
||||
ck::is_gfx103_supported() || ck::is_gfx11_supported() || ck::is_gfx12_supported()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -107,7 +107,7 @@ __global__ void
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
|
||||
defined(__gfx11__))
|
||||
defined(__gfx11__) || defined(__gfx12__))
|
||||
// offset base pointer for each work-group
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
@@ -603,7 +603,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
|
||||
|
||||
// check device
|
||||
if(!(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() ||
|
||||
ck::is_gfx11_supported()))
|
||||
ck::is_gfx11_supported() || ck::is_gfx12_supported()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -582,7 +582,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
|
||||
// check device
|
||||
if(ck::is_gfx11_supported())
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
|
||||
@@ -39,8 +39,9 @@ __global__ void
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx90a__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx94__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx90a__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx94__) || \
|
||||
defined(__gfx12__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
const index_t block_id = get_block_1d_id();
|
||||
@@ -673,7 +674,7 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
|
||||
}
|
||||
|
||||
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
|
||||
ck::is_gfx103_supported() || ck::is_gfx11_supported())
|
||||
ck::is_gfx103_supported() || ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
|
||||
{
|
||||
|
||||
@@ -61,7 +61,7 @@ __global__ void
|
||||
bool input_permute,
|
||||
bool output_permute)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
||||
|
||||
// clang-format off
|
||||
// ***************************************************
|
||||
@@ -166,6 +166,7 @@ __global__ void
|
||||
ignore = O;
|
||||
ignore = G0;
|
||||
ignore = G1;
|
||||
ignore = alpha;
|
||||
ignore = input_permute;
|
||||
ignore = output_permute;
|
||||
#endif // end of if (defined(__gfx11__))
|
||||
@@ -596,7 +597,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma
|
||||
|
||||
static bool IsSupportedArgument(const RawArg& arg)
|
||||
{
|
||||
if(ck::is_gfx11_supported())
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
|
||||
{
|
||||
|
||||
@@ -60,7 +60,7 @@ __global__ void
|
||||
bool input_permute,
|
||||
bool output_permute)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
||||
|
||||
// clang-format off
|
||||
// ***************************************************
|
||||
@@ -165,6 +165,7 @@ __global__ void
|
||||
ignore = O;
|
||||
ignore = G0;
|
||||
ignore = G1;
|
||||
ignore = alpha;
|
||||
ignore = input_permute;
|
||||
ignore = output_permute;
|
||||
#endif // end of if (defined(__gfx11__))
|
||||
@@ -594,7 +595,7 @@ struct DeviceMultiQueryAttentionForward_Wmma
|
||||
|
||||
static bool IsSupportedArgument(const RawArg& arg)
|
||||
{
|
||||
if(ck::is_gfx11_supported())
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
|
||||
{
|
||||
|
||||
@@ -371,12 +371,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
|
||||
if constexpr(B0EnableLds)
|
||||
{
|
||||
// BK0_L_BK1 -> BK0_LRepeat_Lwaves_LPerWmma_BK1
|
||||
constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0);
|
||||
constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2);
|
||||
constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0);
|
||||
constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2);
|
||||
#ifdef __gfx12__
|
||||
constexpr auto B_KRow = I2;
|
||||
#else
|
||||
constexpr auto B_KRow = I1;
|
||||
#endif
|
||||
return transform_tensor_descriptor(
|
||||
B0BlockDesc_{},
|
||||
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0 / B_KRow>{}, B_KRow)),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<LRepeat>{}, Number<LWaves>{}, Number<LPerWmma>{})),
|
||||
make_pass_through_transform(Number<B_K1>{})),
|
||||
@@ -428,12 +432,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
|
||||
if constexpr(B1EnableLds)
|
||||
{
|
||||
// BL0_N_BL1 -> BL0_NRepeat_Nwaves_NPerWmma_BL1
|
||||
constexpr auto B_L0 = B1BlockDesc_{}.GetLength(I0);
|
||||
constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I2);
|
||||
constexpr auto B_L0 = B1BlockDesc_{}.GetLength(I0);
|
||||
constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I2);
|
||||
#ifdef __gfx12__
|
||||
constexpr auto B_LRow = I2;
|
||||
#else
|
||||
constexpr auto B_LRow = I1;
|
||||
#endif
|
||||
return transform_tensor_descriptor(
|
||||
B1BlockDesc_{},
|
||||
make_tuple(make_unmerge_transform(make_tuple(Number<B_L0>{}, B_LRow)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(Number<B_L0 / B_LRow>{}, B_LRow)),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
|
||||
make_pass_through_transform(Number<B_L1>{})),
|
||||
|
||||
@@ -50,7 +50,7 @@ __global__ void
|
||||
const CElementwiseOperation c_element_op,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
||||
__shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
@@ -302,12 +302,16 @@ struct GridwiseFpAintBGemm_Wmma
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
|
||||
constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
|
||||
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
|
||||
constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
|
||||
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
|
||||
#ifdef __gfx12__
|
||||
constexpr auto A_KRow = I2;
|
||||
#else
|
||||
constexpr auto A_KRow = I1;
|
||||
#endif
|
||||
return transform_tensor_descriptor(
|
||||
ABlockDesc_{},
|
||||
make_tuple(make_unmerge_transform(make_tuple(Number<A_K0>{}, A_KRow)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(Number<A_K0 / A_KRow>{}, A_KRow)),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})),
|
||||
make_pass_through_transform(Number<A_K1>{})),
|
||||
@@ -360,12 +364,16 @@ struct GridwiseFpAintBGemm_Wmma
|
||||
if constexpr(BEnableLds)
|
||||
{
|
||||
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
|
||||
constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
|
||||
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
|
||||
constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
|
||||
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
|
||||
#ifdef __gfx12__
|
||||
constexpr auto B_KRow = I2;
|
||||
#else
|
||||
constexpr auto B_KRow = I1;
|
||||
#endif
|
||||
return transform_tensor_descriptor(
|
||||
BBlockDesc_{},
|
||||
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0 / B_KRow>{}, B_KRow)),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
|
||||
make_pass_through_transform(Number<B_K1>{})),
|
||||
|
||||
@@ -54,7 +54,7 @@ __global__ void
|
||||
const Block2CTileMap block_2_ctile_map,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
||||
// offset base pointer for each work-group
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
@@ -147,7 +147,7 @@ __global__ void
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const Block2CTileMap block_2_etile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
||||
// printf("entry kernel launch");
|
||||
__shared__ char p_shared[GridwiseOp::SharedMemTrait::lds_size];
|
||||
|
||||
@@ -237,7 +237,7 @@ __global__ void
|
||||
const CDEElementwiseOperation cde_element_op,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
||||
__shared__ char p_shared[GridwiseOp::SharedMemTrait::lds_size];
|
||||
|
||||
GridwiseOp::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
@@ -375,8 +375,9 @@ struct GridwiseGemmMultipleD_Wmma
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto A_KRow = I2;
|
||||
constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
|
||||
constexpr auto K0PerWmma = WmmaK / 2 / K1;
|
||||
constexpr auto K0PerWmma = WmmaK / A_KRow / K1;
|
||||
// KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<KWmmaPerblock>{},
|
||||
@@ -422,8 +423,9 @@ struct GridwiseGemmMultipleD_Wmma
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto B_KRow = I2;
|
||||
constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
|
||||
constexpr auto K0PerWmma = WmmaK / 2 / K1;
|
||||
constexpr auto K0PerWmma = WmmaK / B_KRow / K1;
|
||||
// KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<KWmmaPerblock>{},
|
||||
@@ -495,12 +497,16 @@ struct GridwiseGemmMultipleD_Wmma
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
|
||||
constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
|
||||
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
|
||||
constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
|
||||
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
|
||||
#ifdef __gfx12__
|
||||
constexpr auto A_KRow = I2;
|
||||
#else
|
||||
constexpr auto A_KRow = I1;
|
||||
#endif
|
||||
return transform_tensor_descriptor(
|
||||
ABlockDesc_{},
|
||||
make_tuple(make_unmerge_transform(make_tuple(Number<A_K0>{}, A_KRow)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(Number<A_K0 / A_KRow>{}, A_KRow)),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})),
|
||||
make_pass_through_transform(Number<A_K1>{})),
|
||||
@@ -534,12 +540,16 @@ struct GridwiseGemmMultipleD_Wmma
|
||||
if constexpr(BEnableLds)
|
||||
{
|
||||
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
|
||||
constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
|
||||
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
|
||||
constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
|
||||
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
|
||||
#ifdef __gfx12__
|
||||
constexpr auto B_KRow = I2;
|
||||
#else
|
||||
constexpr auto B_KRow = I1;
|
||||
#endif
|
||||
return transform_tensor_descriptor(
|
||||
BBlockDesc_{},
|
||||
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0 / B_KRow>{}, B_KRow)),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
|
||||
make_pass_through_transform(Number<B_K1>{})),
|
||||
@@ -571,15 +581,12 @@ struct GridwiseGemmMultipleD_Wmma
|
||||
// *Caution Here repeat is shuffle repeat
|
||||
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
|
||||
{
|
||||
constexpr index_t MWave = MPerBlock / (MRepeat * MPerWmma);
|
||||
constexpr index_t NWave = NPerBlock / (NRepeat * NPerWmma);
|
||||
|
||||
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
|
||||
make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1,
|
||||
Number<CShuffleMRepeatPerShuffle * MWave * MPerWmma>{},
|
||||
Number<CShuffleMRepeatPerShuffle * MWaves * MPerWmma>{},
|
||||
I1,
|
||||
Number<CShuffleNRepeatPerShuffle * NWave * NPerWmma>{}));
|
||||
Number<CShuffleNRepeatPerShuffle * NWaves * NPerWmma>{}));
|
||||
|
||||
return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat;
|
||||
}
|
||||
@@ -799,8 +806,9 @@ struct GridwiseGemmMultipleD_Wmma
|
||||
const auto M = e_grid_desc_m_n.GetLength(I0);
|
||||
const auto N = e_grid_desc_m_n.GetLength(I1);
|
||||
|
||||
const auto MBlock = M / MPerBlock;
|
||||
const auto NBlock = N / NPerBlock;
|
||||
const auto MBlock = M / MPerBlock;
|
||||
const auto NBlock = N / NPerBlock;
|
||||
|
||||
const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
|
||||
e_grid_desc_m_n,
|
||||
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
|
||||
|
||||
@@ -45,7 +45,7 @@ __global__ void
|
||||
const CElementwiseOperation c_element_op,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
||||
__shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
@@ -170,8 +170,9 @@ struct GridwiseGemm_Wmma
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto A_KRow = I2;
|
||||
constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
|
||||
constexpr auto K0PerWmma = WmmaK / 2 / K1;
|
||||
constexpr auto K0PerWmma = WmmaK / A_KRow / K1;
|
||||
// KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<KWmmaPerblock>{},
|
||||
@@ -217,8 +218,10 @@ struct GridwiseGemm_Wmma
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
constexpr auto B_KRow = I2;
|
||||
constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
|
||||
constexpr auto K0PerWmma = WmmaK / 2 / K1;
|
||||
constexpr auto K0PerWmma = WmmaK / B_KRow / K1;
|
||||
// KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<KWmmaPerblock>{},
|
||||
@@ -290,12 +293,17 @@ struct GridwiseGemm_Wmma
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
|
||||
constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
|
||||
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
|
||||
constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
|
||||
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
|
||||
#ifdef __gfx12__
|
||||
constexpr auto A_KRow = I2;
|
||||
#else
|
||||
constexpr auto A_KRow = I1;
|
||||
#endif
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
ABlockDesc_{},
|
||||
make_tuple(make_unmerge_transform(make_tuple(Number<A_K0>{}, A_KRow)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(Number<A_K0 / A_KRow>{}, A_KRow)),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})),
|
||||
make_pass_through_transform(Number<A_K1>{})),
|
||||
@@ -348,12 +356,16 @@ struct GridwiseGemm_Wmma
|
||||
if constexpr(BEnableLds)
|
||||
{
|
||||
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
|
||||
constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
|
||||
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
|
||||
constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
|
||||
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
|
||||
#ifdef __gfx12__
|
||||
constexpr auto B_KRow = I2;
|
||||
#else
|
||||
constexpr auto B_KRow = I1;
|
||||
#endif
|
||||
return transform_tensor_descriptor(
|
||||
BBlockDesc_{},
|
||||
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0 / B_KRow>{}, B_KRow)),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
|
||||
make_pass_through_transform(Number<B_K1>{})),
|
||||
@@ -522,12 +534,6 @@ struct GridwiseGemm_Wmma
|
||||
c_grid_desc_m_n);
|
||||
}
|
||||
|
||||
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
CGridDesc_M_N{}))>;
|
||||
using DefaultBlock2CTileMap =
|
||||
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
|
||||
|
||||
struct SharedMemTrait
|
||||
{
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
@@ -559,6 +565,12 @@ struct GridwiseGemm_Wmma
|
||||
b_block_space_size_aligned * sizeof(BDataType));
|
||||
};
|
||||
|
||||
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
CGridDesc_M_N{}))>;
|
||||
using DefaultBlock2CTileMap =
|
||||
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
|
||||
|
||||
template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
|
||||
__device__ static void Run(const ADataType* __restrict__ p_a_grid,
|
||||
const BDataType* __restrict__ p_b_grid,
|
||||
|
||||
@@ -35,8 +35,9 @@ __global__ void
|
||||
const Block2ETileMap block_2_tile_map,
|
||||
const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \
|
||||
defined(__gfx12__))
|
||||
GridwiseTensorRearrangeKernel::Run(in_grid_desc,
|
||||
p_in_global,
|
||||
out_grid_desc,
|
||||
|
||||
@@ -1304,7 +1304,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
|
||||
ElementwiseOperation element_op_;
|
||||
};
|
||||
|
||||
// Specilized for WMMA
|
||||
// Specilized for WMMA-Navi3
|
||||
// A single Wave32 is composed by double row
|
||||
// Data exchange allowed between these two rows
|
||||
// This RowLane Dst buf will be filled from two Src buf
|
||||
@@ -1439,4 +1439,111 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
|
||||
ElementwiseOperation element_op_{};
|
||||
};
|
||||
|
||||
// Specilized for WMMA-Navi4
|
||||
template <typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename ElementwiseOperation,
|
||||
typename SliceLengths,
|
||||
typename DimAccessOrder,
|
||||
index_t DstVectorDim,
|
||||
index_t DstScalarPerVector,
|
||||
bool IntraRowSwizzlePerm,
|
||||
typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow
|
||||
{
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow(const Index& src_idx)
|
||||
{
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc need to known at compile-time");
|
||||
|
||||
static_assert(SliceLengths::At(Number<DstVectorDim>{}) % DstScalarPerVector == 0,
|
||||
"wrong! Not divisible");
|
||||
ignore = src_idx;
|
||||
}
|
||||
|
||||
template <typename SrcSliceOriginIdx,
|
||||
typename DstSliceOriginIdx,
|
||||
typename SrcBuffer,
|
||||
typename DstBuffer>
|
||||
__device__ void Run(const SrcDesc&,
|
||||
const SrcSliceOriginIdx&,
|
||||
const SrcBuffer& src_buf,
|
||||
const DstDesc&,
|
||||
const DstSliceOriginIdx&,
|
||||
DstBuffer& dst_buf) const
|
||||
{
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc need to known at compile-time");
|
||||
|
||||
static_assert(is_known_at_compile_time<remove_cvref_t<SrcSliceOriginIdx>>::value &&
|
||||
is_known_at_compile_time<remove_cvref_t<DstSliceOriginIdx>>::value,
|
||||
"wrong! SliceOrigin need to known at compile-time");
|
||||
|
||||
static_assert(SrcBuffer::IsStaticBuffer() && DstBuffer::IsStaticBuffer(),
|
||||
"wrong! Buffer need to be StaticBuffer");
|
||||
|
||||
// SrcDesc and src_slice_origin_idx are known at compile-time
|
||||
constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
|
||||
constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
|
||||
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
|
||||
constexpr auto dst_slice_origin_idx = to_multi_index(DstSliceOriginIdx{});
|
||||
|
||||
// scalar per access on each dim
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto dst_scalar_step_in_vector =
|
||||
generate_sequence(detail::lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
|
||||
|
||||
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
DimAccessOrder,
|
||||
remove_cv_t<decltype(dst_scalar_per_access)>>;
|
||||
|
||||
static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector,
|
||||
"wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector");
|
||||
|
||||
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
|
||||
|
||||
static_for<0, num_access, 1>{}([&](auto idx_1d) {
|
||||
constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d);
|
||||
|
||||
// copy data from src_buf into dst_vector
|
||||
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
|
||||
// src_desc error, non constexpr, caused by merge transform
|
||||
constexpr index_t src_offset = src_desc.CalculateOffset(
|
||||
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
|
||||
|
||||
constexpr index_t dst_offset = dst_desc.CalculateOffset(
|
||||
dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
|
||||
|
||||
SrcData v_this_row;
|
||||
// int type temp value due to intrinsic requirement
|
||||
int temp = 0;
|
||||
|
||||
// apply element-wise operation
|
||||
element_op_(v_this_row, src_buf[Number<src_offset>{}]);
|
||||
|
||||
// apply intra-row permute.
|
||||
if constexpr(IntraRowSwizzlePerm)
|
||||
{
|
||||
temp = __builtin_amdgcn_permlane16(
|
||||
temp, type_convert_sp<int>(v_this_row), 0xb3a29180, 0xf7e6d5c4, 1, 0);
|
||||
v_this_row = type_convert_sp<SrcData>(temp);
|
||||
}
|
||||
|
||||
// apply type convert
|
||||
dst_buf(Number<dst_offset>{}) = type_convert_sp<DstData>(v_this_row);
|
||||
});
|
||||
});
|
||||
}
|
||||
ElementwiseOperation element_op_{};
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -11,12 +11,17 @@ namespace ck {
|
||||
|
||||
enum struct WmmaInstr
|
||||
{
|
||||
// gfx11
|
||||
wmma_f32_16x16x16_f16 = 0,
|
||||
wmma_f32_16x16x16_bf16,
|
||||
wmma_f16_16x16x16_f16,
|
||||
wmma_bf16_16x16x16_bf16,
|
||||
wmma_i32_16x16x16_iu8,
|
||||
wmma_i32_16x16x16_iu4
|
||||
wmma_i32_16x16x16_iu4,
|
||||
// gfx12
|
||||
wmma_f32_16x16x16_f16_gfx12,
|
||||
wmma_f32_16x16x16_bf16_gfx12,
|
||||
wmma_i32_16x16x16_iu8_gfx12,
|
||||
};
|
||||
|
||||
/*
|
||||
@@ -279,6 +284,122 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
|
||||
}
|
||||
};
|
||||
|
||||
// gfx12
|
||||
|
||||
// A-swizzled
|
||||
template <index_t WaveSize>
|
||||
struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16_gfx12,
|
||||
WaveSize,
|
||||
typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
|
||||
{
|
||||
// Absolute fixing property
|
||||
// * Data Pixel
|
||||
static constexpr index_t m_per_wmma = 16;
|
||||
static constexpr index_t n_per_wmma = 16;
|
||||
static constexpr index_t k_per_wmma = 16;
|
||||
// static constexpr index_t src_a_data_size = 2;
|
||||
// static constexpr index_t src_b_data_size = 2;
|
||||
// static constexpr index_t acc_data_size = 4;
|
||||
// * Thread mapping inside wave, num_thread_per_subgroups always alone N direction
|
||||
static constexpr index_t acc_data_size = 4;
|
||||
static constexpr index_t acc_pack_number = 1;
|
||||
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
|
||||
|
||||
// Wave mode dependent propety
|
||||
static constexpr index_t wave_size = Number<WaveSize>{};
|
||||
// * Fixed in Navi3x, Will be wave mode dependent on Navi4x
|
||||
// static constexpr index_t num_src_a_vgprs_per_wave = k_per_wmma / 2 * src_a_data_size / 4;
|
||||
// static constexpr index_t num_src_b_vgprs_per_wave = k_per_wmma / 2 * src_b_data_size / 4;
|
||||
// * num_acc_vgprs_per_wave alone M direction
|
||||
// * num_subgroups alone M direction
|
||||
static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
|
||||
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
|
||||
|
||||
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
static_assert(wave_size == 32, "only support wave32 for gfx12 wmma");
|
||||
if constexpr(wave_size == 32)
|
||||
{
|
||||
intrin_wmma_f32_16x16x16_f16_w32_gfx12<MPerWmma, NPerWmma>::Run(a, b, reg_c);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t WaveSize>
|
||||
struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16_gfx12,
|
||||
WaveSize,
|
||||
typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
|
||||
{
|
||||
// Absolute fixing property
|
||||
static constexpr index_t m_per_wmma = 16;
|
||||
static constexpr index_t n_per_wmma = 16;
|
||||
static constexpr index_t k_per_wmma = 16;
|
||||
// static constexpr index_t src_a_data_size = 2;
|
||||
// static constexpr index_t src_b_data_size = 2;
|
||||
static constexpr index_t acc_data_size = 4;
|
||||
static constexpr index_t acc_pack_number = 1;
|
||||
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
|
||||
|
||||
// Wave mode dependent propety
|
||||
static constexpr index_t wave_size = Number<WaveSize>{};
|
||||
// static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
|
||||
// static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
|
||||
static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
|
||||
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
|
||||
|
||||
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
static_assert(wave_size == 32, "only support wave32 for gfx12 wmma");
|
||||
if constexpr(wave_size == 32)
|
||||
{
|
||||
intrin_wmma_f32_16x16x16_bf16_w32_gfx12<MPerWmma, NPerWmma>::Run(a, b, reg_c);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t WaveSize>
|
||||
struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8_gfx12,
|
||||
WaveSize,
|
||||
typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
|
||||
{
|
||||
// Absolute fixing property
|
||||
static constexpr index_t m_per_wmma = 16;
|
||||
static constexpr index_t n_per_wmma = 16;
|
||||
static constexpr index_t k_per_wmma = 16;
|
||||
// static constexpr index_t src_a_data_size = 2;
|
||||
// static constexpr index_t src_b_data_size = 2;
|
||||
static constexpr index_t acc_data_size = 4;
|
||||
static constexpr index_t acc_pack_number = 1;
|
||||
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
|
||||
|
||||
// Wave mode dependent propety
|
||||
static constexpr index_t wave_size = Number<WaveSize>{};
|
||||
// static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
|
||||
// static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
|
||||
static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
|
||||
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
|
||||
|
||||
template <index_t MPerWmma,
|
||||
index_t NPerWmma,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC,
|
||||
bool neg_a = false,
|
||||
bool neg_b = false,
|
||||
bool clamp = false>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
static_assert(wave_size == 32, "only support wave32 for gfx12 wmma");
|
||||
if constexpr(wave_size == 32)
|
||||
{
|
||||
intrin_wmma_i32_16x16x16_iu8_w32_gfx12<MPerWmma, NPerWmma, neg_a, neg_b, clamp>::Run(
|
||||
a, b, reg_c);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename src_type_a,
|
||||
typename src_type_b,
|
||||
typename dst_type,
|
||||
@@ -296,13 +417,21 @@ struct WmmaSelector
|
||||
template <>
|
||||
static constexpr auto GetWmma<half_t, half_t, float, 16, 16>()
|
||||
{
|
||||
#ifdef __gfx12__
|
||||
return WmmaInstr::wmma_f32_16x16x16_f16_gfx12;
|
||||
#else
|
||||
return WmmaInstr::wmma_f32_16x16x16_f16;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetWmma<bhalf_t, bhalf_t, float, 16, 16>()
|
||||
{
|
||||
#ifdef __gfx12__
|
||||
return WmmaInstr::wmma_f32_16x16x16_bf16_gfx12;
|
||||
#else
|
||||
return WmmaInstr::wmma_f32_16x16x16_bf16;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
@@ -320,8 +449,13 @@ struct WmmaSelector
|
||||
template <>
|
||||
static constexpr auto GetWmma<int8_t, int8_t, int, 16, 16>()
|
||||
{
|
||||
#ifdef __gfx12__
|
||||
return WmmaInstr::wmma_i32_16x16x16_iu8_gfx12;
|
||||
#else
|
||||
return WmmaInstr::wmma_i32_16x16x16_iu8;
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
template <>
|
||||
static constexpr auto GetWmma<int4_t, int4_t, int, 16, 16>()
|
||||
@@ -502,6 +636,9 @@ struct WmmaGemm
|
||||
|
||||
__device__ static auto GetSubGroupId()
|
||||
{
|
||||
static_assert(wmma_instr.num_thread_per_subgroups * wmma_instr.num_subgroups ==
|
||||
wmma_instr.wave_size,
|
||||
"");
|
||||
return (GetLaneId() / wmma_instr.num_thread_per_subgroups) % wmma_instr.num_subgroups;
|
||||
}
|
||||
|
||||
@@ -516,12 +653,20 @@ struct WmmaGemm
|
||||
|
||||
__host__ __device__ static auto CalculateAThreadOriginDataIndex()
|
||||
{
|
||||
#ifdef __gfx12__
|
||||
return GetLaneIdUnderSubGroup();
|
||||
#else
|
||||
return TransposeC ? GetLaneIdUnderSubGroup() : GetSwizzledLaneIdLow();
|
||||
#endif
|
||||
}
|
||||
|
||||
__host__ __device__ static auto CalculateBThreadOriginDataIndex()
|
||||
{
|
||||
#ifdef __gfx12__
|
||||
return GetLaneIdUnderSubGroup();
|
||||
#else
|
||||
return TransposeC ? GetSwizzledLaneIdLow() : GetLaneIdUnderSubGroup();
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ static CIndex GetBeginOfThreadBlk()
|
||||
|
||||
Reference in New Issue
Block a user