mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
Pass gemm_descs for grouped gemm via __constant__ buff (#232)
* moved gemm_descs_args into const buff * use CK_CONSTANT_ADDRESS_SPACE instead of global constant * clean * moved hipMemAlloc outside of deviceOp * add SetWorkSpacePointer * fix ignore
This commit is contained in:
@@ -42,6 +42,8 @@ struct BaseOperator
|
||||
|
||||
virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; }
|
||||
|
||||
virtual void SetWorkSpacePointer(BaseArgument*, void*) const {}
|
||||
|
||||
virtual ~BaseOperator() {}
|
||||
};
|
||||
|
||||
|
||||
@@ -24,57 +24,33 @@ template <typename GridwiseGemm,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
bool HasMainKBlockLoop,
|
||||
index_t MaxGroupCount>
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_grouped_gemm_xdlops_v2r3(
|
||||
const StaticallyIndexedArray<GemmDesc, MaxGroupCount> gemm_descs,
|
||||
const index_t group_count,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op)
|
||||
kernel_grouped_gemm_xdlops_v2r3(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
|
||||
const index_t group_count,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
const index_t block_id = get_block_1d_id();
|
||||
|
||||
#if 1
|
||||
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
|
||||
if(block_id >= gemm_descs[i].BlockStart_ && block_id < gemm_descs[i].BlockEnd_ &&
|
||||
i < group_count)
|
||||
{
|
||||
auto group_id = i;
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(
|
||||
gemm_descs[group_id].a_ptr,
|
||||
gemm_descs[group_id].b_ptr,
|
||||
gemm_descs[group_id].c_ptr,
|
||||
p_shared,
|
||||
gemm_descs[group_id].a_grid_desc_k0_m_k1_,
|
||||
gemm_descs[group_id].b_grid_desc_k0_n_k1_,
|
||||
gemm_descs[group_id].c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
gemm_descs[group_id].grouped_gemm_block_2_ctile_map_);
|
||||
}
|
||||
});
|
||||
#else
|
||||
const auto gemm_desc_ptr = reinterpret_cast<const GemmDesc*>(&gemm_descs);
|
||||
const auto gemm_desc_ptr =
|
||||
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
|
||||
|
||||
index_t group_id = 0;
|
||||
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
|
||||
group_id = (block_id >= gemm_descs[i].BlockStart && block_id < gemm_descs[i].BlockEnd &&
|
||||
i < group_count)
|
||||
? i
|
||||
: group_id;
|
||||
});
|
||||
|
||||
const index_t block_id_grp = block_id - gemm_desc_ptr[group_id].BlockStart;
|
||||
for(index_t i = 0; i < group_count; i++)
|
||||
{
|
||||
group_id =
|
||||
(block_id >= gemm_desc_ptr[i].BlockStart_ && block_id < gemm_desc_ptr[i].BlockEnd_)
|
||||
? i
|
||||
: group_id;
|
||||
}
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(
|
||||
gemm_desc_ptr[group_id].a_ptr,
|
||||
@@ -87,11 +63,9 @@ __global__ void
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
gemm_desc_ptr[group_id].block_2_ctile_map_,
|
||||
block_id_grp);
|
||||
#endif
|
||||
gemm_desc_ptr[group_id].grouped_gemm_block_2_ctile_map_);
|
||||
#else
|
||||
ignore = gemm_descs;
|
||||
ignore = gemm_descs_const;
|
||||
ignore = group_count;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
@@ -388,6 +362,8 @@ struct DeviceGroupedGemmXdl
|
||||
{
|
||||
grid_size_ = 0;
|
||||
|
||||
gemm_descs_args_workspace_ = nullptr;
|
||||
|
||||
group_count_ = ck::type_convert<ck::index_t>(gemm_shapes.size());
|
||||
|
||||
if(!(group_count_ == ck::type_convert<ck::index_t>(p_a.size()) &&
|
||||
@@ -461,6 +437,8 @@ struct DeviceGroupedGemmXdl
|
||||
|
||||
std::vector<GemmDescKernelArg> gemm_desc_kernel_arg_;
|
||||
|
||||
void* gemm_descs_args_workspace_;
|
||||
|
||||
index_t grid_size_;
|
||||
};
|
||||
|
||||
@@ -471,49 +449,49 @@ struct DeviceGroupedGemmXdl
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
StaticallyIndexedArray<GemmDescKernelArg, MaxGroupCount> gemm_desc_kernel_args;
|
||||
|
||||
bool has_main_k_block_loop = true;
|
||||
|
||||
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
|
||||
if(i < arg.gemm_desc_kernel_arg_.size())
|
||||
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
|
||||
{
|
||||
std::cout << "group: " << i << " arg.a_grid_desc_k0_m_k1_{"
|
||||
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I0) << ", "
|
||||
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
|
||||
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I2) << "}";
|
||||
|
||||
std::cout << ", arg.b_grid_desc_k0_n_k1_{"
|
||||
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I0) << ", "
|
||||
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
|
||||
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I2) << "}";
|
||||
|
||||
std::cout << ", arg.c_grid_desc_m_n_{ "
|
||||
<< arg.gemm_desc_kernel_arg_[i].c_grid_desc_m_n_.GetLength(I0) << ", "
|
||||
<< arg.gemm_desc_kernel_arg_[i].c_grid_desc_m_n_.GetLength(I1) << "}"
|
||||
<< std::endl;
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(
|
||||
arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_,
|
||||
arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_,
|
||||
arg.gemm_desc_kernel_arg_[i].c_grid_desc_m_n_,
|
||||
arg.gemm_desc_kernel_arg_[i].grouped_gemm_block_2_ctile_map_))
|
||||
{
|
||||
gemm_desc_kernel_args(i) = arg.gemm_desc_kernel_arg_[i];
|
||||
|
||||
std::cout << "group: " << i << " arg.a_grid_desc_k0_m_k1_{"
|
||||
<< gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I0) << ", "
|
||||
<< gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
|
||||
<< gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I2) << "}";
|
||||
|
||||
std::cout << ", arg.b_grid_desc_k0_n_k1_{"
|
||||
<< gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_.GetLength(I0) << ", "
|
||||
<< gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
|
||||
<< gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_.GetLength(I2) << "}";
|
||||
|
||||
std::cout << ", arg.c_grid_desc_m_n_{ "
|
||||
<< gemm_desc_kernel_args[i].c_grid_desc_m_n_.GetLength(I0) << ", "
|
||||
<< gemm_desc_kernel_args[i].c_grid_desc_m_n_.GetLength(I1) << "}"
|
||||
<< std::endl;
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(
|
||||
gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_,
|
||||
gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_,
|
||||
gemm_desc_kernel_args[i].c_grid_desc_m_n_,
|
||||
gemm_desc_kernel_args[i].grouped_gemm_block_2_ctile_map_))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
|
||||
}
|
||||
|
||||
const auto K = gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I0) *
|
||||
gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I2);
|
||||
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K) != has_main_k_block_loop)
|
||||
{
|
||||
throw std::runtime_error("wrong! not all gemm has_main_k_block_loop");
|
||||
}
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
|
||||
}
|
||||
});
|
||||
|
||||
const auto K = arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I0) *
|
||||
arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I2);
|
||||
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K) != has_main_k_block_loop)
|
||||
{
|
||||
throw std::runtime_error("wrong! not all gemm has_main_k_block_loop");
|
||||
}
|
||||
}
|
||||
|
||||
hipGetErrorString(
|
||||
hipMemcpy(arg.gemm_descs_args_workspace_,
|
||||
arg.gemm_desc_kernel_arg_.data(),
|
||||
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmDescKernelArg),
|
||||
hipMemcpyHostToDevice));
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
@@ -523,23 +501,23 @@ struct DeviceGroupedGemmXdl
|
||||
kernel_grouped_gemm_xdlops_v2r3<GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
remove_reference_t<GemmDescKernelArg>,
|
||||
GemmDescKernelArg,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
true,
|
||||
MaxGroupCount>;
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(arg.grid_size_),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
gemm_desc_kernel_args,
|
||||
arg.gemm_desc_kernel_arg_.size(),
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_);
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
dim3(arg.grid_size_),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
cast_pointer_to_constant_address_space(arg.gemm_descs_args_workspace_),
|
||||
arg.gemm_desc_kernel_arg_.size(),
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -547,23 +525,23 @@ struct DeviceGroupedGemmXdl
|
||||
kernel_grouped_gemm_xdlops_v2r3<GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
remove_reference_t<GemmDescKernelArg>,
|
||||
GemmDescKernelArg,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
false,
|
||||
MaxGroupCount>;
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(arg.grid_size_),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
gemm_desc_kernel_args,
|
||||
arg.gemm_desc_kernel_arg_.size(),
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_);
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
dim3(arg.grid_size_),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
cast_pointer_to_constant_address_space(arg.gemm_descs_args_workspace_),
|
||||
arg.gemm_desc_kernel_arg_.size(),
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
@@ -652,6 +630,16 @@ struct DeviceGroupedGemmXdl
|
||||
|
||||
return str.str();
|
||||
}
|
||||
|
||||
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
|
||||
{
|
||||
return dynamic_cast<const Argument*>(p_arg)->group_count_ * sizeof(GemmDescKernelArg);
|
||||
}
|
||||
|
||||
void SetWorkSpacePointer(BaseArgument* p_arg, void* workspace_ptr) const override
|
||||
{
|
||||
dynamic_cast<Argument*>(p_arg)->gemm_descs_args_workspace_ = workspace_ptr;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
|
||||
Reference in New Issue
Block a user