mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
MHC V3 with gemm pipeline
This commit is contained in:
@@ -48,9 +48,9 @@ int main()
|
||||
d_phi_mem.ToDevice(h_phi.data());
|
||||
d_output_mem.ToDevice(h_output.data());
|
||||
|
||||
// Define block shape
|
||||
using BlockShape = ck_tile::Generic2dBlockShape<ck_tile::sequence<1, 256>,
|
||||
ck_tile::sequence<1, 256>,
|
||||
// Define block shape - must match BlockGemmShape thread count (2 warps × 64 = 128 threads)
|
||||
using BlockShape = ck_tile::Generic2dBlockShape<ck_tile::sequence<1, 128>,
|
||||
ck_tile::sequence<1, 128>,
|
||||
ck_tile::sequence<1, 1>>;
|
||||
|
||||
using Problem = ck_tile::MHCProblem<float, float, float, BlockShape>;
|
||||
|
||||
@@ -49,9 +49,9 @@ int main()
|
||||
d_phi_mem.ToDevice(h_phi.data());
|
||||
d_output_mem.ToDevice(h_output.data());
|
||||
|
||||
// Define block shape
|
||||
using BlockShape = ck_tile::Generic2dBlockShape<ck_tile::sequence<1, 256>,
|
||||
ck_tile::sequence<1, 256>,
|
||||
// Define block shape - must match BlockGemmShape thread count (2 warps × 64 = 128 threads)
|
||||
using BlockShape = ck_tile::Generic2dBlockShape<ck_tile::sequence<1, 128>,
|
||||
ck_tile::sequence<1, 128>,
|
||||
ck_tile::sequence<1, 1>>;
|
||||
|
||||
using Problem = ck_tile::MHCProblem<float, float, float, BlockShape>;
|
||||
|
||||
@@ -49,9 +49,9 @@ int main()
|
||||
d_phi_mem.ToDevice(h_phi.data());
|
||||
d_output_mem.ToDevice(h_output.data());
|
||||
|
||||
// Define block shape
|
||||
using BlockShape = ck_tile::Generic2dBlockShape<ck_tile::sequence<1, 256>,
|
||||
ck_tile::sequence<1, 256>,
|
||||
// Define block shape - must match BlockGemmShape thread count (2 warps × 64 = 128 threads)
|
||||
using BlockShape = ck_tile::Generic2dBlockShape<ck_tile::sequence<1, 128>,
|
||||
ck_tile::sequence<1, 128>,
|
||||
ck_tile::sequence<1, 1>>;
|
||||
|
||||
using Problem = ck_tile::MHCProblem<float, float, float, BlockShape>;
|
||||
|
||||
Reference in New Issue
Block a user