Skip to content

Commit

Permalink
[V1][Core] Add guided decoding in v1
Browse files Browse the repository at this point in the history
Signed-off-by: wangerxiao <[email protected]>
  • Loading branch information
WangErXiao committed Jan 12, 2025
1 parent 8bddb73 commit b574a36
Show file tree
Hide file tree
Showing 11 changed files with 181 additions and 16 deletions.
9 changes: 9 additions & 0 deletions tests/entrypoints/llm/test_guided_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@
GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"]


@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


@pytest.fixture(scope="module")
def llm():
# pytest caches the fixture so we use weakref.proxy to
Expand Down Expand Up @@ -236,6 +244,7 @@ def test_validation_against_both_guided_decoding_options(sample_regex, llm):

@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
@pytest.mark.skip_v1
def test_guided_json_object(llm, guided_decoding_backend: str):
sampling_params = SamplingParams(temperature=1.0,
max_tokens=100,
Expand Down
39 changes: 35 additions & 4 deletions vllm/model_executor/layers/logits_processor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""A layer that compute logits from hidden_stats."""
import inspect
from typing import Optional
from typing import Optional, Union

import torch
import torch.nn as nn
Expand All @@ -12,6 +12,7 @@
VocabParallelEmbedding)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.v1.sample.metadata import SamplingMetadata as SamplingMetadataV1


class LogitsProcessor(nn.Module):
Expand Down Expand Up @@ -51,13 +52,14 @@ def forward(
self,
lm_head: VocabParallelEmbedding,
hidden_states: torch.Tensor,
sampling_metadata: Optional[SamplingMetadata] = None,
sampling_metadata: Optional[Union[SamplingMetadata,
SamplingMetadataV1]] = None,
embedding_bias: Optional[torch.Tensor] = None,
) -> Optional[torch.Tensor]:
if self.logits_as_input:
logits = hidden_states
else:
if sampling_metadata is not None:
if not envs.VLLM_USE_V1 and sampling_metadata is not None:
hidden_states = _prune_hidden_states(hidden_states,
sampling_metadata)

Expand All @@ -74,7 +76,12 @@ def forward(

# Apply logits processors (if any).
if sampling_metadata is not None:
logits = _apply_logits_processors(logits, sampling_metadata)
if envs.VLLM_USE_V1:
logits = _apply_logits_processors_v1(
logits, sampling_metadata)
else:
logits = _apply_logits_processors(logits,
sampling_metadata)

return logits

Expand Down Expand Up @@ -162,3 +169,27 @@ def _apply_logits_processors(
# verifies that no rows in logits were missed unexpectedly
assert logits_processed == logits.shape[0]
return logits


def _apply_logits_processors_v1(
logits: torch.Tensor,
sampling_metadata: SamplingMetadataV1,
) -> torch.Tensor:
for idx, req_logits_processors in enumerate(
sampling_metadata.logits_processors):
if not req_logits_processors:
continue

past_tokens_ids = sampling_metadata.output_token_ids[idx]
for logits_processor in req_logits_processors:
parameters = inspect.signature(logits_processor).parameters
if len(parameters) == 3:
# TODO(WangErXiao): Whether to add prompt_tokens_ids info
# in the CPU
prompt_tokens_ids = sampling_metadata.prompt_token_ids[idx]
logits[idx] = logits_processor(prompt_tokens_ids,
past_tokens_ids, logits[idx])
else:
logits[idx] = logits_processor(past_tokens_ids, logits[idx])

return logits
3 changes: 2 additions & 1 deletion vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
from typing import AsyncGenerator, Dict, List, Mapping, Optional, Type, Union

from vllm.config import ModelConfig, VllmConfig
from vllm.config import DecodingConfig, ModelConfig, VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.metrics_types import StatLoggerBase
from vllm.engine.protocol import EngineClient
Expand Down Expand Up @@ -64,6 +64,7 @@ def __init__(
model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config,
lora_config=vllm_config.lora_config,
decoding_config=vllm_config.decoding_config or DecodingConfig(),
tokenizer=self.tokenizer,
input_registry=input_registry,
)
Expand Down
14 changes: 12 additions & 2 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from vllm.config import CacheConfig, VllmConfig
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.utils import get_exception_traceback, zmq_socket_ctx
Expand All @@ -23,7 +24,7 @@
from vllm.v1.engine.mm_input_mapper import MMInputMapperServer
from vllm.v1.executor.abstract import Executor
from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import PickleEncoder
from vllm.v1.serial_utils import CloudPickleEncoder, PickleEncoder
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)
Expand Down Expand Up @@ -270,17 +271,26 @@ def process_input_socket(self, input_path: str):
# Msgpack serialization decoding.
decoder_add_req = PickleEncoder()
decoder_abort_req = PickleEncoder()
cloud_decoder_add_req = CloudPickleEncoder()

with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket:
while True:
# (RequestType, RequestData)
type_frame, data_frame = socket.recv_multipart(copy=False)
frames = socket.recv_multipart(copy=False)
type_frame, data_frame = frames[0], frames[1]
request_type = type_frame.buffer
request_data = data_frame.buffer

# Deserialize the request data.
if request_type == EngineCoreRequestType.ADD.value:
request = decoder_add_req.decode(request_data)
if len(frames) >= 3:
# Use cloudpickle for logits processors
assert isinstance(request.sampling_params,
SamplingParams)
lprocs = cloud_decoder_add_req.decode(frames[2].buffer)
request.sampling_params.logits_processors = lprocs

elif request_type == EngineCoreRequestType.ABORT.value:
request = decoder_abort_req.decode(request_data)
elif request_type == EngineCoreRequestType.PROFILE.value:
Expand Down
33 changes: 28 additions & 5 deletions vllm/v1/engine/core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import signal
import weakref
from abc import ABC, abstractmethod
from typing import List, Type
from typing import Any, List, Tuple, Type, Union

import msgspec
import zmq
Expand All @@ -17,7 +17,7 @@
EngineCoreRequestType, EngineCoreRequestUnion)
from vllm.v1.engine.core import EngineCore, EngineCoreProc
from vllm.v1.executor.abstract import Executor
from vllm.v1.serial_utils import PickleEncoder
from vllm.v1.serial_utils import CloudPickleEncoder, PickleEncoder
from vllm.v1.utils import BackgroundProcHandle

logger = init_logger(__name__)
Expand Down Expand Up @@ -151,6 +151,7 @@ def sigusr1_handler(signum, frame):

# Serialization setup.
self.encoder = PickleEncoder()
self.cloudpickle_encoder = CloudPickleEncoder()
self.decoder = msgspec.msgpack.Decoder(EngineCoreOutputs)

# ZMQ setup.
Expand Down Expand Up @@ -183,6 +184,29 @@ def sigusr1_handler(signum, frame):
"log_stats": log_stats,
})

def _prepare_request(
self, request_type: EngineCoreRequestType,
request: EngineCoreRequestUnion) -> Tuple[Union[bytes, Any], ...]:
#TODO(WangErXiao)
# Detach logits processors so that they can be pickled
# separately (may require cloudpickle which is slower)
if request_type == EngineCoreRequestType.ADD:
assert isinstance(request, EngineCoreRequest)
if request.sampling_params and \
request.sampling_params.logits_processors:
logits_processors = request.sampling_params.logits_processors
request.sampling_params.logits_processors = None
lp_bytes = self.cloudpickle_encoder.encode(logits_processors)
else:
lp_bytes = None

# (RequestType, SerializedRequest) or
# (RequestType, SerializedRequest, LogitsProcessors)
msg = (request_type.value, self.encoder.encode(request), lp_bytes) \
if lp_bytes else \
(request_type.value, self.encoder.encode(request))
return msg

def shutdown(self):
"""Clean up background resources."""
if hasattr(self, "proc_handle"):
Expand Down Expand Up @@ -214,8 +238,7 @@ def get_output(self) -> List[EngineCoreOutput]:
def _send_input(self, request_type: EngineCoreRequestType,
request: EngineCoreRequestUnion) -> None:

# (RequestType, SerializedRequest)
msg = (request_type.value, self.encoder.encode(request))
msg = self._prepare_request(request_type, request)
self.input_socket.send_multipart(msg, copy=False)

def add_request(self, request: EngineCoreRequest) -> None:
Expand Down Expand Up @@ -256,7 +279,7 @@ async def get_output_async(self) -> List[EngineCoreOutput]:
async def _send_input(self, request_type: EngineCoreRequestType,
request: EngineCoreRequestUnion) -> None:

msg = (request_type.value, self.encoder.encode(request))
msg = self._prepare_request(request_type, request)
await self.input_socket.send_multipart(msg, copy=False)

async def add_request_async(self, request: EngineCoreRequest) -> None:
Expand Down
4 changes: 3 additions & 1 deletion vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing_extensions import TypeVar

from vllm.config import VllmConfig
from vllm.config import DecodingConfig, VllmConfig
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics_types import StatLoggerBase
from vllm.envs import VLLM_ENABLE_V1_MULTIPROCESSING
Expand Down Expand Up @@ -56,6 +56,8 @@ def __init__(
self.processor = Processor(model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config,
lora_config=vllm_config.lora_config,
decoding_config=vllm_config.decoding_config
or DecodingConfig(),
tokenizer=self.tokenizer,
input_registry=input_registry,
mm_registry=mm_registry)
Expand Down
59 changes: 58 additions & 1 deletion vllm/v1/engine/processor.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import time
from typing import Mapping, Optional, Union

from vllm.config import CacheConfig, LoRAConfig, ModelConfig
from vllm.config import CacheConfig, DecodingConfig, LoRAConfig, ModelConfig
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
PromptType, SingletonInputsAdapter)
from vllm.inputs.parse import is_encoder_decoder_inputs
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
get_local_guided_decoding_logits_processor)
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalHasher,
MultiModalKwargs, MultiModalRegistry)
from vllm.multimodal.utils import merge_and_sort_multimodal_metadata
Expand All @@ -17,6 +20,8 @@
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient

logger = init_logger(__name__)


class Processor:

Expand All @@ -25,13 +30,15 @@ def __init__(
model_config: ModelConfig,
cache_config: CacheConfig,
lora_config: Optional[LoRAConfig],
decoding_config: Optional[DecodingConfig],
tokenizer: BaseTokenizerGroup,
input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
):

self.model_config = model_config
self.lora_config = lora_config
self.decoding_config = decoding_config or DecodingConfig()
self.tokenizer = tokenizer

self.generation_config_fields = model_config.try_get_generation_config(
Expand Down Expand Up @@ -103,6 +110,9 @@ def process_inputs(
sampling_params.update_from_generation_config(
self.generation_config_fields, eos_token_id)

# Add logits processors to sampling params
self._build_logits_processors(sampling_params, lora_request)

# Multimodal related.
# Compute MM hashes (if enabled)
mm_hashes = None
Expand Down Expand Up @@ -221,3 +231,50 @@ def _validate_model_inputs(self, inputs: ProcessorInputs):
# TODO: Find out how many placeholder tokens are there so we can
# check that chunked prefill does not truncate them
# max_batch_len = self.scheduler_config.max_num_batched_tokens

def _build_logits_processors(
self, sampling_params: SamplingParams,
lora_request: Optional[LoRARequest]) -> SamplingParams:
"""Constructs logits processors based on the guided_decoding,
logits_bias, and allowed_token_ids fields in sampling_params. Deletes
those fields and adds the constructed logits processors to the
logits_processors field. Returns the modified sampling params."""

logits_processors = []

if sampling_params.guided_decoding is not None:
# TODO(WangErXiao) Whether to copy sampling_params or not in v1
# Defensively copy sampling params since guided decoding logits
# processors can have different state for each request
# sampling_params = copy.copy(sampling_params)
guided_decoding = sampling_params.guided_decoding

logger.debug(
"Building guided decoding logits processor in "
"Processor. Params: %s", guided_decoding)

# tokenizer = self.get_tokenizer(lora_request=lora_request)
guided_decoding.backend = guided_decoding.backend or \
self.decoding_config.guided_decoding_backend

processor = get_local_guided_decoding_logits_processor(
guided_params=guided_decoding,
tokenizer=self._get_lora_tokenizer(lora_request),
model_config=self.model_config)
if processor:
logits_processors.append(processor)

# Unset so this doesn't get passed down to the model
sampling_params.guided_decoding = None

if logits_processors:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = logits_processors
else:
sampling_params.logits_processors.extend(logits_processors)

def _get_lora_tokenizer(self, lora_request: Optional[LoRARequest] = None):
# TODO(WangErXiao) add lora tokenizer
# self.tokenizer is BaseTokenizerGroup type, so we need to return
# original tokenizer for logits processor
return self.tokenizer.get_lora_tokenizer(lora_request)
5 changes: 5 additions & 0 deletions vllm/v1/sample/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

import torch

from vllm.logits_process import LogitsProcessor


@dataclass
class SamplingMetadata:
Expand All @@ -29,3 +31,6 @@ class SamplingMetadata:
output_token_ids: List[List[int]]
min_tokens: List[int]
stop_token_ids: List[Set[int]]

# guided decoding
logits_processors: List[Optional[List[LogitsProcessor]]]
11 changes: 11 additions & 0 deletions vllm/v1/serial_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pickle

import cloudpickle


class PickleEncoder:

Expand All @@ -8,3 +10,12 @@ def encode(self, obj):

def decode(self, data):
return pickle.loads(data)


class CloudPickleEncoder:

def encode(self, obj):
return cloudpickle.dumps(obj)

def decode(self, data):
return cloudpickle.loads(data)
Loading

0 comments on commit b574a36

Please sign in to comment.