move kernel parametor to host for performance

This commit is contained in:
letaoqin
2025-03-13 11:07:34 +00:00
parent dc890c0f2d
commit 457d8b4f85
3 changed files with 94 additions and 57 deletions

View File

@@ -138,6 +138,48 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
LDSTypeB>;
using Argument = typename GridwiseGemm::Argument;
struct DeviceArgument : public Argument
{
__host__ DeviceArgument(const ADataType* p_a_grid_,
const BDataType* p_b_grid_,
std::array<const void*, NumDTensor> p_ds_grid_,
CDataType* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideC_,
index_t k_batch_,
AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_,
CElementwiseOperation c_element_op_,
index_t Nr_,
index_t Kr_)
: Argument{p_a_grid_,
p_b_grid_,
p_ds_grid_,
p_c_grid_,
M_,
N_,
K_,
StrideA_,
StrideB_,
StrideDs_,
StrideC_,
k_batch_,
a_element_op_,
b_element_op_,
c_element_op_},
Nr{Nr_},
Kr{Kr_}
{
}
index_t Nr;
index_t Kr;
};
int GetPreShuffleParameters() override { return NPerXDL; }
@@ -540,7 +582,16 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
{
return false;
}
const auto karg = dynamic_cast<const DeviceArgument*>(&arg);
if(NPadding && (karg->Nr != GridwiseGemm::CalculateBNShufflePadded(arg.N)))
{
return false;
}
if(KPadding && (karg->Kr != GridwiseGemm::CalculateBKShufflePadded(arg.K)))
{
return false;
}
return GridwiseGemm::CheckValidity(arg);
}
@@ -568,23 +619,23 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
index_t Nr,
index_t Kr)
{
return Argument{static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
p_ds,
static_cast<CDataType*>(p_c),
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideC,
KBatch,
a_element_op,
b_element_op,
c_element_op,
Nr,
Kr};
return DeviceArgument{static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
p_ds,
static_cast<CDataType*>(p_c),
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideC,
KBatch,
a_element_op,
b_element_op,
c_element_op,
Nr,
Kr};
}
static auto MakeInvoker() { return Invoker{}; }
@@ -608,23 +659,23 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
index_t Nr,
index_t Kr) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
p_ds,
static_cast<CDataType*>(p_c),
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideC,
KBatch,
a_element_op,
b_element_op,
c_element_op,
Nr,
Kr);
return std::make_unique<DeviceArgument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
p_ds,
static_cast<CDataType*>(p_c),
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideC,
KBatch,
a_element_op,
b_element_op,
c_element_op,
Nr,
Kr);
}
// polymorphic

View File

@@ -596,9 +596,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
index_t StrideB_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideC_,
index_t KBatch_,
index_t Nr_,
index_t Kr_)
index_t KBatch_)
: M{M_},
N{N_},
K{K_},
@@ -616,9 +614,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
MBlock{CalculateMBlock(M_)},
NBlock{CalculateNBlock(N_)},
BN0Shuffled{CalculateBN0Shuffled(NPadding ? CalculateBNShufflePadded(N_) : N_)},
BK0Shuffled{CalculateBK0Shuffled(KPadding ? CalculateBKShufflePadded(K_) : K_)},
Nr{Nr_},
Kr{Kr_}
BK0Shuffled{CalculateBK0Shuffled(KPadding ? CalculateBKShufflePadded(K_) : K_)}
{
}
@@ -660,8 +656,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
// FOR PRESHUFFLE ONLY
index_t BN0Shuffled;
index_t BK0Shuffled;
index_t Nr;
index_t Kr;
};
// Argument
@@ -681,10 +675,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
index_t k_batch_,
AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_,
CElementwiseOperation c_element_op_,
index_t Nr_,
index_t Kr_)
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideC_, k_batch_, Nr_, Kr_},
CElementwiseOperation c_element_op_)
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideC_, k_batch_},
p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_ds_grid{},
@@ -952,16 +944,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
return false;
}
if(NPadding && (karg.Nr != CalculateBNShufflePadded(karg.N)))
{
return false;
}
if(KPadding && (karg.Kr != CalculateBKShufflePadded(karg.K)))
{
return false;
}
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||

View File

@@ -138,6 +138,7 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl;
std::cout << "rotating count: " << rotating_count << std::endl;
std::cout << "verification: " << do_verification << std::endl;
switch(init_method)
{
@@ -325,7 +326,10 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
<< std::endl;
}
}
if(!pass)
{
continue;
}
std::string op_name = op_ptr->GetTypeString();
float ave_time = invoker_ptr->Run(argument_ptr.get(),