mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 03:07:02 +00:00
Added b preshuffle pipeline v3 support.
This commit is contained in:
@@ -28,9 +28,9 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
|
||||
|
||||
static constexpr bool PermuteA = false;
|
||||
static constexpr bool PermuteB = false;
|
||||
static constexpr ck::index_t KPerBlock = 128;
|
||||
|
||||
// clang-format off
|
||||
#if 0
|
||||
using DeviceGemmV2Instance =
|
||||
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3_BPreshuffle<
|
||||
ALayout, BLayout, CLayout,
|
||||
@@ -38,7 +38,7 @@ using DeviceGemmV2Instance =
|
||||
AElementOp, BElementOp, CElementOp, GemmDefault,
|
||||
256,
|
||||
128, 128,
|
||||
KPerBlock, 16, 32,
|
||||
256, 16, 32,
|
||||
32, 32,
|
||||
4, 1,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
@@ -47,7 +47,26 @@ using DeviceGemmV2Instance =
|
||||
2, 32, 32, 0,
|
||||
1, 1, S<1, 32, 1, 8>, 4,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, F8, F8, PermuteA, PermuteB>;
|
||||
|
||||
#else
|
||||
using DeviceGemmV2Instance =
|
||||
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3_BPreshuffle<
|
||||
ALayout, BLayout, CLayout,
|
||||
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CElementOp, GemmDefault,
|
||||
256,
|
||||
256, 256,
|
||||
128, 16, 32,
|
||||
32, 32,
|
||||
4, 4,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 16, 16, 0,
|
||||
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 32, 32, 0,
|
||||
1, 1, S<1, 32, 1, 8>, 8,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, F8, F8, PermuteA, PermuteB>;
|
||||
|
||||
#endif
|
||||
// clang-format on
|
||||
|
||||
template <typename ProblemType>
|
||||
|
||||
@@ -510,10 +510,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_dequant_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
|
||||
StaticallyIndexedArray<decltype(b_thread_dequant_buf), Number<2>{}> b_thread_dequant_bufs;
|
||||
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
|
||||
|
||||
// Global prefetch A1 B1
|
||||
@@ -545,6 +548,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
|
||||
make_tuple(I0, I0, I0, k0, I0, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
// B VGPR->VGPR dequant
|
||||
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I0),
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_dequant_bufs(I0));
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
@@ -594,9 +604,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
|
||||
I0,
|
||||
ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs[mfma_reg_buf]
|
||||
[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
b_thread_dequant_bufs[mfma_reg_buf]
|
||||
[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
@@ -633,6 +643,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
|
||||
I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
// B VGPR->VGPR dequant
|
||||
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(local_read_buf),
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_dequant_bufs(local_read_buf));
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -652,6 +669,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
|
||||
I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
// B VGPR->VGPR dequant
|
||||
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(mfma_reg_buf),
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_dequant_bufs(mfma_reg_buf));
|
||||
}
|
||||
|
||||
HotLoopScheduler(m0);
|
||||
@@ -691,7 +715,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0 % 2, I0, I0, k0, I0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
b_thread_dequant_bufs[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
@@ -720,6 +744,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
|
||||
make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
// B VGPR->VGPR dequant
|
||||
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I1),
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_dequant_bufs(I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -732,6 +763,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
|
||||
make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
// B VGPR->VGPR dequant
|
||||
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I0),
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_dequant_bufs(I0));
|
||||
}
|
||||
|
||||
EpilogueScheduler_1(m0);
|
||||
@@ -748,7 +786,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
|
||||
(m0 + HotloopLocalBufSwitch) % 2, I0, I0, k0, I0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
|
||||
b_thread_dequant_bufs[I1][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
@@ -776,6 +814,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
|
||||
Number<(m0 + 1 + HotloopLocalBufSwitch) % 2>{}, I0, I0, k0, I0, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
// B VGPR->VGPR dequant
|
||||
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I1),
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_dequant_bufs(I1));
|
||||
|
||||
EpilogueScheduler_2();
|
||||
}
|
||||
@@ -797,7 +842,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0 % 2, I0, I0, k0, I0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
b_thread_dequant_bufs[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
@@ -823,6 +868,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
|
||||
make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
// B VGPR->VGPR dequant
|
||||
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I0),
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_dequant_bufs(I0));
|
||||
|
||||
EpilogueScheduler_2();
|
||||
}
|
||||
@@ -855,6 +907,22 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
|
||||
static constexpr BTileDesc b_block_desc_n0_n1_k0_k1;
|
||||
|
||||
using Base::c_thread_desc_;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using BThreadDequantCopy = ThreadwiseTensorSliceTransfer_StaticToStatic<
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
decltype(b_block_desc_n0_n1_k0_k1),
|
||||
decltype(b_block_desc_n0_n1_k0_k1),
|
||||
tensor_operation::element_wise::PassThrough,
|
||||
Sequence<Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}>,
|
||||
Sequence<1, 2, 0, 3>,
|
||||
3,
|
||||
KPack>;
|
||||
|
||||
const PassThrough b_element_op{};
|
||||
BThreadDequantCopy b_thread_dequant_copy_{b_element_op};
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
Reference in New Issue
Block a user