add padding info

This commit is contained in:
qin letao
2025-02-22 09:25:19 +00:00
parent 01c5d3fa61
commit 23b2fb0f3e
3 changed files with 37 additions and 25 deletions

View File

@@ -242,12 +242,12 @@ int main(int argc, char* argv[])
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
auto Knew = (K + 64 - 1) / 64;
auto StrideBnew = Knew;
auto Knew = (K + 64 - 1) / 64 * 64;
auto StrideBnew = Knew;
Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{}));
Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
Tensor<B0DataType> b0_preshuffled(
f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); // use laout only for size
f_host_tensor_descriptor(Knew, N, StrideBnew, B0Layout{})); // use laout only for size
Tensor<D0DataType> d0_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{}));
Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor(M, N, StrideD, D1Layout{}));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
@@ -281,7 +281,7 @@ int main(int argc, char* argv[])
d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
}
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_preshuffled.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
@@ -304,6 +304,7 @@ int main(int argc, char* argv[])
int NPerXdl = device_op.GetPreShuffleParameters();
b0_preshuffled.SetZero();
preShuffleBuffer(b0_k_n.mData.data(), b0_preshuffled.mData.data(), N, K, NPerXdl);
b0_device_buf.ToDevice(b0_preshuffled.mData.data());

View File

@@ -355,6 +355,10 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
throw std::runtime_error("todo: only v1 v2 and v3 support now");
}
}
else
{
throw std::runtime_error("not call kernel function");
}
#if 0
else
{
@@ -526,7 +530,8 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
return false;
}
if(arg.N % NPerBlock != 0 || (arg.K % KPerBlock != 0 && GemmSpec != GemmSpecialization::KPadding))
if(arg.N % NPerBlock != 0 ||
(arg.K % KPerBlock != 0 && GemmSpec != GemmSpecialization::KPadding))
{
return false;
}

View File

@@ -349,27 +349,33 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
}
}
__host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
__host__ __device__ static auto
MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0, index_t KPad)
{
// using GemmSpecialization = tensor_operation::device::GemmSpecialization;
//if K padding
// if constexpr(GemmSpec == GemmSpecialization::KPadding ||
// GemmSpec == GemmSpecialization::NKPadding)
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
// if K padding
if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// origin: [N0,K0,KLane,NLane,KPack]
// const auto b_grid_desc_raw = make_naive_tensor_descriptor(
// make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber),
// make_tuple(
// NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1));
// const auto b_grid_desc_n_k =
// transform_tensor_descriptor(b_grid_desc_nraw_kraw,
// make_tuple(make_pass_through_transform(N),
// make_right_pad_transform(K, KPad - K)),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}));
// ignore = b_grid_desc_n_k;
constexpr index_t NkSwizzleNumber = Number<warpSize * KPack>{};
const auto b_grid_desc_raw = make_naive_tensor_descriptor(
make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber),
make_tuple(
NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1));
auto K0new = CalculateBK0Shuffled(KPad);
return transform_tensor_descriptor(
b_grid_desc_raw,
make_tuple(make_pass_through_transform(N0 / NWave),
make_pass_through_transform(NWave),
make_right_pad_transform(K0, K0new - K0),
make_pass_through_transform(NkSwizzleNumber)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
}
// else
else
{
constexpr index_t NkSwizzleNumber = Number<warpSize * KPack>{};
return make_naive_tensor_descriptor(
@@ -591,7 +597,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
MBlock{CalculateMBlock(M_)},
NBlock{CalculateNBlock(N_)},
BN0Shuffled{CalculateBN0Shuffled(N_)},
BK0Shuffled{CalculateBK0Shuffled(K_)}
BK0Shuffled{CalculateBK0Shuffled((K_ + 64 - 1) / 64 * 64)}
{
}
@@ -1592,8 +1598,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bpreshuffled =
MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled);
const auto b_grid_desc_bpreshuffled = MakeBGridDescriptor_Preshuffled(
problem.BN0Shuffled, problem.BK0Shuffled, problem.KPadded);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);