Refine pipeline name

This commit is contained in:
rocking
2024-10-24 20:42:40 +00:00
parent c89d8ca95f
commit 871af334d1
4 changed files with 8 additions and 8 deletions

View File

@@ -35,9 +35,9 @@ struct Layernorm2dFwdPipelineOnePass
static constexpr const char* name = []() {
if constexpr(kNeedCrossWarpSync)
return "bpr"; // block per row
return "bpr_op"; // block per row
else
return "wpr"; // warp per row
return "wpr_op"; // warp per row
}();
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()

View File

@@ -35,9 +35,9 @@ struct Layernorm2dFwdPipelineTwoPass
static constexpr const char* name = []() {
if constexpr(kNeedCrossWarpSync)
return "bpr"; // block per row
return "bpr_tp"; // block per row
else
return "wpr"; // warp per row
return "wpr_tp"; // warp per row
}();
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()

View File

@@ -31,9 +31,9 @@ struct Rmsnorm2dFwdPipelineOnePass
static constexpr const char* name = []() {
if constexpr(kNeedCrossWarpSync)
return "bpr"; // block per row
return "bpr_op"; // block per row
else
return "wpr"; // warp per row
return "wpr_op"; // warp per row
}();
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()

View File

@@ -31,9 +31,9 @@ struct Rmsnorm2dFwdPipelineTwoPass
static constexpr const char* name = []() {
if constexpr(kNeedCrossWarpSync)
return "bpr"; // block per row
return "bpr_tp"; // block per row
else
return "wpr"; // warp per row
return "wpr_tp"; // warp per row
}();
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()