mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 01:36:06 +00:00
Refactor threadwise copy using sfcurve (#101)
* add space_filling_curve * cleanup and move space_filling_curve into test * WIP: start refactoring threadwise_transfer_v1r3 * threadwise_copy works but needs further refactoring * add some comments * add SpaceFillingCurve::GetIndices() * minor changes * removed GetIndices; refactored GetDstCoordinateResetStep * add DynamicBuffer::Transfer, but Add is not tested * rebased agaist develop * threadwise_copy_v6r1/v6r2/v6r3 using space-filling curve start to work * minor changes * refactored threadcopy v3r1, v2; removed old implementations * clang-format * cleanup * fix a typo in v6r3 * format Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
@@ -69,7 +69,6 @@ struct gemmArgs
|
||||
int KBatch;
|
||||
};
|
||||
|
||||
|
||||
int test_gemm(const gemmArgs& args)
|
||||
{
|
||||
bool a_row_major, b_row_major, c_row_major;
|
||||
@@ -115,8 +114,10 @@ int test_gemm(const gemmArgs& args)
|
||||
|
||||
Tensor<float> a_m_k(f_host_tensor_descriptor(args.M, args.K, args.StrideA, a_row_major));
|
||||
Tensor<float> b_k_n(f_host_tensor_descriptor(args.K, args.N, args.StrideB, b_row_major));
|
||||
Tensor<float> c_m_n_host_result(f_host_tensor_descriptor(args.M, args.N, args.StrideC, c_row_major));
|
||||
Tensor<float> c_m_n_device_result(f_host_tensor_descriptor(args.M, args.N, args.StrideC, c_row_major));
|
||||
Tensor<float> c_m_n_host_result(
|
||||
f_host_tensor_descriptor(args.M, args.N, args.StrideC, c_row_major));
|
||||
Tensor<float> c_m_n_device_result(
|
||||
f_host_tensor_descriptor(args.M, args.N, args.StrideC, c_row_major));
|
||||
|
||||
// init data
|
||||
std::size_t num_thread = std::thread::hardware_concurrency();
|
||||
@@ -205,7 +206,7 @@ int test_gemm(const gemmArgs& args)
|
||||
else
|
||||
{
|
||||
std::cout << "test split k: Fail " << std::endl;
|
||||
error_code = -1; // test needs to report failure
|
||||
error_code = -1; // test needs to report failure
|
||||
}
|
||||
return error_code;
|
||||
}
|
||||
@@ -221,17 +222,17 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
else if(argc == 9)
|
||||
{
|
||||
const int layout = static_cast<GemmMatrixLayout>(std::stoi(argv[1]));
|
||||
const int layout = static_cast<GemmMatrixLayout>(std::stoi(argv[1]));
|
||||
|
||||
const int M = std::stoi(argv[2]);
|
||||
const int N = std::stoi(argv[3]);
|
||||
const int K = std::stoi(argv[4]);
|
||||
const int M = std::stoi(argv[2]);
|
||||
const int N = std::stoi(argv[3]);
|
||||
const int K = std::stoi(argv[4]);
|
||||
|
||||
const int StrideA = std::stoi(argv[5]);
|
||||
const int StrideB = std::stoi(argv[6]);
|
||||
const int StrideC = std::stoi(argv[7]);
|
||||
const int KBatch = std::stoi(argv[8]);
|
||||
test_cases = {{layout, M, N, K, StrideA, StrideB, StrideC, KBatch}};
|
||||
const int StrideA = std::stoi(argv[5]);
|
||||
const int StrideB = std::stoi(argv[6]);
|
||||
const int StrideC = std::stoi(argv[7]);
|
||||
const int KBatch = std::stoi(argv[8]);
|
||||
test_cases = {{layout, M, N, K, StrideA, StrideB, StrideC, KBatch}};
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -242,12 +243,11 @@ int main(int argc, char* argv[])
|
||||
printf("arg2 to 7: M, N, K, StrideA, StrideB, StrideC KBatch\n");
|
||||
return -1;
|
||||
}
|
||||
for(const auto& kinder: test_cases)
|
||||
for(const auto& kinder : test_cases)
|
||||
{
|
||||
const auto res = test_gemm(kinder);
|
||||
if(!res)
|
||||
return -1;
|
||||
return -1;
|
||||
}
|
||||
return 0;
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user