Attention with output permutation (#370)

* comment on specialization for TensorSpecialization::Packed

* gemm_softmax_gemm with output permutation

* scaling

* refactor MatrixPadder; rename to GemmPadder

* remove old sanity check

* restore original gemm_softmax_gemm

* revise comment in gemm_softmax_gemm example

* use GetElementSpaceSize()

* remove extra header

* typo

* remove archaic DeviceOpPtr
This commit is contained in:
Anthony Chang
2022-08-24 03:52:56 +08:00
committed by GitHub
parent 6091458300
commit e0d8806ca1
11 changed files with 1501 additions and 74 deletions

View File

@@ -249,8 +249,6 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
return false;
}
assert(num_gemm1_k_outer_loop * num_gemm1_k_inner_loop == N / Gemm1KPerBlock);
if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
{
return false;

View File

@@ -245,8 +245,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
return false;
}
assert(num_gemm1_k_outer_loop * num_gemm1_k_inner_loop == N / Gemm1KPerBlock);
if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
{
return false;