(M, N, K)=(128, 128, 128) function failed.

This commit is contained in:
mtgu0705
2025-05-11 10:16:26 +00:00
parent 0bddd63d9c
commit 726551dec4

View File

@@ -235,9 +235,9 @@ int main(int argc, char* argv[])
b_k_n_scale.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
}
DeviceMem a_device_buf(sizeof(A0DataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem a_device_buf(sizeof(A0DataType) * a_m_k.mDesc.GetElementSpaceSize() / 2);
DeviceMem a_scale_device_buf(sizeof(A1DataType) * a_m_k_scale.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(B0DataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(B0DataType) * b_k_n.mDesc.GetElementSpaceSize() / 2);
DeviceMem b_scale_device_buf(sizeof(B1DataType) * b_k_n_scale.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());