mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 10:59:55 +00:00
Gemm+Bilinear (#316)
* refactor
* update example
* update example
* gemm bilinear
* clean
* update
[ROCm/composable_kernel commit: 9e4429f9c3]
This commit is contained in:
@@ -27,8 +27,9 @@ enum struct GemmDataType
|
||||
|
||||
int profile_batched_gemm(int argc, char* argv[])
|
||||
{
|
||||
if(argc != 15)
|
||||
if(argc != 18)
|
||||
{
|
||||
// clang-format off
|
||||
printf("arg1: tensor operation (batched_gemm: Batched GEMM)\n");
|
||||
printf("arg2: data type (0: fp32; 1: fp16, 2: bf16, 3: int8)\n");
|
||||
printf("arg3: matrix layout (0: A[g, m, k] * B[g, k, n] = C[g, m, n];\n");
|
||||
@@ -39,7 +40,8 @@ int profile_batched_gemm(int argc, char* argv[])
|
||||
printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n");
|
||||
printf("arg6: print tensor value (0: no; 1: yes)\n");
|
||||
printf("arg7: time kernel (0=n0, 1=yes)\n");
|
||||
printf("arg8 to 14: M, N, K, StrideA, StrideB, StrideC, BatchCount\n");
|
||||
printf("arg8 to 17: M, N, K, StrideA, StrideB, StrideC, BatchStrideA, BatchStrideB, BatchStrideC, BatchCount\n");
|
||||
// clang-format on
|
||||
exit(1);
|
||||
}
|
||||
|
||||
@@ -58,7 +60,11 @@ int profile_batched_gemm(int argc, char* argv[])
|
||||
const int StrideB = std::stoi(argv[12]);
|
||||
const int StrideC = std::stoi(argv[13]);
|
||||
|
||||
const int BatchCount = std::stoi(argv[14]);
|
||||
const int BatchStrideA = std::stoi(argv[14]);
|
||||
const int BatchStrideB = std::stoi(argv[15]);
|
||||
const int BatchStrideC = std::stoi(argv[16]);
|
||||
|
||||
const int BatchCount = std::stoi(argv[17]);
|
||||
|
||||
using F32 = float;
|
||||
using F16 = ck::half_t;
|
||||
@@ -90,9 +96,13 @@ int profile_batched_gemm(int argc, char* argv[])
|
||||
const int StrideB_ = (StrideB < 0) ? DefaultStrideB : StrideB;
|
||||
const int StrideC_ = (StrideC < 0) ? DefaultStrideC : StrideC;
|
||||
|
||||
const int BatchStrideA = (ck::is_same_v<ALayout, Row> ? M : K) * StrideA_;
|
||||
const int BatchStrideB = (ck::is_same_v<BLayout, Row> ? K : N) * StrideB_;
|
||||
const int BatchStrideC = (ck::is_same_v<CLayout, Row> ? M : N) * StrideC_;
|
||||
const int DefaultBatchStrideA = (ck::is_same_v<ALayout, Row> ? M : K) * StrideA_;
|
||||
const int DefaultBatchStrideB = (ck::is_same_v<BLayout, Row> ? K : N) * StrideB_;
|
||||
const int DefaultBatchStrideC = (ck::is_same_v<CLayout, Row> ? M : N) * StrideC_;
|
||||
|
||||
const int BatchStrideA_ = (BatchStrideA < 0) ? DefaultBatchStrideA : BatchStrideA;
|
||||
const int BatchStrideB_ = (BatchStrideB < 0) ? DefaultBatchStrideB : BatchStrideB;
|
||||
const int BatchStrideC_ = (BatchStrideC < 0) ? DefaultBatchStrideC : BatchStrideC;
|
||||
|
||||
bool pass = ck::profiler::
|
||||
profile_batched_gemm_impl<ADataType, BDataType, CDataType, ALayout, BLayout, CLayout>(
|
||||
@@ -103,9 +113,9 @@ int profile_batched_gemm(int argc, char* argv[])
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
BatchStrideA,
|
||||
BatchStrideB,
|
||||
BatchStrideC,
|
||||
BatchStrideA_,
|
||||
BatchStrideB_,
|
||||
BatchStrideC_,
|
||||
StrideA_,
|
||||
StrideB_,
|
||||
StrideC_,
|
||||
|
||||
Reference in New Issue
Block a user