-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #7 from AlmogBaku/feat/struct_handling
feat: #1 add supports complex struct parsing with streaming
- Loading branch information
Showing
15 changed files
with
1,275 additions
and
49 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
name: "Pull Request Triage" | ||
on: | ||
# NB: using `pull_request_target` runs this in the context of | ||
# the base repository, so it has permission to upload to the checks API. | ||
# This means changes won't kick in to this file until merged onto the | ||
# main branch. | ||
pull_request_target: | ||
types: [ opened, edited, reopened, synchronize ] | ||
|
||
permissions: | ||
contents: read | ||
pull-requests: write | ||
issues: write | ||
|
||
jobs: | ||
triage: | ||
name: "Triage Pull Request" | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: codelytv/pr-size-labeler@v1 | ||
with: | ||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} | ||
xs_label: 'size/xs' | ||
xs_max_size: '15' | ||
s_label: 'size/s' | ||
s_max_size: '100' | ||
m_label: 'size/m' | ||
m_max_size: '500' | ||
l_label: 'size/l' | ||
l_max_size: '1000' | ||
xl_label: 'size/xl' | ||
fail_if_xl: 'false' | ||
message_if_xl: > | ||
This PR exceeds the recommended size of 1000 lines. | ||
Please make sure you are NOT addressing multiple issues with one PR. | ||
Note this PR might be rejected due to its size. | ||
files_to_ignore: '' | ||
# - name: "Check for PR body length" | ||
# shell: bash | ||
# env: | ||
# PR_BODY: ${{ github.event.pull_request.body }} | ||
# run: | | ||
# if [ ${#PR_BODY} -lt 80 ]; then | ||
# echo "::error title=PR body is too short::Your PR is probably isn't descriptive enough.\nYou should give a description that highlights both what you're doing it and *why* you're doing it. Someone reading the PR description without clicking any issue links should be able to roughly understand what's going on." | ||
# exit 1 | ||
# fi | ||
- uses: amannn/action-semantic-pull-request@v5 | ||
env: | ||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} | ||
with: | ||
disallowScopes: | | ||
release |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,3 +6,4 @@ dist | |
openai_streaming.egg-info/ | ||
.benchmarks | ||
junit | ||
.venv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .handler import process_struct_response, Terminate, BaseHandler |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
from typing import Protocol, Literal, AsyncGenerator, Optional, Type, TypeVar, Union, Dict, Any, Tuple | ||
|
||
from pydantic import BaseModel | ||
|
||
from json_streamer import Parser, JsonParser | ||
from .yaml_parser import YamlParser | ||
from ..stream_processing import OAIResponse, process_response | ||
|
||
TModel = TypeVar('TModel', bound=BaseModel) | ||
|
||
|
||
class Terminate: | ||
pass | ||
|
||
|
||
class BaseHandler(Protocol[TModel]): | ||
""" | ||
The base handler for the structured response from OpenAI. | ||
""" | ||
|
||
def model(self) -> Type[TModel]: | ||
""" | ||
The Pydantic Data Model that we parse | ||
:return: type of the Pydantic model | ||
""" | ||
pass | ||
|
||
async def handle_partially_parsed(self, data: TModel) -> Optional[Terminate]: | ||
""" | ||
Handle partially parsed model | ||
:param data: The partially parsed object | ||
:return: None or Terminate if we want to terminate the parsing | ||
""" | ||
pass | ||
|
||
async def terminated(self): | ||
""" | ||
Called when the parsing was terminated | ||
""" | ||
|
||
|
||
OutputSerialization = Literal["json", "yaml"] | ||
|
||
|
||
class _ContentHandler: | ||
parser: Parser = None | ||
_last_resp: Optional[Union[TModel, Terminate]] = None | ||
|
||
def __init__(self, handler: BaseHandler, output_serialization: OutputSerialization = "yaml"): | ||
self.handler = handler | ||
if output_serialization == "json": | ||
self.parser = JsonParser() | ||
elif output_serialization == "yaml": | ||
self.parser = YamlParser() | ||
|
||
async def handle_content(self, content: AsyncGenerator[str, None]): | ||
""" | ||
Handle the content of the response from OpenAI. | ||
:param content: A generator that yields the content of the response from OpenAI | ||
:return: None | ||
""" | ||
|
||
loader = self.parser() # create a Streaming loader | ||
next(loader) | ||
|
||
last_resp = None | ||
|
||
async for token in content: | ||
parsed = loader.send(token) # send the token to the JSON loader | ||
while parsed: # loop until through the parsed parts as the loader yields them | ||
last_resp = await self._handle_parsed(parsed[1]) # handle the parsed dict of the response | ||
if isinstance(last_resp, Terminate): | ||
break | ||
try: | ||
parsed = next(loader) | ||
except StopIteration: | ||
break | ||
if isinstance(last_resp, Terminate): | ||
break | ||
|
||
if not last_resp: | ||
return | ||
if isinstance(last_resp, Terminate): | ||
await self.handler.terminated() | ||
|
||
self._last_resp = last_resp | ||
|
||
async def _handle_parsed(self, part) -> Optional[Union[TModel, Terminate]]: | ||
""" | ||
Handle a parsed part of the response from OpenAI. | ||
It parses the "parsed dictionary" as a type of `TModel` object and processes it with the handler. | ||
:param part: A dictionary containing the parsed part of the response | ||
:return: The parsed part of the response as an `TModel` object, `Terminate` to terminate the handling, | ||
or `None` if the part is not valid | ||
""" | ||
try: | ||
parsed = self.handler.model()(**part) | ||
except (TypeError, ValueError): | ||
return | ||
|
||
ret = await self.handler.handle_partially_parsed(parsed) | ||
return ret if ret else parsed | ||
|
||
def get_last_response(self) -> Optional[Union[TModel, Terminate]]: | ||
""" | ||
Get the last response from OpenAI. | ||
:return: The last response from OpenAI | ||
""" | ||
return self._last_resp | ||
|
||
|
||
async def process_struct_response( | ||
response: OAIResponse, | ||
handler: BaseHandler, | ||
output_serialization: OutputSerialization = "json" | ||
) -> Tuple[Optional[Union[TModel, Terminate]], Dict[str, Any]]: | ||
""" | ||
Process the structured response from OpenAI. | ||
This is useful when we want to parse a structured response from OpenAI in streaming mode. For example: our response | ||
contains reasoning, and content - but we want to stream only the content to the user. | ||
:param response: The response from OpenAI | ||
:param handler: The handler for the response. It should be a subclass of `BaseHandler` | ||
:param output_serialization: The output serialization of the response. It should be either "json" or "yaml" | ||
:return: A tuple of the last parsed response, and a dictionary containing the OpenAI response | ||
""" | ||
|
||
handler = _ContentHandler(handler, output_serialization) | ||
_, result = await process_response(response, handler.handle_content, self=handler) | ||
if not handler.get_last_response(): | ||
raise ValueError("Probably invalid response from OpenAI") | ||
|
||
return handler.get_last_response(), result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from typing import List, Dict, Tuple, Generator, Optional | ||
from json_streamer import Parser, ParseState | ||
|
||
|
||
class YamlParser(Parser): | ||
""" | ||
Parse partial YAML | ||
""" | ||
|
||
@staticmethod | ||
def opening_symbols() -> List[chr]: | ||
return ['{', '[', '"'] | ||
|
||
def raw_decode(self, s: str) -> Tuple[Dict, int]: | ||
try: | ||
from yaml import safe_load | ||
except ImportError: | ||
raise ImportError("You must install PyYAML to use the YamlParser: pip install PyYAML") | ||
return safe_load(s), -1 | ||
|
||
def parse_part(self, part: str) -> Generator[Tuple[ParseState, dict], None, None]: | ||
for y in super().parse_part(part): | ||
yield ParseState.UNKNOWN, y[1] | ||
|
||
|
||
def loads(s: Optional[Generator[chr, None, None]] = None) -> Generator[Tuple[ParseState, dict], Optional[str], None]: | ||
return YamlParser()(s) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
# build with: pipx run pydoc-markdown | ||
loaders: | ||
- type: python | ||
search_path: | ||
|
Oops, something went wrong.