mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
add cpu shuffle
This commit is contained in:
@@ -63,36 +63,38 @@ struct MultiplyMultiply
|
||||
}
|
||||
};
|
||||
|
||||
void reshapeBuffer(char* buffer, int N, int K, char* output) {
|
||||
const int KRepeat = 2;
|
||||
const int NRepeat = 3;
|
||||
const int KLane = 4;
|
||||
const int NLane = 5;
|
||||
const int KPack = 6;
|
||||
void preShuffleBuffer(const FP8* src, int N, int K, FP8* dst) {
|
||||
const int NRepeat = 1;
|
||||
const int KRepeat = 4;
|
||||
const int KLane = 2;
|
||||
const int NLane = 128;
|
||||
const int KPack = 16;
|
||||
int N0 = N / (NRepeat * NLane);
|
||||
int K0 = K / (KRepeat * KLane * KPack);
|
||||
|
||||
int tempn, tempk;
|
||||
for (int n = 0; n < N; ++n) {
|
||||
for (int k = 0; k < K; ++k) {
|
||||
int n0 = n / (NRepeat * NLane);
|
||||
int k0 = k / (KRepeat * KLane * KPack);
|
||||
int nRel = n % (NRepeat * NLane);
|
||||
int kRel = k % (KRepeat * KLane * KPack);
|
||||
tempn = n % (NRepeat * NLane);
|
||||
tempk = k % (KRepeat * KLane * KPack);
|
||||
int n1 = tempn / NLane;
|
||||
int k1 = tempk / (KLane * KPack);
|
||||
int n2 = n1 % NLane;
|
||||
tempk = tempk % (KLane * KPack);
|
||||
int k2 = tempk / KPack;
|
||||
int k3 = tempk % KPack;
|
||||
|
||||
int nIndex = nRel / NLane;
|
||||
int kIndex = kRel / (KLane * KPack);
|
||||
int nLaneIndex = nRel % NLane;
|
||||
int kLaneIndex = (kRel % (KLane * KPack)) / KPack;
|
||||
int kPackIndex = kRel % KPack;
|
||||
int outputIndex = n0 * KPack * NLane * KLane * KRepeat * NRepeat * K0
|
||||
+ k0 * KPack * NLane * KLane * KRepeat * NRepeat
|
||||
+ n1 * KPack * NLane * KLane * KRepeat
|
||||
+ k1 * KPack * NLane * KLane
|
||||
+ k2 * KPack * NLane
|
||||
+ n2 * KPack
|
||||
+ k3;
|
||||
|
||||
int outputIndex = (n0 * K0 + k0) * KRepeat * NRepeat * KLane * NLane * KPack
|
||||
+ nIndex * KRepeat * KLane * KPack
|
||||
+ kIndex * KLane * KPack
|
||||
+ nLaneIndex * KPack
|
||||
+ kLaneIndex * KPack
|
||||
+ kPackIndex;
|
||||
|
||||
output[outputIndex] = buffer[n * K + k];
|
||||
dst[outputIndex] = src[n * K + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -191,6 +193,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{}));
|
||||
Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
|
||||
Tensor<B0DataType> b0_preshuffled(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); //use laout only for size
|
||||
Tensor<D0DataType> d0_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{}));
|
||||
Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor(M, N, StrideD, D1Layout{}));
|
||||
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
@@ -217,15 +220,15 @@ int main(int argc, char* argv[])
|
||||
d0_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{-0.5, 0.5});
|
||||
d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{-0.5, 0.5});
|
||||
}
|
||||
|
||||
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
preShuffleBuffer(b0_k_n.mData.data(), N, K, b0_preshuffled.mData.data());
|
||||
a0_device_buf.ToDevice(a0_m_k.mData.data());
|
||||
b0_device_buf.ToDevice(b0_k_n.mData.data());
|
||||
b0_device_buf.ToDevice(b0_preshuffled.mData.data());
|
||||
d0_device_buf.ToDevice(d0_m_n.mData.data());
|
||||
d1_device_buf.ToDevice(d1_m_n.mData.data());
|
||||
e_device_buf.ToDevice(e_m_n_device_result.mData.data());
|
||||
|
||||
@@ -131,7 +131,7 @@ struct GeneratorTensor_2<ck::f8_t>
|
||||
template <typename... Is>
|
||||
ck::f8_t operator()(Is...)
|
||||
{
|
||||
float tmp = (std::rand() % (max_value - min_value)) + min_value;
|
||||
float tmp = 1;
|
||||
return ck::type_convert<ck::f8_t>(tmp);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -281,7 +281,8 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
ABlockBuffer& a_block_buf0,
|
||||
ABlockBuffer& a_block_buf1,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
@@ -306,7 +307,7 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// // Local prefill 1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf0);
|
||||
|
||||
// // Global prefetch 2
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
@@ -321,19 +322,11 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<k0 * AMmaKStride>{}),
|
||||
a_block_buf,
|
||||
a_block_buf0,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, k0, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
// static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
// b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
|
||||
// make_tuple(n0, I0, I0, Number<k0 * BMmaKStride>{}),
|
||||
// b_block_buf,
|
||||
// b_thread_desc_,
|
||||
// make_tuple(n0, I0, k0, I0),
|
||||
// b_thread_buf);
|
||||
// });
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
@@ -344,9 +337,7 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
index_t i = 0;
|
||||
do
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf1);
|
||||
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, Number<1>{});
|
||||
@@ -364,8 +355,15 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
|
||||
// if(threadIdx.x==0) {
|
||||
// printf("%f, %f; ", type_convert<float>(a_thread_vec.template AsType<ComputeDataType>()(ik)), ype_convert<float>(b_thread_vec.template AsType<ComputeDataType>()(ik)));
|
||||
// }
|
||||
});
|
||||
|
||||
|
||||
// if(threadIdx.x==0) {
|
||||
// printf("\n");
|
||||
// }
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
@@ -387,7 +385,7 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<k0 * AMmaKStride>{}),
|
||||
a_block_buf,
|
||||
a_block_buf1,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, k0, I0),
|
||||
a_thread_buf);
|
||||
@@ -397,10 +395,7 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf0);
|
||||
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, Number<0>{});
|
||||
@@ -441,7 +436,7 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<k0 * AMmaKStride>{}),
|
||||
a_block_buf,
|
||||
a_block_buf0,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, k0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
@@ -486,52 +486,52 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK<ALayo
|
||||
// Tail number could be Odd or Even
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
// if(arg.KBatch > 1)
|
||||
// {
|
||||
// if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
// {
|
||||
// const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds<
|
||||
// GridwiseGemm,
|
||||
// true,
|
||||
// InMemoryDataOperationEnum::AtomicAdd,
|
||||
// minimum_occupancy,
|
||||
// TailNumber::Odd>;
|
||||
// Run(kernel);
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds<
|
||||
// GridwiseGemm,
|
||||
// true,
|
||||
// InMemoryDataOperationEnum::AtomicAdd,
|
||||
// minimum_occupancy,
|
||||
// TailNumber::Even>;
|
||||
// Run(kernel);
|
||||
// }
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
// {
|
||||
// const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds<
|
||||
// GridwiseGemm,
|
||||
// true,
|
||||
// InMemoryDataOperationEnum::Set,
|
||||
// minimum_occupancy,
|
||||
// TailNumber::Odd>;
|
||||
// Run(kernel);
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds<
|
||||
// GridwiseGemm,
|
||||
// true,
|
||||
// InMemoryDataOperationEnum::Set,
|
||||
// minimum_occupancy,
|
||||
// TailNumber::Even>;
|
||||
// Run(kernel);
|
||||
// }
|
||||
// }
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -40,6 +40,7 @@ __global__ void
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
__shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
@@ -49,42 +50,7 @@ __global__ void
|
||||
karg.p_ds_grid,
|
||||
karg.p_c_grid,
|
||||
p_shared,
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.c_element_op);
|
||||
#else
|
||||
ignore = karg;
|
||||
#endif // end of if (defined(__gfx9__))
|
||||
}
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
index_t MinimumOccupancy = 1,
|
||||
TailNumber TailNum = TailNumber::Full>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
#endif
|
||||
// __attribute__((amdgpu_waves_per_eu(1, 1)))
|
||||
kernel_gemm_xdl_cshuffle_v3_multi_d_2lds(typename GridwiseGemm::Argument karg)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
// Pass two lds pointer is the key to tell compiler that ds_read/write
|
||||
// operate on different lds chunk at same time without order dependecy
|
||||
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
|
||||
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_ds_grid,
|
||||
karg.p_c_grid,
|
||||
p_shared_0,
|
||||
p_shared_1,
|
||||
p_shared1,
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
@@ -1256,6 +1222,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
DsGridPointer& p_ds_grid,
|
||||
CDataType* p_c_grid,
|
||||
void* p_shared,
|
||||
void* p_shared1,
|
||||
const Problem& problem,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
@@ -1268,6 +1235,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
p_ds_grid,
|
||||
p_c_grid,
|
||||
p_shared,
|
||||
p_shared1,
|
||||
problem,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
@@ -1284,6 +1252,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
DsGridPointer& p_ds_grid,
|
||||
CDataType* p_c_grid,
|
||||
void* p_shared,
|
||||
void* p_shared1,
|
||||
const Problem& problem,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
@@ -1409,6 +1378,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
// Cast after lds
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<LDSTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
auto a_block_buf1 = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<LDSTypeA*>(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<LDSTypeB*>(p_shared) +
|
||||
@@ -1432,6 +1403,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_buf1,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_bpreshuffled,
|
||||
b_block_desc_bk0_n_bk1,
|
||||
|
||||
Reference in New Issue
Block a user