mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
Add unit tests for grouped gemm two stage (#1256)
* add unit tests for grouped gemm two stage * add reviewers suggestions --------- Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
This commit is contained in:
@@ -337,6 +337,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
|
||||
elementwise_d_grid_descs_m_n_.reserve(group_count_);
|
||||
ds_grid_pointer_.reserve(group_count_);
|
||||
group_grid_size_.reserve(group_count_);
|
||||
e_ptrs_.reserve(group_count_);
|
||||
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); ++i)
|
||||
{
|
||||
@@ -380,7 +381,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
|
||||
const index_t block_end = grid_size_ + grid_size_grp;
|
||||
|
||||
grid_size_ += grid_size_grp;
|
||||
group_grid_size_[i] = grid_size_grp;
|
||||
group_grid_size_.push_back(grid_size_grp);
|
||||
// block-to-e-tile map
|
||||
auto grouped_block_2_ctile_map =
|
||||
GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
|
||||
@@ -421,9 +422,9 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
|
||||
elementwise_c_grid_descs_m_n_.push_back(c_grid_desc_m_n);
|
||||
elementwise_d_grid_descs_m_n_.push_back(ds_grid_desc_m_n);
|
||||
ds_grid_pointer_.push_back(p_ds_grid);
|
||||
// Store a copy of E pointers for elementwise kernel destination
|
||||
e_ptrs_.push_back(p_Es[i]);
|
||||
}
|
||||
// Store a copy of E pointers for elementwise kernel destination
|
||||
e_ptrs_ = p_Es;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -774,13 +775,13 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
cast_pointer_to_constant_address_space(dev_gemm_args),
|
||||
arg.group_count_,
|
||||
arg.gemm_kernel_args_.size(),
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
PassThrough{});
|
||||
|
||||
// Elementwise kernels
|
||||
for(int i = 0; i < arg.group_count_; ++i)
|
||||
for(size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
|
||||
{
|
||||
time += launch_and_time_kernel(
|
||||
stream_config,
|
||||
|
||||
Reference in New Issue
Block a user