mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
clang format
This commit is contained in:
@@ -58,10 +58,12 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
|
||||
|
||||
//> store rows/cols into thread registers in chunks of 16
|
||||
//> e.g. [k0,...,k15,k64,...,k79] or [k0,...,k15,k32,...,k47]
|
||||
static constexpr index_t APackedSize = is_same_v<remove_cvref_t<ComputeTypeA>, f4x2_pk_t>
|
||||
? 2
|
||||
: 1;
|
||||
static constexpr index_t KThreadChunk = 16 * APackedSize/ sizeof(ComputeTypeA);
|
||||
static constexpr index_t APackedSize =
|
||||
is_same_v<remove_cvref_t<ComputeTypeA>, f4x2_pk_t> ? 2 : 1;
|
||||
static constexpr index_t BPackedSize =
|
||||
is_same_v<remove_cvref_t<ComputeTypeB>, f4x2_pk_t> ? 2 : 1;
|
||||
|
||||
static constexpr index_t KThreadChunk = 16 * APackedSize / sizeof(ComputeTypeA);
|
||||
|
||||
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
|
||||
static constexpr index_t KRepeat = KPerThread / KPack;
|
||||
@@ -327,18 +329,18 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
|
||||
// Read buffer + Compute buffer
|
||||
// A[M0, M1, M2, KPack]
|
||||
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor(
|
||||
make_tuple(Number<MRepeat>{}, I1, Number<KRepeat>{}, Number<KPack / 2>{}),
|
||||
make_tuple(Number<KPack / 2>{},
|
||||
Number<KRepeat * MRepeat * KPack / 2>{},
|
||||
Number<MRepeat * KPack / 2>{},
|
||||
make_tuple(Number<MRepeat>{}, I1, Number<KRepeat>{}, Number<KPack / APackedSize>{}),
|
||||
make_tuple(Number<KPack / APackedSize>{},
|
||||
Number<KRepeat * MRepeat * KPack / APackedSize>{},
|
||||
Number<MRepeat * KPack / APackedSize>{},
|
||||
I1));
|
||||
|
||||
// B[N0, N1, N2, KPack]
|
||||
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor(
|
||||
make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack / 2>{}),
|
||||
make_tuple(Number<KPack / 2>{},
|
||||
Number<KRepeat * NRepeat * KPack / 2>{},
|
||||
Number<NRepeat * KPack / 2>{},
|
||||
make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack / BPackedSize>{}),
|
||||
make_tuple(Number<KPack / BPackedSize>{},
|
||||
Number<KRepeat * NRepeat * KPack / BPackedSize>{},
|
||||
Number<NRepeat * KPack / BPackedSize>{},
|
||||
I1));
|
||||
|
||||
// C[M, N, NumRegXdlops]
|
||||
|
||||
@@ -141,7 +141,9 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
using Base::b_block_desc_n0_n1_n2_k;
|
||||
|
||||
using Base::AMmaKStride;
|
||||
using Base::APackedSize;
|
||||
using Base::BMmaKStride;
|
||||
using Base::BPackedSize;
|
||||
using Base::KThreadChunk;
|
||||
|
||||
using AccType = typename Base::AccType;
|
||||
@@ -539,10 +541,11 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec; // = vec: pk_i4_t, 32
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
vector_type<ComputeTypeA, KPack / APackedSize>
|
||||
a_thread_vec; // = vec: pk_i4_t, 32
|
||||
vector_type<ComputeTypeB, KPack / BPackedSize> b_thread_vec;
|
||||
|
||||
static_for<0, KPack / 2, 1>{}([&](auto ik) {
|
||||
static_for<0, KPack / APackedSize, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
@@ -578,12 +581,14 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
// CK_PRINT<decltype(xdlops_gemm)>();
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops / 2>::type;
|
||||
xdlops_gemm.K1PerXdlops /
|
||||
APackedSize>::type;
|
||||
// mfma input type = pk_f4_t, 32
|
||||
// CK_PRINT<mfma_input_type_a>();
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops / 2>::type;
|
||||
xdlops_gemm.K1PerXdlops /
|
||||
BPackedSize>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
@@ -721,10 +726,10 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
vector_type<ComputeTypeA, KPack / APackedSize> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack / BPackedSize> b_thread_vec;
|
||||
|
||||
static_for<0, KPack / 2, 1>{}([&](auto ik) {
|
||||
static_for<0, KPack / APackedSize, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
@@ -751,9 +756,11 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops / 2>::type;
|
||||
typename vector_type<ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops / APackedSize>::type;
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB, xdlops_gemm.K1PerXdlops / 2>::type;
|
||||
typename vector_type<ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
@@ -805,10 +812,10 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
vector_type<ComputeTypeA, KPack / APackedSize> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack / BPackedSize> b_thread_vec;
|
||||
|
||||
static_for<0, KPack / 2, 1>{}([&](auto ik) {
|
||||
static_for<0, KPack / APackedSize, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
@@ -835,9 +842,11 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops / 2>::type;
|
||||
typename vector_type<ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops / APackedSize>::type;
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB, xdlops_gemm.K1PerXdlops / 2>::type;
|
||||
typename vector_type<ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
@@ -858,10 +867,10 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
vector_type<ComputeTypeA, KPack / APackedSize> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack / BPackedSize> b_thread_vec;
|
||||
|
||||
static_for<0, KPack / 2, 1>{}([&](auto ik) {
|
||||
static_for<0, KPack / APackedSize, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
@@ -888,9 +897,11 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops / 2>::type;
|
||||
typename vector_type<ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops / APackedSize>::type;
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB, xdlops_gemm.K1PerXdlops / 2>::type;
|
||||
typename vector_type<ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
@@ -344,22 +344,22 @@ struct DeviceGemmMX_Xdl_CShuffleV3 : public DeviceGemmMX<ALayout,
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
@@ -369,20 +369,20 @@ struct DeviceGemmMX_Xdl_CShuffleV3 : public DeviceGemmMX<ALayout,
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1247,16 +1247,16 @@ struct ThreadwiseTensorSliceTransfer_v4
|
||||
{
|
||||
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
|
||||
// DstData)
|
||||
vector_type_maker_t<DstData, SrcScalarPerVector / 2> dst_tmp_vector;
|
||||
vector_type_maker_t<DstData, SrcScalarPerVector / PackedSize> dst_tmp_vector;
|
||||
|
||||
// TODO: if SrcData and DstData are vetor type, then static_cast may not compile
|
||||
static_for<0, SrcScalarPerVector / 2, 1>{}([&](auto i) {
|
||||
static_for<0, SrcScalarPerVector / PackedSize, 1>{}([&](auto i) {
|
||||
dst_tmp_vector.template AsType<DstData>()(i) =
|
||||
type_convert<DstData>(src_tmp_vector.template AsType<SrcData>()[i]);
|
||||
});
|
||||
|
||||
// copy data from dst_tmp_vector into dst_buf
|
||||
static_for<0, SrcScalarPerVector / 2, 1>{}([&](auto i) {
|
||||
static_for<0, SrcScalarPerVector / PackedSize, 1>{}([&](auto i) {
|
||||
constexpr index_t dst_offset = dst_desc.CalculateOffset(
|
||||
dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user