Fix gemm-softmax-gemm-permute padding cases (#409)

* fix example; make padding on by default in example; fix argument checks

* fix Gemm1KPacK which has since regressed from PR #399
This commit is contained in:
Anthony Chang
2022-09-08 22:27:50 +08:00
committed by GitHub
parent ce74cea407
commit d6709dc373
4 changed files with 30 additions and 15 deletions

View File

@@ -693,9 +693,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
}
// Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0);
const index_t c_m = arg.c_grid_desc_g_m_n_.GetLength(I1);
const index_t c_gemm1n = arg.c_grid_desc_g_m_n_.GetLength(I2);
const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
const index_t c_m = arg.c_grid_desc_m_n_.GetLength(I0);
const index_t c_gemm1n = arg.c_grid_desc_m_n_.GetLength(I1);
const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_gemm1n = arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1);
if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n))