mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 19:09:59 +00:00
Fix gemm-softmax-gemm-permute padding cases (#409)
* fix example; make padding on by default in example; fix argument checks
* fix Gemm1KPacK which has since regressed from PR #399
[ROCm/composable_kernel commit: d6709dc373]
This commit is contained in:
@@ -58,7 +58,7 @@ using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
|
||||
using B1ElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNOPadding;
|
||||
|
||||
using DeviceGemmInstance =
|
||||
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<
|
||||
@@ -77,7 +77,7 @@ using DeviceGemmInstance =
|
||||
Acc0ElementOp,
|
||||
B1ElementOp,
|
||||
CElementOp,
|
||||
GemmDefault,
|
||||
GemmSpec,
|
||||
1,
|
||||
256,
|
||||
128, // MPerBlock
|
||||
@@ -166,8 +166,6 @@ int main(int argc, char* argv[])
|
||||
// C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3])
|
||||
ck::index_t G0 = 7;
|
||||
ck::index_t G1 = 13;
|
||||
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
|
||||
std::vector<ck::index_t> c_gs_ms_os_strides{M * G1 * O, O, G1 * O, 1};
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
@@ -204,6 +202,9 @@ int main(int argc, char* argv[])
|
||||
exit(0);
|
||||
}
|
||||
|
||||
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
|
||||
std::vector<ck::index_t> c_gs_ms_os_strides{M * G1 * O, O, G1 * O, 1};
|
||||
|
||||
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
|
||||
const int DefaultStrideB0 = ck::is_same_v<B0Layout, Row> ? N : K;
|
||||
const int DefaultStrideB1 = ck::is_same_v<B1Layout, Row> ? O : N;
|
||||
|
||||
Reference in New Issue
Block a user