Add unit tests for grouped gemm two stage (#1256)

* add unit tests for grouped gemm two stage

* add reviewers suggestions

---------

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
This commit is contained in:
jakpiase
2024-05-15 10:03:39 +02:00
committed by GitHub
parent 7843a8a7fb
commit 3e3471d5d2
5 changed files with 189 additions and 6 deletions

View File

@@ -337,6 +337,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
elementwise_d_grid_descs_m_n_.reserve(group_count_);
ds_grid_pointer_.reserve(group_count_);
group_grid_size_.reserve(group_count_);
e_ptrs_.reserve(group_count_);
for(std::size_t i = 0; i < gemm_descs.size(); ++i)
{
@@ -380,7 +381,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
const index_t block_end = grid_size_ + grid_size_grp;
grid_size_ += grid_size_grp;
group_grid_size_[i] = grid_size_grp;
group_grid_size_.push_back(grid_size_grp);
// block-to-e-tile map
auto grouped_block_2_ctile_map =
GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
@@ -421,9 +422,9 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
elementwise_c_grid_descs_m_n_.push_back(c_grid_desc_m_n);
elementwise_d_grid_descs_m_n_.push_back(ds_grid_desc_m_n);
ds_grid_pointer_.push_back(p_ds_grid);
// Store a copy of E pointers for elementwise kernel destination
e_ptrs_.push_back(p_Es[i]);
}
// Store a copy of E pointers for elementwise kernel destination
e_ptrs_ = p_Es;
}
/**
@@ -774,13 +775,13 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(dev_gemm_args),
arg.group_count_,
arg.gemm_kernel_args_.size(),
arg.a_element_op_,
arg.b_element_op_,
PassThrough{});
// Elementwise kernels
for(int i = 0; i < arg.group_count_; ++i)
for(size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
{
time += launch_and_time_kernel(
stream_config,