mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
add padding info
This commit is contained in:
@@ -242,12 +242,12 @@ int main(int argc, char* argv[])
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
auto Knew = (K + 64 - 1) / 64;
|
||||
auto StrideBnew = Knew;
|
||||
auto Knew = (K + 64 - 1) / 64 * 64;
|
||||
auto StrideBnew = Knew;
|
||||
Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{}));
|
||||
Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
|
||||
Tensor<B0DataType> b0_preshuffled(
|
||||
f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); // use laout only for size
|
||||
f_host_tensor_descriptor(Knew, N, StrideBnew, B0Layout{})); // use laout only for size
|
||||
Tensor<D0DataType> d0_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{}));
|
||||
Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor(M, N, StrideD, D1Layout{}));
|
||||
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
@@ -281,7 +281,7 @@ int main(int argc, char* argv[])
|
||||
d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
|
||||
}
|
||||
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_preshuffled.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
@@ -304,6 +304,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
int NPerXdl = device_op.GetPreShuffleParameters();
|
||||
|
||||
b0_preshuffled.SetZero();
|
||||
preShuffleBuffer(b0_k_n.mData.data(), b0_preshuffled.mData.data(), N, K, NPerXdl);
|
||||
|
||||
b0_device_buf.ToDevice(b0_preshuffled.mData.data());
|
||||
|
||||
@@ -355,6 +355,10 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
|
||||
throw std::runtime_error("todo: only v1 v2 and v3 support now");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("not call kernel function");
|
||||
}
|
||||
#if 0
|
||||
else
|
||||
{
|
||||
@@ -526,7 +530,8 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
|
||||
return false;
|
||||
}
|
||||
|
||||
if(arg.N % NPerBlock != 0 || (arg.K % KPerBlock != 0 && GemmSpec != GemmSpecialization::KPadding))
|
||||
if(arg.N % NPerBlock != 0 ||
|
||||
(arg.K % KPerBlock != 0 && GemmSpec != GemmSpecialization::KPadding))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -349,27 +349,33 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
|
||||
__host__ __device__ static auto
|
||||
MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0, index_t KPad)
|
||||
{
|
||||
// using GemmSpecialization = tensor_operation::device::GemmSpecialization;
|
||||
//if K padding
|
||||
// if constexpr(GemmSpec == GemmSpecialization::KPadding ||
|
||||
// GemmSpec == GemmSpecialization::NKPadding)
|
||||
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
|
||||
// if K padding
|
||||
if constexpr(GemmSpec == GemmSpecialization::KPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding)
|
||||
{
|
||||
// origin: [N0,K0,KLane,NLane,KPack]
|
||||
// const auto b_grid_desc_raw = make_naive_tensor_descriptor(
|
||||
// make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber),
|
||||
// make_tuple(
|
||||
// NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1));
|
||||
// const auto b_grid_desc_n_k =
|
||||
// transform_tensor_descriptor(b_grid_desc_nraw_kraw,
|
||||
// make_tuple(make_pass_through_transform(N),
|
||||
// make_right_pad_transform(K, KPad - K)),
|
||||
// make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
// make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
// ignore = b_grid_desc_n_k;
|
||||
constexpr index_t NkSwizzleNumber = Number<warpSize * KPack>{};
|
||||
const auto b_grid_desc_raw = make_naive_tensor_descriptor(
|
||||
make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber),
|
||||
make_tuple(
|
||||
NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1));
|
||||
|
||||
auto K0new = CalculateBK0Shuffled(KPad);
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_raw,
|
||||
make_tuple(make_pass_through_transform(N0 / NWave),
|
||||
make_pass_through_transform(NWave),
|
||||
make_right_pad_transform(K0, K0new - K0),
|
||||
make_pass_through_transform(NkSwizzleNumber)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
}
|
||||
// else
|
||||
else
|
||||
{
|
||||
constexpr index_t NkSwizzleNumber = Number<warpSize * KPack>{};
|
||||
return make_naive_tensor_descriptor(
|
||||
@@ -591,7 +597,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
MBlock{CalculateMBlock(M_)},
|
||||
NBlock{CalculateNBlock(N_)},
|
||||
BN0Shuffled{CalculateBN0Shuffled(N_)},
|
||||
BK0Shuffled{CalculateBK0Shuffled(K_)}
|
||||
BK0Shuffled{CalculateBK0Shuffled((K_ + 64 - 1) / 64 * 64)}
|
||||
{
|
||||
}
|
||||
|
||||
@@ -1592,8 +1598,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
|
||||
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
|
||||
|
||||
const auto b_grid_desc_bpreshuffled =
|
||||
MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled);
|
||||
const auto b_grid_desc_bpreshuffled = MakeBGridDescriptor_Preshuffled(
|
||||
problem.BN0Shuffled, problem.BK0Shuffled, problem.KPadded);
|
||||
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
|
||||
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user