[rocm-libraries] ROCm/rocm-libraries#4638 (commit 305ec71)

[ck] Support VGPR estimate in GridwiseGemm_wmma_cshuffle_v3
 (#4638)

1. Add GetEstimateVgprCount to estimate the VGPR usage in
GridwiseGemm_wmma_cshuffle_v3
2. Add IsValidCompilationParameter to disable kernel which use too many
vgprs.
- Currently, the threashold is AvailableVgprCount * 1.25
3. Modify examples to avoid test is disabled on gfx11

It is port from internal repo
PR[#192](https://github.com/ROCm/composable_kernel/issues/192)

## Motivation

<!-- Explain the purpose of this PR and the goals it aims to achieve.
-->

## Technical Details

<!-- Explain the changes along with any relevant GitHub links. -->

## Test Plan

<!-- Explain any relevant testing done to verify this PR. -->

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
linqunAMD
2026-02-20 15:57:18 +00:00
committed by assistant-librarian[bot]
parent 7689090739
commit 29781f2ac4
4 changed files with 465 additions and 374 deletions

View File

@@ -1785,12 +1785,11 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3
p_ds_grid_dummy[i] = nullptr;
StrideDs_dummy[i] = I0;
});
for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++)
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); i++)
{
const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I1);
const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_container_[i].GetLength(I1);
const index_t GemmK = arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I0) *
arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I2);
const index_t GemmM = arg.a_grid_desc_m_k_container_[i].GetLength(I0);
const index_t GemmN = arg.b_grid_desc_n_k_container_[i].GetLength(I0);
const index_t GemmK = arg.a_grid_desc_m_k_container_[i].GetLength(I1);
// Create gemm arguments with dummy values to check for validity
typename GridwiseGemmCTranspose::Argument gemm_arg{
std::array<const void*, 1>{nullptr}, // p_as_grid