mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
move kernel parametor to host for performance
This commit is contained in:
@@ -138,6 +138,48 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
|
||||
LDSTypeB>;
|
||||
|
||||
using Argument = typename GridwiseGemm::Argument;
|
||||
struct DeviceArgument : public Argument
|
||||
{
|
||||
__host__ DeviceArgument(const ADataType* p_a_grid_,
|
||||
const BDataType* p_b_grid_,
|
||||
std::array<const void*, NumDTensor> p_ds_grid_,
|
||||
CDataType* p_c_grid_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t StrideA_,
|
||||
index_t StrideB_,
|
||||
std::array<index_t, NumDTensor> StrideDs_,
|
||||
index_t StrideC_,
|
||||
index_t k_batch_,
|
||||
AElementwiseOperation a_element_op_,
|
||||
BElementwiseOperation b_element_op_,
|
||||
CElementwiseOperation c_element_op_,
|
||||
index_t Nr_,
|
||||
index_t Kr_)
|
||||
: Argument{p_a_grid_,
|
||||
p_b_grid_,
|
||||
p_ds_grid_,
|
||||
p_c_grid_,
|
||||
M_,
|
||||
N_,
|
||||
K_,
|
||||
StrideA_,
|
||||
StrideB_,
|
||||
StrideDs_,
|
||||
StrideC_,
|
||||
k_batch_,
|
||||
a_element_op_,
|
||||
b_element_op_,
|
||||
c_element_op_},
|
||||
Nr{Nr_},
|
||||
Kr{Kr_}
|
||||
{
|
||||
}
|
||||
|
||||
index_t Nr;
|
||||
index_t Kr;
|
||||
};
|
||||
|
||||
int GetPreShuffleParameters() override { return NPerXDL; }
|
||||
|
||||
@@ -540,7 +582,16 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
|
||||
{
|
||||
return false;
|
||||
}
|
||||
const auto karg = dynamic_cast<const DeviceArgument*>(&arg);
|
||||
if(NPadding && (karg->Nr != GridwiseGemm::CalculateBNShufflePadded(arg.N)))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(KPadding && (karg->Kr != GridwiseGemm::CalculateBKShufflePadded(arg.K)))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
return GridwiseGemm::CheckValidity(arg);
|
||||
}
|
||||
|
||||
@@ -568,23 +619,23 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
|
||||
index_t Nr,
|
||||
index_t Kr)
|
||||
{
|
||||
return Argument{static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
p_ds,
|
||||
static_cast<CDataType*>(p_c),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideC,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
Nr,
|
||||
Kr};
|
||||
return DeviceArgument{static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
p_ds,
|
||||
static_cast<CDataType*>(p_c),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideC,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
Nr,
|
||||
Kr};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
@@ -608,23 +659,23 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
|
||||
index_t Nr,
|
||||
index_t Kr) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
p_ds,
|
||||
static_cast<CDataType*>(p_c),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideC,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
Nr,
|
||||
Kr);
|
||||
return std::make_unique<DeviceArgument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
p_ds,
|
||||
static_cast<CDataType*>(p_c),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideC,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
Nr,
|
||||
Kr);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
|
||||
@@ -596,9 +596,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
index_t StrideB_,
|
||||
std::array<index_t, NumDTensor> StrideDs_,
|
||||
index_t StrideC_,
|
||||
index_t KBatch_,
|
||||
index_t Nr_,
|
||||
index_t Kr_)
|
||||
index_t KBatch_)
|
||||
: M{M_},
|
||||
N{N_},
|
||||
K{K_},
|
||||
@@ -616,9 +614,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
MBlock{CalculateMBlock(M_)},
|
||||
NBlock{CalculateNBlock(N_)},
|
||||
BN0Shuffled{CalculateBN0Shuffled(NPadding ? CalculateBNShufflePadded(N_) : N_)},
|
||||
BK0Shuffled{CalculateBK0Shuffled(KPadding ? CalculateBKShufflePadded(K_) : K_)},
|
||||
Nr{Nr_},
|
||||
Kr{Kr_}
|
||||
BK0Shuffled{CalculateBK0Shuffled(KPadding ? CalculateBKShufflePadded(K_) : K_)}
|
||||
{
|
||||
}
|
||||
|
||||
@@ -660,8 +656,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
// FOR PRESHUFFLE ONLY
|
||||
index_t BN0Shuffled;
|
||||
index_t BK0Shuffled;
|
||||
index_t Nr;
|
||||
index_t Kr;
|
||||
};
|
||||
|
||||
// Argument
|
||||
@@ -681,10 +675,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
index_t k_batch_,
|
||||
AElementwiseOperation a_element_op_,
|
||||
BElementwiseOperation b_element_op_,
|
||||
CElementwiseOperation c_element_op_,
|
||||
index_t Nr_,
|
||||
index_t Kr_)
|
||||
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideC_, k_batch_, Nr_, Kr_},
|
||||
CElementwiseOperation c_element_op_)
|
||||
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideC_, k_batch_},
|
||||
p_a_grid{p_a_grid_},
|
||||
p_b_grid{p_b_grid_},
|
||||
p_ds_grid{},
|
||||
@@ -952,16 +944,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
return false;
|
||||
}
|
||||
|
||||
if(NPadding && (karg.Nr != CalculateBNShufflePadded(karg.N)))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(KPadding && (karg.Kr != CalculateBKShufflePadded(karg.K)))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
|
||||
|
||||
@@ -138,6 +138,7 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
|
||||
std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl;
|
||||
std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl;
|
||||
std::cout << "rotating count: " << rotating_count << std::endl;
|
||||
std::cout << "verification: " << do_verification << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
@@ -325,7 +326,10 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
if(!pass)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
float ave_time = invoker_ptr->Run(argument_ptr.get(),
|
||||
|
||||
Reference in New Issue
Block a user