Files
custom_flashinfer/docs/api/logits_processor.rst
Shanli Xing 44f7d0b3db doc: fix LogitsPipe example (#1110)
<!-- .github/pull_request_template.md -->

## 📌 Description
Fix the example given by the logits processor documentation.

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

###  Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->
2025-06-02 23:26:17 -07:00

119 lines
2.7 KiB
ReStructuredText

.. _apilogitsprocessor:
flashinfer.logits_processor
===========================
A declarative, pluggable framework for building processing pipelines for LLM outputs.
.. currentmodule:: flashinfer.logits_processor
Pipeline Construction
---------------------
Use :class:`LogitsPipe` to create processing pipelines:
.. code-block:: python
import torch
from flashinfer.logits_processor import LogitsPipe, Temperature, Softmax, TopP, Sample
# Create a pipeline
pipe = LogitsPipe([
Temperature(), # Scale logits by temperature
Softmax(), # Convert logits to probabilities
TopP(), # Apply top-p filtering
Sample() # Sample from the distribution
])
# Apply the pipeline
batch_size = 4
vocab_size = 5
logits = torch.randn(batch_size, vocab_size, device="cuda")
output_ids = pipe(logits, temperature=0.7, top_p=0.9)
Pipeline
--------
.. autosummary::
:toctree: ../generated
LogitsPipe
Processors
----------
.. autosummary::
:toctree: ../generated
LogitsProcessor
Temperature
Softmax
TopK
TopP
MinP
Sample
Types
-----
.. autosummary::
:toctree: ../generated
TensorType
TaggedTensor
Customization Features
-------------
Custom Logits Processor
^^^^^^^^^^^^^^^^^^^^^^^
You can create your own logits processor by subclassing :class:`LogitsProcessor`:
.. code-block:: python
class CustomLogitsProcessor(LogitsProcessor):
def __init__(self, **params: Any):
super().__init__(**params)
def legalize(self, input_type: TensorType) -> List["Op"]:
return [CustomOp(**self.params)]
class CustomOp(Op):
# Define the input and output tensor types
IN = TensorType.LOGITS
OUT = TensorType.LOGITS
def __call__(self, tensor: TaggedTensor, **kwargs: Any) -> TaggedTensor:
pass
pipe = LogitsPipe([CustomLogitsProcessor()]) # The pipe will be compiled into [CustomOp]
Custom Fusion Rules
^^^^^^^^^^^^^^^^^^^
You can register custom fusion rules to optimize specific processor combinations:
.. code-block:: python
def custom_fusion_guard(window: List[Op]) -> bool:
# Whether the fusion should be applied
return True
def build_custom_fusion(window: List[Op]) -> Op:
# Create a fused operator by setting the parameters etc.
return CustomOp()
custom_rule = FusionRule(
pattern=(Temperature, Softmax),
guard=custom_fusion_guard,
build=build_custom_fusion,
prio=20
)
pipe = LogitsPipe(
[Temperature(), Softmax(), Sample()],
custom_fusion_rules=[custom_rule]
) # The compiled ops in the pipeline will be [CustomOp, Sample]