mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
minor fix
This commit is contained in:
@@ -410,18 +410,18 @@ def get_blobs() -> tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
||||
pipelines = []
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
for mask, bias, lse in itertools.product(MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
#if hdim == 256:
|
||||
if True:
|
||||
if hdim == 256:
|
||||
# if True:
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, mask))
|
||||
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, mask))
|
||||
#else:
|
||||
# pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, mask))
|
||||
# pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, mask))
|
||||
# pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, mask))
|
||||
# pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, mask))
|
||||
else:
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, mask))
|
||||
elif dtype in ['fp8', 'bf8']:
|
||||
# no need lse kernels
|
||||
for mask, bias in itertools.product(MASK_MAP.keys(), ["t", "f"]):
|
||||
|
||||
@@ -31,7 +31,7 @@ get_slice_tile(const tile_window_with_static_lengths<BottomTensorView_, WindowLe
|
||||
|
||||
constexpr auto slice_lengths = slice_ends - slice_begins;
|
||||
|
||||
return make_tile_window(tile.GetBottomTensorView(),
|
||||
return make_tile_window(tile.get_bottom_tensor_view(),
|
||||
sequence_to_tuple_of_number(slice_lengths),
|
||||
to_multi_index(slice_begins));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user