mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[rocm-libraries] ROCm/rocm-libraries#4872 (commit ca623f7)
[CK] Small improvements for grouped conv backward weight (#4872) ## Motivation Improvements for CK Tile convolution builder run function and atol/rtol calculations. ## Technical Details - Add preprocessing function for wrw when k_batch is larger than 1 for builder run function - Divide num acums by number of groups to get real number of accums ## Test Plan CI wrw tests ## Test Result pending ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. AICK-783
This commit is contained in:
committed by
assistant-librarian[bot]
parent
c90a363589
commit
eede24de0d
@@ -364,7 +364,8 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
|
||||
using AccDataType =
|
||||
std::conditional_t<std::is_same_v<ComputeType, int8_t>, int32_t, float>;
|
||||
|
||||
const index_t num_accums = output.GetElementSize() / conv_param.K_;
|
||||
const index_t num_accums =
|
||||
output.GetElementSize() / (conv_param.K_ * conv_param.G_);
|
||||
const index_t num_accums_split_k = split_k_value;
|
||||
// Get maximum accumulated value from reference
|
||||
const std::size_t tensor_size =
|
||||
@@ -437,7 +438,8 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
|
||||
ComputeTypeB>;
|
||||
using AccDataType =
|
||||
std::conditional_t<std::is_same_v<ComputeType, int8_t>, int32_t, float>;
|
||||
const index_t num_accums = output.GetElementSize() / conv_param.K_;
|
||||
const index_t num_accums =
|
||||
output.GetElementSize() / (conv_param.K_ * conv_param.G_);
|
||||
const index_t num_accums_split_k = split_k_value;
|
||||
// Calculate thresholds
|
||||
auto rtol =
|
||||
|
||||
Reference in New Issue
Block a user