diff --git a/go.mod b/go.mod index b94e8dfc..516cd3db 100644 --- a/go.mod +++ b/go.mod @@ -10,12 +10,14 @@ require ( github.com/aws/aws-sdk-go-v2/config v1.28.7 github.com/envoyproxy/gateway v1.2.4 github.com/envoyproxy/go-control-plane/envoy v1.32.2 + github.com/google/go-cmp v0.6.0 github.com/openai/openai-go v0.1.0-alpha.43 github.com/stretchr/testify v1.10.0 golang.org/x/exp v0.0.0-20240904232852-e7e105dedf7e google.golang.org/grpc v1.69.2 google.golang.org/protobuf v1.36.1 k8s.io/apimachinery v0.32.0 + k8s.io/utils v0.0.0-20241104163129-6fe5fd82f078 sigs.k8s.io/controller-runtime v0.19.3 sigs.k8s.io/gateway-api v1.2.1 sigs.k8s.io/yaml v1.4.0 @@ -51,7 +53,6 @@ require ( github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/google/gnostic-models v0.6.8 // indirect - github.com/google/go-cmp v0.6.0 // indirect github.com/google/gofuzz v1.2.0 // indirect github.com/google/uuid v1.6.0 // indirect github.com/imdario/mergo v1.0.1 // indirect @@ -92,7 +93,6 @@ require ( k8s.io/client-go v0.31.2 // indirect k8s.io/klog/v2 v2.130.1 // indirect k8s.io/kube-openapi v0.0.0-20241105132330-32ad38e42d3f // indirect - k8s.io/utils v0.0.0-20241104163129-6fe5fd82f078 // indirect sigs.k8s.io/json v0.0.0-20241014173422-cfa47c3a1cc8 // indirect sigs.k8s.io/structured-merge-diff/v4 v4.4.3 // indirect ) diff --git a/internal/apischema/awsbedrock/awsbedrock.go b/internal/apischema/awsbedrock/awsbedrock.go index b868a03f..25de5ba2 100644 --- a/internal/apischema/awsbedrock/awsbedrock.go +++ b/internal/apischema/awsbedrock/awsbedrock.go @@ -5,44 +5,362 @@ import ( "strings" ) -// ConverseRequest is defined in the AWS Bedrock API: -// https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html#API_runtime_Converse_RequestBody -type ConverseRequest struct { - Messages []Message `json:"messages,omitempty"` +const ( + // StopReasonEndTurn is a StopReason enum value. + StopReasonEndTurn = "end_turn" + + // StopReasonToolUse is a StopReason enum value. + StopReasonToolUse = "tool_use" + + // StopReasonMaxTokens is a StopReason enum value. + StopReasonMaxTokens = "max_tokens" + + // StopReasonStopSequence is a StopReason enum value. + StopReasonStopSequence = "stop_sequence" + + // StopReasonGuardrailIntervened is a StopReason enum value. + StopReasonGuardrailIntervened = "guardrail_intervened" + + // StopReasonContentFiltered is a StopReason enum value. + StopReasonContentFiltered = "content_filtered" + + // ConversationRoleUser is a ConversationRole enum value. + ConversationRoleUser = "user" + + // ConversationRoleAssistant is a ConversationRole enum value. + ConversationRoleAssistant = "assistant" +) + +// InferenceConfiguration Base inference parameters to pass to a model in a call to Converse (https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html) +// or ConverseStream (https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ConverseStream.html). +// For more information, see Inference parameters for foundation models (https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html). +// +// If you need to pass additional parameters that the model supports, use the +// additionalModelRequestFields request field in the call to Converse or ConverseStream. +// For more information, see Model parameters (https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html). +type InferenceConfiguration struct { + // The maximum number of tokens to allow in the generated response. The default + // value is the maximum allowed value for the model that you are using. For + // more information, see Inference parameters for foundation models (https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html). + MaxTokens *int64 `json:"maxTokens,omitempty"` + + // A list of stop sequences. A stop sequence is a sequence of characters that + // causes the model to stop generating the response. + StopSequences []*string `json:"stopSequences,omitempty"` + + // The likelihood of the model selecting higher-probability options while generating + // a response. A lower value makes the model more likely to choose higher-probability + // options, while a higher value makes the model more likely to choose lower-probability + // options. + // + // The default value is the default value for the model that you are using. + // For more information, see Inference parameters for foundation models (https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html). + Temperature *float64 `json:"temperature,omitempty"` + + // The percentage of most-likely candidates that the model considers for the + // next token. For example, if you choose a value of 0.8 for topP, the model + // selects from the top 80% of the probability distribution of tokens that could + // be next in the sequence. + // + // The default value is the default value for the model that you are using. + // For more information, see Inference parameters for foundation models (https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html). + TopP *float64 `json:"topP,omitempty"` +} + +// GuardrailConverseTextBlock A text block that contains text that you want to assess with a guardrail. +// For more information, see GuardrailConverseContentBlock. +type GuardrailConverseTextBlock struct { + // The qualifier details for the guardrails contextual grounding filter. + Qualifiers []*string `json:"qualifiers,omitempty"` + + // The text that you want to guard. + // + // Text is a required field + Text *string `json:"text"` +} + +// GuardrailConverseContentBlock A content block for selective guarding with the Converse (https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html) +// or ConverseStream (https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ConverseStream.html) +// API operations. +type GuardrailConverseContentBlock struct { + // The text to guard. + Text *GuardrailConverseTextBlock `json:"text"` } -// Message is defined in the AWS Bedrock API: -// https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Message.html#bedrock-Type-runtime_Message-content +// SystemContentBlock A system content block. +type SystemContentBlock struct { + // A content block to assess with the guardrail. Use with the Converse (https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html) + // or ConverseStream (https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ConverseStream.html) + // API operations. + // + // For more information, see Use a guardrail with the Converse API in the Amazon + // Bedrock User Guide. + GuardContent *GuardrailConverseContentBlock `json:"guardContent,omitempty"` + + // A system prompt for the model. + Text string `json:"text"` +} + +// GuardrailConfiguration Configuration information for a guardrail that you use with the Converse +// (https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html) +// operation. +type GuardrailConfiguration struct { + // The identifier for the guardrail. + // + // GuardrailIdentifier is a required field + GuardrailIdentifier *string `json:"guardrailIdentifier"` + + // The version of the guardrail. + // + // GuardrailVersion is a required field + GuardrailVersion *string `json:"guardrailVersion"` + + // The trace behavior for the guardrail. + Trace *string `json:"trace,omitempty"` +} + +type ConverseInput struct { + // Additional model parameters field paths to return in the response. Converse + // returns the requested fields as a JSON Pointer object in the additionalModelResponseFields + // field. The following is example JSON for additionalModelResponseFieldPaths. + // + // [ "/stop_sequence" ] + // + // For information about the JSON Pointer syntax, see the Internet Engineering + // Task Force (IETF) (https://datatracker.ietf.org/doc/html/rfc6901) documentation. + // + // Converse rejects an empty JSON Pointer or incorrectly structured JSON Pointer + // with a 400 error code. if the JSON Pointer is valid, but the requested field + // is not in the model response, it is ignored by Converse. + AdditionalModelResponseFieldPaths []*string `json:"additionalModelResponseFieldPaths,omitempty"` + + // Configuration information for a guardrail that you want to use in the request. + GuardrailConfig *GuardrailConfiguration `json:"guardrailConfig,omitempty"` + + // Inference parameters to pass to the model. Converse supports a base set of + // inference parameters. If you need to pass additional parameters that the + // model supports, use the additionalModelRequestFields request field. + InferenceConfig *InferenceConfiguration `json:"inferenceConfig,omitempty"` + + // The messages that you want to send to the model. + // + // Messages is a required field + Messages []*Message `json:"messages"` + + // The identifier for the model that you want to call. + // + // The modelId to provide depends on the type of model that you use: + // + // * If you use a base model, specify the model ID or its ARN. For a list + // of model IDs for base models, see Amazon Bedrock base model IDs (on-demand + // throughput) (https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html#model-ids-arns) + // in the Amazon Bedrock User Guide. + // + // * If you use a provisioned model, specify the ARN of the Provisioned Throughput. + // For more information, see Run inference using a Provisioned Throughput + // (https://docs.aws.amazon.com/bedrock/latest/userguide/prov-thru-use.html) + // in the Amazon Bedrock User Guide. + // + // * If you use a custom model, first purchase Provisioned Throughput for + // it. Then specify the ARN of the resulting provisioned model. For more + // information, see Use a custom model in Amazon Bedrock (https://docs.aws.amazon.com/bedrock/latest/userguide/model-customization-use.html) + // in the Amazon Bedrock User Guide. + // + // ModelId is a required field + ModelID *string `json:"modelId"` + + // A system prompt to pass to the model. + System []*SystemContentBlock `json:"system,omitempty"` + + // Configuration information for the tools that the model can use when generating + // a response. + // + // This field is only supported by Anthropic Claude 3, Cohere Command R, Cohere + // Command R+, and Mistral Large models. + ToolConfig *ToolConfiguration `json:"toolConfig,omitempty"` +} + +// Message A message input, or returned from, a call to Converse (https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html) +// or ConverseStream (https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ConverseStream.html). type Message struct { - Role string `json:"role,omitempty"` - Content []ContentBlock `json:"content,omitempty"` + // The message content. Note the following restrictions: + // + // * You can include up to 20 images. Each image's size, height, and width + // must be no more than 3.75 MB, 8000 px, and 8000 px, respectively. + // + // * You can include up to five documents. Each document's size must be no + // more than 4.5 MB. + // + // * If you include a ContentBlock with a document field in the array, you + // must also include a ContentBlock with a text field. + // + // * You can only include images and documents if the role is user. + // + // Content is a required field + Content []*ContentBlock `json:"content"` + + // The role that the message plays in the message. + // + // Role is a required field + Role string `json:"role"` +} + +// ImageSource The source for an image. +type ImageSource struct { + // The raw image bytes for the image. If you use an AWS SDK, you don't need + // to encode the image bytes in base64. + // Bytes are automatically base64 encoded/decoded by the SDK. + Bytes []byte `json:"bytes"` +} + +// ImageBlock Image content for a message. +type ImageBlock struct { + // The format of the image. + // + // Format is a required field + Format string `json:"format"` + + // The source for the image. + // + // Source is a required field + Source ImageSource `json:"source"` +} + +// DocumentSource Contains the content of a document. +type DocumentSource struct { + // The raw bytes for the document. If you use an Amazon Web Services SDK, you + // don't need to encode the bytes in base64. + // Bytes are automatically base64 encoded/decoded by the SDK. + Bytes []byte `json:"bytes"` +} + +// DocumentBlock A document to include in a message. +type DocumentBlock struct { + // The format of a document, or its extension. + // + // Format is a required field + Format string `json:"format"` + + // A name for the document. The name can only contain the following characters: + // + // * Alphanumeric characters + // + // * Whitespace characters (no more than one in a row) + // + // * Hyphens + // + // * Parentheses + // + // * Square brackets + // + // This field is vulnerable to prompt injections, because the model might inadvertently + // interpret it as instructions. Therefore, we recommend that you specify a + // neutral name. + // + // Name is a required field + Name string `json:"name"` + + // Contains the content of the document. + // + // Source is a required field + Source DocumentSource `json:"source"` +} + +// ToolResultContentBlock The tool result content block. +type ToolResultContentBlock struct { + // A tool result that is a document. + Document *DocumentBlock `json:"document,omitempty"` + + // A tool result that is an image. + // + // This field is only supported by Anthropic Claude 3 models. + Image *ImageBlock `json:"image,omitempty"` + + // A tool result that is text. + Text *string `json:"text" type:"string,omitempty"` +} + +// ToolResultBlock A tool result block that contains the results for a tool request that the +// model previously made. +type ToolResultBlock struct { + // The content for tool result content block. + // + // Content is a required field + Content []*ToolResultContentBlock `json:"content"` + + // The status for the tool result content block. + // + // This field is only supported Anthropic Claude 3 models. + Status *string `json:"status"` + + // The ID of the tool request that this is the result for. + // + // ToolUseId is a required field + ToolUseID *string `json:"toolUseId"` } // ContentBlock is defined in the AWS Bedrock API: // https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ContentBlock.html type ContentBlock struct { - Text string `json:"text,omitempty"` + // A tool result that is a document. + Document *DocumentBlock `json:"document,omitempty"` + + // A tool result that is an image. + // + // This field is only supported by Anthropic Claude 3 models. + Image *ImageBlock `json:"image,omitempty"` + // Text to include in the message. + Text *string `json:"text,omitempty"` + + // The result for a tool request that a model makes. + ToolResult *ToolResultBlock `json:"toolResult,omitempty"` +} + +// ConverseMetrics Metrics for a call to Converse (https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html). +type ConverseMetrics struct { + // The latency of the call to Converse, in milliseconds. + // + // LatencyMs is a required field + LatencyMs *int64 `json:"latencyMs"` } -// ConverseResponse is defined in the AWS Bedrock API: -// https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html#API_runtime_Converse_ResponseElements -type ConverseResponse struct { - Output ConverseResponseOutput `json:"output,omitempty"` - Usage TokenUsage `json:"usage,omitempty"` +type ConverseOutput struct { + // Metrics for the call to Converse. + // + // Metrics is a required field + Metrics *ConverseMetrics `json:"metrics"` + + // The result from the call to Converse. + // + // Output is a required field + Output *ConverseOutput_ `json:"output"` + + // The reason why the model stopped generating output. + // + // StopReason is a required field + // + // Valid Values: end_turn | tool_use | max_tokens | stop_sequence | guardrail_intervened | content_filtered + StopReason *string `json:"stopReason"` + + // The total number of tokens used in the call to Converse. The total includes + // the tokens input to the model and the tokens generated by the model. + // + // Usage is a required field + Usage *TokenUsage `json:"usage"` } -// ConverseResponseOutput is defined in the AWS Bedrock API: +// ConverseOutput_ ConverseResponseOutput is defined in the AWS Bedrock API: // https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ConverseOutput.html -type ConverseResponseOutput struct { +type ConverseOutput_ struct { Message Message `json:"message,omitempty"` } // TokenUsage is defined in the AWS Bedrock API: // https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_TokenUsage.html type TokenUsage struct { - InputTokens int `json:"inputTokens,omitempty"` - OutputTokens int `json:"outputTokens,omitempty"` - TotalTokens int `json:"totalTokens,omitempty"` + InputTokens int `json:"inputTokens"` + OutputTokens int `json:"outputTokens"` + TotalTokens int `json:"totalTokens"` } // ConverseStreamEvent is the union of all possible event types in the AWS Bedrock API: @@ -66,3 +384,87 @@ func (c ConverseStreamEvent) String() string { type ConverseStreamEventContentBlockDelta struct { Text string `json:"text,omitempty"` } + +type BedrockException struct { + Message string `json:"message"` +} + +// AnyToolChoice The model must request at least one tool (no text is generated). For example, +// {"any" : {}}. +type AnyToolChoice struct{} + +// AutoToolChoice The Model automatically decides if a tool should be called or whether to +// generate text instead. For example, {"auto" : {}}. +type AutoToolChoice struct{} + +// SpecificToolChoice The model must request a specific tool. For example, {"tool" : {"name" : +// "Your tool name"}}. +// +// This field is only supported by Anthropic Claude 3 models. +type SpecificToolChoice struct { + // The name of the tool that the model must request. + // + // Name is a required field + Name *string `json:"name"` +} + +// ToolChoice Determines which tools the model should request in a call to Converse or +// ConverseStream. ToolChoice is only supported by Anthropic Claude 3 models +// and by Mistral AI Mistral Large. +type ToolChoice struct { + // The model must request at least one tool (no text is generated). + Any *AnyToolChoice `json:"any,omitempty"` + + // (Default). The Model automatically decides if a tool should be called or + // whether to generate text instead. + Auto *AutoToolChoice `json:"auto,omitempty"` + + // The Model must request the specified tool. Only supported by Anthropic Claude + // 3 models. + Tool *SpecificToolChoice `json:"tool,omitempty"` +} + +// ToolConfiguration Configuration information for the tools that you pass to a model. For more +// information, see Tool use (function calling) (https://docs.aws.amazon.com/bedrock/latest/userguide/tool-use.html) +// in the Amazon Bedrock User Guide. +// +// This field is only supported by Anthropic Claude 3, Cohere Command R, Cohere +// Command R+, and Mistral Large models. +type ToolConfiguration struct { + // If supported by model, forces the model to request a tool. + ToolChoice *ToolChoice `json:"toolChoice,omitempty"` + + // An array of tools that you want to pass to a model. + // + // Tools is a required field + Tools []*Tool `json:"tools"` +} + +// Tool Information about a tool that you can use with the Converse API. For more +// information, see Tool use (function calling) (https://docs.aws.amazon.com/bedrock/latest/userguide/tool-use.html) +// in the Amazon Bedrock User Guide. +type Tool struct { + // The specification for the tool. + ToolSpec *ToolSpecification `json:"toolSpec"` +} + +// ToolInputSchema The schema for the tool. The top level schema type must be an object. +type ToolInputSchema struct { + JSON any `json:"json"` +} + +// ToolSpecification The specification for the tool. +type ToolSpecification struct { + // The description for the tool. + Description *string `json:"description,omitempty"` + + // The input schema for the tool in JSON format. + // + // InputSchema is a required field + InputSchema *ToolInputSchema `json:"inputSchema"` + + // The name for the tool. + // + // Name is a required field + Name *string `json:"name"` +} diff --git a/internal/apischema/openai/openai.go b/internal/apischema/openai/openai.go index b1fd00d2..49893fd9 100644 --- a/internal/apischema/openai/openai.go +++ b/internal/apischema/openai/openai.go @@ -1,36 +1,534 @@ -// Package openai contains the following are the OpenAI API schema definitions. +// Package openai contains the following is the OpenAI API schema definitions. // Note that we intentionally do not use the code generation tools like OpenAPI Generator not only to keep the code simple // but also because the OpenAI's OpenAPI definition is not compliant with the spec and the existing tools do not work well. package openai import ( "encoding/json" + "fmt" "strings" ) -// ChatCompletionRequest represents a request to /v1/chat/completions. -// https://platform.openai.com/docs/api-reference/chat/create +// Chat message role defined by the OpenAI API. +const ( + ChatMessageRoleSystem = "system" + ChatMessageRoleUser = "user" + ChatMessageRoleAssistant = "assistant" + ChatMessageRoleFunction = "function" + ChatMessageRoleTool = "tool" +) + +// ChatCompletionContentPartRefusalType The type of the content part. +type ChatCompletionContentPartRefusalType string + +// ChatCompletionContentPartInputAudioType The type of the content part. Always `input_audio`. +type ChatCompletionContentPartInputAudioType string + +// ChatCompletionContentPartTextType The type of the content part. +type ChatCompletionContentPartTextType string + +// ChatCompletionContentPartImageType The type of the content part. +type ChatCompletionContentPartImageType string + +const ( + ChatCompletionContentPartTextTypeText ChatCompletionContentPartTextType = "text" + ChatCompletionContentPartRefusalTypeRefusal ChatCompletionContentPartRefusalType = "refusal" + ChatCompletionContentPartInputAudioTypeInputAudio ChatCompletionContentPartInputAudioType = "input_audio" + ChatCompletionContentPartImageTypeImageURL ChatCompletionContentPartImageType = "image_url" +) + +// ChatCompletionContentPartTextParam Learn about +// [text inputs](https://platform.openai.com/docs/guides/text-generation). +type ChatCompletionContentPartTextParam struct { + // The text content. + Text string `json:"text"` + // The type of the content part. + Type string `json:"type"` +} + +type ChatCompletionContentPartRefusalParam struct { + // The refusal message generated by the model. + Refusal string `json:"refusal"` + // The type of the content part. + Type ChatCompletionContentPartRefusalType `json:"type"` +} + +// ChatCompletionContentPartInputAudioParam Learn about [audio inputs](https://platform.openai.com/docs/guides/audio). +type ChatCompletionContentPartInputAudioParam struct { + InputAudio ChatCompletionContentPartInputAudioInputAudioParam `json:"input_audio"` + // The type of the content part. Always `input_audio`. + Type ChatCompletionContentPartInputAudioType `json:"type"` +} + +// ChatCompletionContentPartInputAudioInputAudioFormat The format of the encoded audio data. Currently supports "wav" and "mp3". +type ChatCompletionContentPartInputAudioInputAudioFormat string + +const ( + ChatCompletionContentPartInputAudioInputAudioFormatWAV ChatCompletionContentPartInputAudioInputAudioFormat = "wav" + ChatCompletionContentPartInputAudioInputAudioFormatMP3 ChatCompletionContentPartInputAudioInputAudioFormat = "mp3" +) + +type ChatCompletionContentPartInputAudioInputAudioParam struct { + // Base64 encoded audio data. + Data string `json:"data"` + // The format of the encoded audio data. Currently supports "wav" and "mp3". + Format ChatCompletionContentPartInputAudioInputAudioFormat `json:"format"` +} + +type ChatCompletionContentPartImageImageURLDetail string + +const ( + ChatCompletionContentPartImageImageURLDetailAuto ChatCompletionContentPartImageImageURLDetail = "auto" + ChatCompletionContentPartImageImageURLDetailLow ChatCompletionContentPartImageImageURLDetail = "low" + ChatCompletionContentPartImageImageURLDetailHigh ChatCompletionContentPartImageImageURLDetail = "high" +) + +type ChatCompletionContentPartImageImageURLParam struct { + // Either a URL of the image or the base64 encoded image data. + URL string `json:"url"` + // Specifies the detail level of the image. Learn more in the + // [Vision guide](https://platform.openai.com/docs/guides/vision#low-or-high-fidelity-image-understanding). + Detail ChatCompletionContentPartImageImageURLDetail `json:"detail,omitempty"` +} + +// ChatCompletionContentPartImageParam Learn about [image inputs](https://platform.openai.com/docs/guides/vision). +type ChatCompletionContentPartImageParam struct { + ImageURL ChatCompletionContentPartImageImageURLParam `json:"image_url"` + // The type of the content part. + Type ChatCompletionContentPartImageType `json:"type"` +} + +// ChatCompletionContentPartUserUnionParam Learn about +// [text inputs](https://platform.openai.com/docs/guides/text-generation). +type ChatCompletionContentPartUserUnionParam struct { + TextContent *ChatCompletionContentPartTextParam + InputAudioContent *ChatCompletionContentPartInputAudioParam + ImageContent *ChatCompletionContentPartImageParam +} + +func (c *ChatCompletionContentPartUserUnionParam) UnmarshalJSON(data []byte) error { + var chatContentPart map[string]interface{} + if err := json.Unmarshal(data, &chatContentPart); err != nil { + return err + } + var contentType string + var ok bool + if contentType, ok = chatContentPart["type"].(string); !ok { + return fmt.Errorf("chat content does not have type") + } + switch contentType { + case string(ChatCompletionContentPartTextTypeText): + var textContent ChatCompletionContentPartTextParam + if err := json.Unmarshal(data, &textContent); err != nil { + return err + } + c.TextContent = &textContent + case string(ChatCompletionContentPartInputAudioTypeInputAudio): + var audioContent ChatCompletionContentPartInputAudioParam + if err := json.Unmarshal(data, &audioContent); err != nil { + return err + } + c.InputAudioContent = &audioContent + case string(ChatCompletionContentPartImageTypeImageURL): + var imageContent ChatCompletionContentPartImageParam + if err := json.Unmarshal(data, &imageContent); err != nil { + return err + } + c.ImageContent = &imageContent + default: + return fmt.Errorf("unknown ChatCompletionContentPartUnionParam type: %v", contentType) + } + return nil +} + +type StringOrArray struct { + Value interface{} +} + +func (s *StringOrArray) UnmarshalJSON(data []byte) error { + var str string + err := json.Unmarshal(data, &str) + if err == nil { + s.Value = str + return nil + } + + var arr []ChatCompletionContentPartTextParam + err = json.Unmarshal(data, &arr) + if err == nil { + s.Value = arr + return nil + } + + return fmt.Errorf("cannot unmarshal JSON data as string or array of string") +} + +type StringOrUserRoleContentUnion struct { + Value interface{} +} + +func (s *StringOrUserRoleContentUnion) UnmarshalJSON(data []byte) error { + var str string + err := json.Unmarshal(data, &str) + if err == nil { + s.Value = str + return nil + } + + var arr []ChatCompletionContentPartUserUnionParam + err = json.Unmarshal(data, &arr) + if err == nil { + s.Value = arr + return nil + } + + return fmt.Errorf("cannot unmarshal JSON data as string or array of content parts") +} + +type ChatCompletionMessageParamUnion struct { + Value interface{} + Type string +} + +func (c *ChatCompletionMessageParamUnion) UnmarshalJSON(data []byte) error { + var chatMessage map[string]interface{} + if err := json.Unmarshal(data, &chatMessage); err != nil { + return err + } + if _, ok := chatMessage["role"]; !ok { + return fmt.Errorf("chat message does not have role") + } + var role string + var ok bool + if role, ok = chatMessage["role"].(string); !ok { + return fmt.Errorf("chat message role is not string: %s", role) + } + switch role { + case ChatMessageRoleUser: + var userMessage ChatCompletionUserMessageParam + if err := json.Unmarshal(data, &userMessage); err != nil { + return err + } + c.Value = userMessage + c.Type = ChatMessageRoleUser + case ChatMessageRoleAssistant: + var assistantMessage ChatCompletionAssistantMessageParam + if err := json.Unmarshal(data, &assistantMessage); err != nil { + return err + } + c.Value = assistantMessage + c.Type = ChatMessageRoleAssistant + case ChatMessageRoleSystem: + var systemMessage ChatCompletionSystemMessageParam + if err := json.Unmarshal(data, &systemMessage); err != nil { + return err + } + c.Value = systemMessage + c.Type = ChatMessageRoleSystem + case ChatMessageRoleTool: + var toolMessage ChatCompletionToolMessageParam + if err := json.Unmarshal(data, &toolMessage); err != nil { + return err + } + c.Value = toolMessage + c.Type = ChatMessageRoleTool + default: + return fmt.Errorf("unknown ChatCompletionMessageParam type: %v", role) + } + return nil +} + +// ChatCompletionUserMessageParam Messages sent by an end user, containing prompts or additional context +// information. +type ChatCompletionUserMessageParam struct { + // The contents of the user message. + Content StringOrUserRoleContentUnion `json:"content"` + // The role of the messages author, in this case `user`. + Role string `json:"role"` + // An optional name for the participant. Provides the model information to + // differentiate between participants of the same role. + Name string `json:"name,omitempty"` +} + +// ChatCompletionSystemMessageParam Developer-provided instructions that the model should follow, regardless of +// messages sent by the user. With o1 models and newer, use `developer` messages +// for this purpose instead. +type ChatCompletionSystemMessageParam struct { + // The contents of the system message. + Content StringOrArray `json:"content"` + // The role of the messages author, in this case `system`. + Role string `json:"role"` + // An optional name for the participant. Provides the model information to + // differentiate between participants of the same role. + Name string `json:"name,omitempty"` +} + +type ChatCompletionToolMessageParam struct { + // The contents of the tool message. + Content StringOrArray `json:"content"` + // The role of the messages author, in this case `tool`. + Role string `json:"role"` + // Tool call that this message is responding to. + ToolCallID string `json:"tool_call_id"` +} + +// ChatCompletionAssistantMessageParamAudio Data about a previous audio response from the model. +// [Learn more](https://platform.openai.com/docs/guides/audio). +type ChatCompletionAssistantMessageParamAudio struct { + // Unique identifier for a previous audio response from the model. + ID string `json:"id"` +} + +// ChatCompletionAssistantMessageParamContentType The type of the content part. +type ChatCompletionAssistantMessageParamContentType string + +const ( + ChatCompletionAssistantMessageParamContentTypeText ChatCompletionAssistantMessageParamContentType = "text" + ChatCompletionAssistantMessageParamContentTypeRefusal ChatCompletionAssistantMessageParamContentType = "refusal" +) + +// ChatCompletionAssistantMessageParamContent Learn about +// [text inputs](https://platform.openai.com/docs/guides/text-generation). +type ChatCompletionAssistantMessageParamContent struct { + // The type of the content part. + Type ChatCompletionAssistantMessageParamContentType `json:"type"` + // The refusal message generated by the model. + Refusal *string `json:"refusal,omitempty"` + // The text content. + Text *string `json:"text,omitempty"` +} + +// ChatCompletionAssistantMessageParam Messages sent by the model in response to user messages. +type ChatCompletionAssistantMessageParam struct { + // The role of the messages author, in this case `assistant`. + Role string `json:"role"` + // Data about a previous audio response from the model. + // [Learn more](https://platform.openai.com/docs/guides/audio). + Audio ChatCompletionAssistantMessageParamAudio `json:"audio,omitempty"` + // The contents of the assistant message. Required unless `tool_calls` or + // `function_call` is specified. + Content ChatCompletionAssistantMessageParamContent `json:"content"` + // An optional name for the participant. Provides the model information to + // differentiate between participants of the same role. + Name string `json:"name,omitempty"` + // The refusal message by the assistant. + Refusal string `json:"refusal,omitempty"` + // The tool calls generated by the model, such as function calls. + ToolCalls []ChatCompletionMessageToolCallParam `json:"tool_calls,omitempty"` +} + +// ChatCompletionMessageToolCallType The type of the tool. Currently, only `function` is supported. +type ChatCompletionMessageToolCallType string + +const ( + ChatCompletionMessageToolCallTypeFunction ChatCompletionMessageToolCallType = "function" +) + +// ChatCompletionMessageToolCallFunctionParam The function that the model called. +type ChatCompletionMessageToolCallFunctionParam struct { + // The arguments to call the function with, as generated by the model in JSON + // format. Note that the model does not always generate valid JSON, and may + // hallucinate parameters not defined by your function schema. Validate the + // arguments in your code before calling your function. + Arguments string `json:"arguments"` + // The name of the function to call. + Name string `json:"name"` +} + +type ChatCompletionMessageToolCallParam struct { + // The ID of the tool call. + ID string `json:"id"` + // The function that the model called. + Function ChatCompletionMessageToolCallFunctionParam `json:"function"` + // The type of the tool. Currently, only `function` is supported. + Type ChatCompletionMessageToolCallType `json:"type"` +} + +type ChatCompletionResponseFormatType string + +const ( + ChatCompletionResponseFormatTypeJSONObject ChatCompletionResponseFormatType = "json_object" + ChatCompletionResponseFormatTypeJSONSchema ChatCompletionResponseFormatType = "json_schema" + ChatCompletionResponseFormatTypeText ChatCompletionResponseFormatType = "text" +) + +type ChatCompletionResponseFormat struct { + Type ChatCompletionResponseFormatType `json:"type,omitempty"` + JSONSchema *ChatCompletionResponseFormatJSONSchema `json:"json_schema,omitempty"` //nolint:tagliatelle //follow openai api +} + +type ChatCompletionResponseFormatJSONSchema struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Schema json.Marshaler `json:"schema"` + Strict bool `json:"strict"` +} + +// ChatCompletionRequest represents a request structure for chat completion API. type ChatCompletionRequest struct { - // Model is described in the OpenAI API documentation: - // https://platform.openai.com/docs/api-reference/chat/create#chat-create-model + // Messages: A list of messages comprising the conversation so far. + // Depending on the model you use, different message types (modalities) are supported, + // like text, images, and audio. + // Docs: https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages + Messages []ChatCompletionMessageParamUnion `json:"messages"` + + // Model: ID of the model to use + // Docs: https://platform.openai.com/docs/api-reference/chat/create#chat-create-model Model string `json:"model"` - // Messages is described in the OpenAI API documentation: - // https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages - Messages []ChatCompletionRequestMessage `json:"messages"` + // FrequencyPenalty: Number between -2.0 and 2.0 + // Docs: https://platform.openai.com/docs/api-reference/chat/create#chat-create-frequency_penalty + FrequencyPenalty *float32 `json:"frequency_penalty,omitempty"` //nolint:tagliatelle //follow openai api + + // LogitBias Modify the likelihood of specified tokens appearing in the completion. + // It must be a token id string (specified by their token ID in the tokenizer), not a word string. + // incorrect: `"logit_bias":{"You": 6}`, correct: `"logit_bias":{"1639": 6}` + // Docs: https://platform.openai.com/docs/api-reference/chat/create#chat-create-logit_bias + LogitBias map[string]int `json:"logit_bias,omitempty"` //nolint:tagliatelle //follow openai api + + // LogProbs indicates whether to return log probabilities of the output tokens or not. + // If true, returns the log probabilities of each output token returned in the content of message. + // Docs: https://platform.openai.com/docs/api-reference/chat/create#chat-create-logprobs + LogProbs *bool `json:"logprobs,omitempty"` - // Stream is described in the OpenAI API documentation: - // https://platform.openai.com/docs/api-reference/chat/create#chat-create-stream + // TopLogProbs is an integer between 0 and 5 specifying the number of most likely tokens to return at each + // token position, each with an associated log probability. + // logprobs must be set to true if this parameter is used. + // Docs: https://platform.openai.com/docs/api-reference/chat/create#chat-create-top_logprobs + TopLogProbs *int `json:"top_logprobs,omitempty"` //nolint:tagliatelle //follow openai api + + // MaxTokens The maximum number of tokens that can be generated in the chat completion. + // This value can be used to control costs for text generated via API. + // This value is now deprecated in favor of max_completion_tokens, and is not compatible with o1 series models. + // refs: https://platform.openai.com/docs/api-reference/chat/create#chat-create-max_tokens + MaxTokens *int64 `json:"max_tokens,omitempty"` //nolint:tagliatelle //follow openai api + + // N: LLM Gateway does not support multiple completions. + // The only accepted value is 1. + // Docs: https://platform.openai.com/docs/api-reference/chat/create#chat-create-n + N *int `json:"n,omitempty"` + + // PresencePenalty Positive values penalize new tokens based on whether they appear in the text so far, + // increasing the model's likelihood to talk about new topics. + // Docs: https://platform.openai.com/docs/api-reference/chat/create#chat-create-presence_penalty + PresencePenalty *float32 `json:"presence_penalty,omitempty"` //nolint:tagliatelle //follow openai api + + // ResponseFormat is only for GPT models. + // Docs: https://platform.openai.com/docs/api-reference/chat/create#chat-create-response_format + ResponseFormat *ChatCompletionResponseFormat `json:"response_format,omitempty"` //nolint:tagliatelle //follow openai api + + // Seed: This feature is in Beta. If specified, our system will make a best effort to + // sample deterministically, such that repeated requests with the same `seed` and + // parameters should return the same result. Determinism is not guaranteed, and you + // should refer to the `system_fingerprint` response parameter to monitor changes + // in the backend. + // Docs: https://platform.openai.com/docs/api-reference/chat/create#chat-create-seed + Seed *int `json:"seed,omitempty"` + + // Stop string / array / null Defaults to null + // Up to 4 sequences where the API will stop generating further tokens. + // Docs: https://platform.openai.com/docs/api-reference/chat/create#chat-create-stop + Stop []*string `json:"stop,omitempty"` + + // Stream: If set, partial message deltas will be sent, like in ChatGPT. + // Docs: https://platform.openai.com/docs/api-reference/chat/create#chat-create-stream Stream bool `json:"stream,omitempty"` + + // StreamOptions for streaming response. Only set this when you set stream: true. + // Docs: https://platform.openai.com/docs/api-reference/chat/create#chat-create-stream_options + StreamOptions *StreamOptions `json:"stream_options,omitempty"` //nolint:tagliatelle //follow openai api + + // Temperature What sampling temperature to use, between 0 and 2. + // Docs: https://platform.openai.com/docs/api-reference/chat/create#chat-create-temperature + Temperature *float64 `json:"temperature,omitempty"` + + // TopP An alternative to sampling with temperature, called nucleus sampling, + // where the model considers the results of the tokens with top_p probability mass. + // So 0.1 means only the tokens comprising the top 10% probability mass are considered. + // Docs: https://platform.openai.com/docs/api-reference/chat/create#chat-create-top_p + TopP *float64 `json:"top_p,omitempty"` //nolint:tagliatelle //follow openai api + + // Tools provide a list of tool definitions to be used by the LLM. + // Docs: https://platform.openai.com/docs/api-reference/chat/create#chat-create-response_format + Tools []Tool `json:"tools,omitempty"` + + // ToolChoice specifies a specific tool to be used by name (given in the tool definition), + // or use "auto" to auto select the most appropriate. + // Docs: https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice + ToolChoice any `json:"tool_choice,omitempty"` //nolint:tagliatelle //follow openai api + + // ParallelToolCalls enables multiple tools to be returned by the model. + // Docs: https://platform.openai.com/docs/guides/function-calling/parallel-function-calling + ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"` //nolint:tagliatelle //follow openai api + + // User: A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. + // Docs: https://platform.openai.com/docs/api-reference/chat/create#chat-create-user + User string `json:"user,omitempty"` } -// ChatCompletionRequestMessage represents a message in a ChatCompletionRequest. -// https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages -type ChatCompletionRequestMessage struct { - // Role is the role of the message. The role of the message (whether it represents the user or the AI). - Role string `json:"role,omitempty"` - // Content is the content of the message. - Content any `json:"content,omitempty"` +type StreamOptions struct { + // If set, an additional chunk will be streamed before the data: [DONE] message. + // The usage field on this chunk shows the token usage statistics for the entire request, + // and the choices field will always be an empty array. + // All other chunks will also include a usage field, but with a null value. + IncludeUsage bool `json:"include_usage,omitempty"` //nolint:tagliatelle //follow openai api +} + +type ToolType string + +const ( + ToolTypeFunction ToolType = "function" +) + +type Tool struct { + Type ToolType `json:"type"` + Function *FunctionDefinition `json:"function,omitempty"` +} + +type ToolChoice struct { + Type ToolType `json:"type"` + Function ToolFunction `json:"function,omitempty"` +} + +type ToolFunction struct { + Name string `json:"name"` +} + +type FunctionDefinition struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Strict bool `json:"strict,omitempty"` + // Parameters is an object describing the function. + // You can pass json.RawMessage to describe the schema, + // or you can pass in a struct which serializes to the proper JSON schema. + // The jsonschema package is provided for convenience, but you should + // consider another specialized library if you require more complex schemas. + Parameters any `json:"parameters"` +} + +// Deprecated: use FunctionDefinition instead. +type FunctionDefine = FunctionDefinition + +type TopLogProbs struct { + Token string `json:"token"` + LogProb float64 `json:"logprob"` + Bytes []byte `json:"bytes,omitempty"` +} + +// LogProb represents the probability information for a token. +type LogProb struct { + Token string `json:"token"` + LogProb float64 `json:"logprob"` + Bytes []byte `json:"bytes,omitempty"` // Omitting the field if it is null + // TopLogProbs is a list of the most likely tokens and their log probability, at this token position. + // In rare cases, there may be fewer than the number of requested top_logprobs returned. + TopLogProbs []TopLogProbs `json:"top_logprobs"` //nolint:tagliatelle //follow openai api +} + +// LogProbs is the top-level structure containing the log probability information. +type LogProbs struct { + // Content is a list of message content tokens with log probability information. + Content []LogProb `json:"content"` } // ChatCompletionResponse represents a response from /v1/chat/completions. @@ -49,9 +547,76 @@ type ChatCompletionResponse struct { Usage ChatCompletionResponseUsage `json:"usage,omitempty"` } +// ChatCompletionChoicesFinishReason The reason the model stopped generating tokens. This will be `stop` if the model +// hit a natural stop point or a provided stop sequence, `length` if the maximum +// number of tokens specified in the request was reached, `content_filter` if +// content was omitted due to a flag from our content filters, `tool_calls` if the +// model called a tool, or `function_call` (deprecated) if the model called a +// function. +type ChatCompletionChoicesFinishReason string + +const ( + ChatCompletionChoicesFinishReasonStop ChatCompletionChoicesFinishReason = "stop" + ChatCompletionChoicesFinishReasonLength ChatCompletionChoicesFinishReason = "length" + ChatCompletionChoicesFinishReasonToolCalls ChatCompletionChoicesFinishReason = "tool_calls" + ChatCompletionChoicesFinishReasonContentFilter ChatCompletionChoicesFinishReason = "content_filter" + ChatCompletionChoicesFinishReasonFunctionCall ChatCompletionChoicesFinishReason = "function_call" +) + +type ChatCompletionTokenLogprobTopLogprob struct { + // The token. + Token string `json:"token"` + // A list of integers representing the UTF-8 bytes representation of the token. + // Useful in instances where characters are represented by multiple tokens and + // their byte representations must be combined to generate the correct text + // representation. Can be `null` if there is no bytes representation for the token. + Bytes []int64 `json:"bytes,omitempty"` + // The log probability of this token, if it is within the top 20 most likely + // tokens. Otherwise, the value `-9999.0` is used to signify that the token is very + // unlikely. + Logprob float64 `json:"logprob"` +} + +type ChatCompletionTokenLogprob struct { + // The token. + Token string `json:"token"` + // A list of integers representing the UTF-8 bytes representation of the token. + // Useful in instances where characters are represented by multiple tokens and + // their byte representations must be combined to generate the correct text + // representation. Can be `null` if there is no bytes representation for the token. + Bytes []int64 `json:"bytes,omitempty"` + // The log probability of this token, if it is within the top 20 most likely + // tokens. Otherwise, the value `-9999.0` is used to signify that the token is very + // unlikely. + Logprob float64 `json:"logprob"` + // List of the most likely tokens and their log probability, at this token + // position. In rare cases, there may be fewer than the number of requested + // `top_logprobs` returned. + TopLogprobs []ChatCompletionTokenLogprobTopLogprob `json:"top_logprobs"` +} + +// ChatCompletionChoicesLogprobs Log probability information for the choice. +type ChatCompletionChoicesLogprobs struct { + // A list of message content tokens with log probability information. + Content []ChatCompletionTokenLogprob `json:"content,omitempty"` + // A list of message refusal tokens with log probability information. + Refusal []ChatCompletionTokenLogprob `json:"refusal,omitempty"` +} + // ChatCompletionResponseChoice is described in the OpenAI API documentation: // https://platform.openai.com/docs/api-reference/chat/object#chat/object-choices type ChatCompletionResponseChoice struct { + // The reason the model stopped generating tokens. This will be `stop` if the model + // hit a natural stop point or a provided stop sequence, `length` if the maximum + // number of tokens specified in the request was reached, `content_filter` if + // content was omitted due to a flag from our content filters, `tool_calls` if the + // model called a tool, or `function_call` (deprecated) if the model called a + // function. + FinishReason ChatCompletionChoicesFinishReason `json:"finish_reason"` + // The index of the choice in the list of choices. + Index int64 `json:"index"` + // Log probability information for the choice. + Logprobs ChatCompletionChoicesLogprobs `json:"logprobs,omitempty"` // Message is described in the OpenAI API documentation: // https://platform.openai.com/docs/api-reference/chat/object#chat/object-choices Message ChatCompletionResponseChoiceMessage `json:"message,omitempty"` @@ -106,3 +671,19 @@ type ChatCompletionResponseChunkChoiceDelta struct { Content *string `json:"content,omitempty"` Role *string `json:"role,omitempty"` } + +// Error is described in the OpenAI API documentation +// https://platform.openai.com/docs/api-reference/realtime-server-events/error +type Error struct { + EventID *string `json:"event_id,omitempty"` + Type string `json:"type"` + Error ErrorType `json:"error"` +} + +type ErrorType struct { + Type string `json:"type"` + Code *string `json:"code,omitempty"` + Message string `json:"message,omitempty"` + Param *string `json:"param,omitempty"` + EventID *string `json:"event_id,omitempty"` +} diff --git a/internal/apischema/openai/openai_test.go b/internal/apischema/openai/openai_test.go new file mode 100644 index 00000000..803807ea --- /dev/null +++ b/internal/apischema/openai/openai_test.go @@ -0,0 +1,97 @@ +package openai + +import ( + "encoding/json" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/openai/openai-go" + "github.com/stretchr/testify/require" +) + +func TestOpenAIChatCompletionMessageUnmarshal(t *testing.T) { + for _, tc := range []struct { + name string + in []byte + out *ChatCompletionRequest + }{ + { + name: "basic test", + in: []byte(`{"model": "gpu-o4", + "messages": [ + {"role": "system", "content": "you are a helpful assistant"}, + {"role": "user", "content": "what do you see in this image"}]}`), + out: &ChatCompletionRequest{ + Model: "gpu-o4", + Messages: []ChatCompletionMessageParamUnion{ + { + Value: ChatCompletionSystemMessageParam{ + Role: ChatMessageRoleSystem, + Content: StringOrArray{ + Value: "you are a helpful assistant", + }, + }, + Type: ChatMessageRoleSystem, + }, + { + Value: ChatCompletionUserMessageParam{ + Role: ChatMessageRoleUser, + Content: StringOrUserRoleContentUnion{ + Value: "what do you see in this image", + }, + }, + Type: ChatMessageRoleUser, + }, + }, + }, + }, + { + name: "content with array", + in: []byte(`{"model": "gpu-o4", + "messages": [ + {"role": "system", "content": [{"text": "you are a helpful assistant", "type": "text"}]}, + {"role": "user", "content": [{"text": "what do you see in this image", "type": "text"}]}]}`), + out: &ChatCompletionRequest{ + Model: "gpu-o4", + Messages: []ChatCompletionMessageParamUnion{ + { + Value: ChatCompletionSystemMessageParam{ + Role: ChatMessageRoleSystem, + Content: StringOrArray{ + Value: []ChatCompletionContentPartTextParam{ + { + Text: "you are a helpful assistant", + Type: string(openai.ChatCompletionContentPartTextTypeText), + }, + }, + }, + }, + Type: ChatMessageRoleSystem, + }, + { + Value: ChatCompletionUserMessageParam{ + Role: ChatMessageRoleUser, + Content: StringOrUserRoleContentUnion{ + Value: []ChatCompletionContentPartUserUnionParam{ + { + TextContent: &ChatCompletionContentPartTextParam{Text: "what do you see in this image", Type: "text"}, + }, + }, + }, + }, + Type: ChatMessageRoleUser, + }, + }, + }, + }, + } { + t.Run(tc.name, func(t *testing.T) { + var chatCompletion ChatCompletionRequest + err := json.Unmarshal(tc.in, &chatCompletion) + require.NoError(t, err) + if !cmp.Equal(&chatCompletion, tc.out) { + t.Errorf("UnmarshalOpenAIRequest(), diff(got, expected) = %s\n", cmp.Diff(&chatCompletion, tc.out)) + } + }) + } +} diff --git a/internal/extproc/processor.go b/internal/extproc/processor.go index 28451597..d280b215 100644 --- a/internal/extproc/processor.go +++ b/internal/extproc/processor.go @@ -131,6 +131,11 @@ func (p *Processor) ProcessResponseHeaders(_ context.Context, headers *corev3.He if enc := hs["content-encoding"]; enc != "" { p.responseEncoding = enc } + // The translator can be nil as there could be response event generated by previous ext proc without + // getting the request event. + if p.translator == nil { + return &extprocv3.ProcessingResponse{Response: &extprocv3.ProcessingResponse_ResponseHeaders{}}, nil + } headerMutation, err := p.translator.ResponseHeaders(hs) if err != nil { return nil, fmt.Errorf("failed to transform response: %w", err) @@ -154,6 +159,11 @@ func (p *Processor) ProcessResponseBody(_ context.Context, body *extprocv3.HttpB default: br = bytes.NewReader(body.Body) } + // The translator can be nil as there could be response event generated by previous ext proc without + // getting the request event. + if p.translator == nil { + return &extprocv3.ProcessingResponse{Response: &extprocv3.ProcessingResponse_ResponseBody{}}, nil + } headerMutation, bodyMutation, usedToken, err := p.translator.ResponseBody(br, body.EndOfStream) if err != nil { return nil, fmt.Errorf("failed to transform response: %w", err) diff --git a/internal/extproc/translator/openai_awsbedrock.go b/internal/extproc/translator/openai_awsbedrock.go index 0ddc5975..ce2cc70b 100644 --- a/internal/extproc/translator/openai_awsbedrock.go +++ b/internal/extproc/translator/openai_awsbedrock.go @@ -5,11 +5,14 @@ import ( "encoding/json" "fmt" "io" + "reflect" + "strings" "github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream" corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" extprocv3http "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/ext_proc/v3" extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + "k8s.io/utils/ptr" "github.com/envoyproxy/ai-gateway/internal/apischema/awsbedrock" "github.com/envoyproxy/ai-gateway/internal/apischema/openai" @@ -63,46 +66,36 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) RequestBody(body router.R }, } - var awsReq awsbedrock.ConverseRequest - awsReq.Messages = make([]awsbedrock.Message, 0, len(openAIReq.Messages)) - for _, msg := range openAIReq.Messages { - var role string - switch msg.Role { - case "user", "assistant": - role = msg.Role - case "system": - role = "assistant" - default: - return nil, nil, nil, fmt.Errorf("unexpected role: %s", msg.Role) - } - - contents, ok := msg.Content.([]any) - if !ok { - return nil, nil, nil, fmt.Errorf("unexpected content: %[1]T:%[1]v", msg.Content) - } - for _, contentAny := range contents { - content, ok := contentAny.(map[string]any) - if !ok { - return nil, nil, nil, fmt.Errorf("unexpected content: %[1]T:%[1]v", contentAny) - } - textAny, ok := content["text"] - if !ok { - return nil, nil, nil, fmt.Errorf("missing text in content: %v", contents) - } - - text, ok := textAny.(string) - if !ok { - return nil, nil, nil, fmt.Errorf("unexpected text: %[1]T:%[1]v", textAny) - } - awsReq.Messages = append(awsReq.Messages, awsbedrock.Message{ - Role: role, - Content: []awsbedrock.ContentBlock{{Text: text}}, - }) + var bedrockReq awsbedrock.ConverseInput + // Convert InferenceConfiguration. + bedrockReq.InferenceConfig = &awsbedrock.InferenceConfiguration{} + if openAIReq.MaxTokens != nil { + bedrockReq.InferenceConfig.MaxTokens = openAIReq.MaxTokens + } + if openAIReq.Stop != nil { + bedrockReq.InferenceConfig.StopSequences = openAIReq.Stop + } + if openAIReq.Temperature != nil { + bedrockReq.InferenceConfig.Temperature = openAIReq.Temperature + } + if openAIReq.TopP != nil { + bedrockReq.InferenceConfig.TopP = openAIReq.TopP + } + // Convert Chat Completion messages. + err = o.OpenAIMessageToBedrockMessage(openAIReq, &bedrockReq) + if err != nil { + return nil, nil, nil, err + } + // Convert ToolConfiguration. + if len(openAIReq.Tools) > 0 { + err = o.openAIToolsToBedrockToolConfiguration(openAIReq, &bedrockReq) + if err != nil { + return nil, nil, nil, err } } mut := &extprocv3.BodyMutation_Body{} - if body, err := json.Marshal(awsReq); err != nil { + if body, err := json.Marshal(bedrockReq); err != nil { return nil, nil, nil, fmt.Errorf("failed to marshal body: %w", err) } else { mut.Body = body @@ -111,6 +104,168 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) RequestBody(body router.R return headerMutation, &extprocv3.BodyMutation{Mutation: mut}, override, nil } +// openAIToolsToBedrockToolConfiguration converts openai ChatCompletion tools to aws bedrock tool configurations +func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) openAIToolsToBedrockToolConfiguration(openAIReq *openai.ChatCompletionRequest, + bedrockReq *awsbedrock.ConverseInput, +) error { + bedrockReq.ToolConfig = &awsbedrock.ToolConfiguration{} + tools := make([]*awsbedrock.Tool, 0, len(openAIReq.Tools)) + for _, toolDefinition := range openAIReq.Tools { + toolType := (string)(toolDefinition.Type) + tool := &awsbedrock.Tool{ + ToolSpec: &awsbedrock.ToolSpecification{ + Name: &toolType, + Description: &toolDefinition.Function.Description, + InputSchema: &awsbedrock.ToolInputSchema{ + JSON: toolDefinition.Function.Parameters, + }, + }, + } + tools = append(tools, tool) + } + bedrockReq.ToolConfig.Tools = tools + + if openAIReq.ToolChoice != nil { + switch reflect.TypeOf(openAIReq.ToolChoice).Kind() { + case reflect.String: + if openAIReq.ToolChoice.(string) == "auto" { + bedrockReq.ToolConfig.ToolChoice = &awsbedrock.ToolChoice{ + Auto: &awsbedrock.AutoToolChoice{}, + } + } else { + bedrockReq.ToolConfig.ToolChoice = &awsbedrock.ToolChoice{ + Any: &awsbedrock.AnyToolChoice{}, + } + } + case reflect.Struct: + toolChoice := openAIReq.ToolChoice.(openai.ToolChoice) + tool := (string)(toolChoice.Type) + bedrockReq.ToolConfig.ToolChoice = &awsbedrock.ToolChoice{ + Tool: &awsbedrock.SpecificToolChoice{ + Name: &tool, + }, + } + default: + return fmt.Errorf("unexpected type: %s", reflect.TypeOf(openAIReq.ToolChoice).Kind()) + } + } + return nil +} + +// OpenAIMessageToBedrockMessage converts openai ChatCompletion messages to aws bedrock messages +func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) OpenAIMessageToBedrockMessage(openAIReq *openai.ChatCompletionRequest, + bedrockReq *awsbedrock.ConverseInput, +) error { + // Convert Messages. + bedrockReq.Messages = make([]*awsbedrock.Message, 0, len(openAIReq.Messages)) + for _, msg := range openAIReq.Messages { + switch msg.Type { + case openai.ChatMessageRoleUser: + message := msg.Value.(openai.ChatCompletionUserMessageParam) + if _, ok := message.Content.Value.(string); ok { + bedrockReq.Messages = append(bedrockReq.Messages, &awsbedrock.Message{ + Role: msg.Type, + Content: []*awsbedrock.ContentBlock{ + {Text: ptr.To(message.Content.Value.(string))}, + }, + }) + } else { + if contents, ok := message.Content.Value.([]openai.ChatCompletionContentPartUserUnionParam); ok { + chatMessage := &awsbedrock.Message{Role: msg.Type} + chatMessage.Content = make([]*awsbedrock.ContentBlock, 0, len(contents)) + for _, contentPart := range contents { + if contentPart.TextContent != nil { + textContentPart := contentPart.TextContent + chatMessage.Content = append(chatMessage.Content, &awsbedrock.ContentBlock{ + Text: &textContentPart.Text, + }) + } else if contentPart.ImageContent != nil { + imageContentPart := contentPart.ImageContent + parts := strings.Split(imageContentPart.ImageURL.URL, ",") + if len(parts) == 2 { + formatPart := strings.Split(parts[0], ";")[0] + format := strings.TrimPrefix(formatPart, "data:image/") + chatMessage.Content = append(chatMessage.Content, &awsbedrock.ContentBlock{ + Image: &awsbedrock.ImageBlock{ + Format: format, + Source: awsbedrock.ImageSource{ + Bytes: []byte(parts[1]), + }, + }, + }) + } else { + return fmt.Errorf("unexpected image data url") + } + } + } + bedrockReq.Messages = append(bedrockReq.Messages, chatMessage) + } else { + return fmt.Errorf("unexpected content type for user message") + } + } + case openai.ChatMessageRoleAssistant: + message := msg.Value.(openai.ChatCompletionAssistantMessageParam) + if message.Content.Type == openai.ChatCompletionAssistantMessageParamContentTypeRefusal { + bedrockReq.Messages = append(bedrockReq.Messages, &awsbedrock.Message{ + Role: msg.Type, + Content: []*awsbedrock.ContentBlock{ + {Text: message.Content.Refusal}, + }, + }) + } else { + bedrockReq.Messages = append(bedrockReq.Messages, &awsbedrock.Message{ + Role: msg.Type, + Content: []*awsbedrock.ContentBlock{ + {Text: message.Content.Text}, + }, + }) + } + case openai.ChatMessageRoleSystem: + message := msg.Value.(openai.ChatCompletionSystemMessageParam) + if bedrockReq.System == nil { + bedrockReq.System = []*awsbedrock.SystemContentBlock{} + } + + if _, ok := message.Content.Value.(string); ok { + bedrockReq.System = append(bedrockReq.System, &awsbedrock.SystemContentBlock{ + Text: message.Content.Value.(string), + }) + } else { + if contents, ok := message.Content.Value.([]openai.ChatCompletionContentPartTextParam); ok { + for _, contentPart := range contents { + textContentPart := contentPart.Text + bedrockReq.System = append(bedrockReq.System, &awsbedrock.SystemContentBlock{ + Text: textContentPart, + }) + } + } else { + return fmt.Errorf("unexpected content type for system message") + } + } + case openai.ChatMessageRoleTool: + message := msg.Value.(openai.ChatCompletionToolMessageParam) + bedrockReq.Messages = append(bedrockReq.Messages, &awsbedrock.Message{ + // bedrock does not support tool role, merging to the user role + Role: awsbedrock.ConversationRoleUser, + Content: []*awsbedrock.ContentBlock{ + { + ToolResult: &awsbedrock.ToolResultBlock{ + Content: []*awsbedrock.ToolResultContentBlock{ + { + Text: message.Content.Value.(*string), + }, + }, + }, + }, + }, + }) + default: + return fmt.Errorf("unexpected role: %s", msg.Type) + } + } + return nil +} + // ResponseHeaders implements [Translator.ResponseHeaders]. func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) ResponseHeaders(headers map[string]string) ( headerMutation *extprocv3.HeaderMutation, err error, @@ -169,29 +324,46 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) ResponseBody(body io.Read return headerMutation, &extprocv3.BodyMutation{Mutation: mut}, usedToken, nil } - var awsResp awsbedrock.ConverseResponse - if err := json.NewDecoder(body).Decode(&awsResp); err != nil { + var bedrockResp awsbedrock.ConverseOutput + if err := json.NewDecoder(body).Decode(&bedrockResp); err != nil { return nil, nil, 0, fmt.Errorf("failed to unmarshal body: %w", err) } - usedToken = uint32(awsResp.Usage.TotalTokens) + usedToken = uint32(bedrockResp.Usage.TotalTokens) openAIResp := openai.ChatCompletionResponse{ - Usage: openai.ChatCompletionResponseUsage{ - TotalTokens: awsResp.Usage.TotalTokens, - PromptTokens: awsResp.Usage.InputTokens, - CompletionTokens: awsResp.Usage.OutputTokens, - }, Object: "chat.completion", - Choices: make([]openai.ChatCompletionResponseChoice, 0, len(awsResp.Output.Message.Content)), + Choices: make([]openai.ChatCompletionResponseChoice, 0, len(bedrockResp.Output.Message.Content)), } - - for _, output := range awsResp.Output.Message.Content { - t := output.Text - openAIResp.Choices = append(openAIResp.Choices, openai.ChatCompletionResponseChoice{Message: openai.ChatCompletionResponseChoiceMessage{ - Content: &t, - Role: awsResp.Output.Message.Role, - }}) + if bedrockResp.Usage != nil { + openAIResp.Usage = openai.ChatCompletionResponseUsage{ + TotalTokens: bedrockResp.Usage.TotalTokens, + PromptTokens: bedrockResp.Usage.InputTokens, + CompletionTokens: bedrockResp.Usage.OutputTokens, + } + usedToken = uint32(bedrockResp.Usage.TotalTokens) + } + for i, output := range bedrockResp.Output.Message.Content { + choice := openai.ChatCompletionResponseChoice{ + Index: (int64)(i), + Message: openai.ChatCompletionResponseChoiceMessage{ + Content: output.Text, + Role: bedrockResp.Output.Message.Role, + }, + } + if bedrockResp.StopReason != nil { + switch *bedrockResp.StopReason { + case awsbedrock.StopReasonStopSequence, awsbedrock.StopReasonEndTurn: + choice.FinishReason = openai.ChatCompletionChoicesFinishReasonStop + case awsbedrock.StopReasonMaxTokens: + choice.FinishReason = openai.ChatCompletionChoicesFinishReasonLength + case awsbedrock.StopReasonContentFiltered: + choice.FinishReason = openai.ChatCompletionChoicesFinishReasonContentFilter + case awsbedrock.StopReasonToolUse: + choice.FinishReason = openai.ChatCompletionChoicesFinishReasonToolCalls + } + } + openAIResp.Choices = append(openAIResp.Choices, choice) } if body, err := json.Marshal(openAIResp); err != nil { diff --git a/internal/extproc/translator/openai_awsbedrock_test.go b/internal/extproc/translator/openai_awsbedrock_test.go index 4f56feba..f4bcc25f 100644 --- a/internal/extproc/translator/openai_awsbedrock_test.go +++ b/internal/extproc/translator/openai_awsbedrock_test.go @@ -4,7 +4,6 @@ import ( "bytes" "encoding/base64" "encoding/json" - "fmt" "strconv" "strings" "testing" @@ -12,7 +11,9 @@ import ( "github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream" extprocv3http "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/ext_proc/v3" extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/require" + "k8s.io/utils/ptr" "github.com/envoyproxy/ai-gateway/internal/apischema/awsbedrock" "github.com/envoyproxy/ai-gateway/internal/apischema/openai" @@ -37,66 +38,376 @@ func TestOpenAIToAWSBedrockTranslatorV1ChatCompletion_RequestBody(t *testing.T) _, _, _, err := o.RequestBody(&extprocv3.HttpBody{Body: []byte("invalid")}) require.Error(t, err) }) - t.Run("valid body", func(t *testing.T) { - contentify := func(msg string) any { - return []any{map[string]any{"text": msg}} - } - for _, stream := range []bool{true, false} { - t.Run(fmt.Sprintf("stream=%t", stream), func(t *testing.T) { - o := &openAIToAWSBedrockTranslatorV1ChatCompletion{} - originalReq := openai.ChatCompletionRequest{ - Stream: stream, - Model: "gpt-4o", - Messages: []openai.ChatCompletionRequestMessage{ - {Content: contentify("from-system"), Role: "system"}, - {Content: contentify("from-user"), Role: "user"}, - {Content: contentify("part1"), Role: "user"}, - {Content: contentify("part2"), Role: "user"}, - }, - } - - hm, bm, mode, err := o.RequestBody(router.RequestBody(&originalReq)) - var expPath string - if stream { - expPath = "/model/gpt-4o/converse-stream" - require.True(t, o.stream) - require.NotNil(t, mode) - require.Equal(t, extprocv3http.ProcessingMode_STREAMED, mode.ResponseBodyMode) - require.Equal(t, extprocv3http.ProcessingMode_SEND, mode.ResponseHeaderMode) - } else { - expPath = "/model/gpt-4o/converse" - require.False(t, o.stream) - require.Nil(t, mode) - } - require.NoError(t, err) - require.NotNil(t, hm) - require.NotNil(t, hm.SetHeaders) - require.Len(t, hm.SetHeaders, 2) - require.Equal(t, ":path", hm.SetHeaders[0].Header.Key) - require.Equal(t, expPath, string(hm.SetHeaders[0].Header.RawValue)) - require.Equal(t, "content-length", hm.SetHeaders[1].Header.Key) - newBody := bm.Mutation.(*extprocv3.BodyMutation_Body).Body - require.Equal(t, strconv.Itoa(len(newBody)), string(hm.SetHeaders[1].Header.RawValue)) + tests := []struct { + name string + output awsbedrock.ConverseInput + input openai.ChatCompletionRequest + }{ + { + name: "basic test", + input: openai.ChatCompletionRequest{ + Stream: false, + Model: "gpt-4o", + Messages: []openai.ChatCompletionMessageParamUnion{ + { + Value: openai.ChatCompletionSystemMessageParam{ + Content: openai.StringOrArray{ + Value: "from-system", + }, + }, Type: openai.ChatMessageRoleSystem, + }, + { + Value: openai.ChatCompletionUserMessageParam{ + Content: openai.StringOrUserRoleContentUnion{ + Value: "from-user", + }, + }, Type: openai.ChatMessageRoleUser, + }, + { + Value: openai.ChatCompletionUserMessageParam{ + Content: openai.StringOrUserRoleContentUnion{ + Value: "part1", + }, + }, Type: openai.ChatMessageRoleUser, + }, + { + Value: openai.ChatCompletionUserMessageParam{ + Content: openai.StringOrUserRoleContentUnion{ + Value: "part2", + }, + }, Type: openai.ChatMessageRoleUser, + }, + }, + }, + output: awsbedrock.ConverseInput{ + InferenceConfig: &awsbedrock.InferenceConfiguration{}, + System: []*awsbedrock.SystemContentBlock{ + { + Text: "from-system", + }, + }, + Messages: []*awsbedrock.Message{ + { + Role: openai.ChatMessageRoleUser, + Content: []*awsbedrock.ContentBlock{ + { + Text: ptr.To("from-user"), + }, + }, + }, + { + Role: openai.ChatMessageRoleUser, + Content: []*awsbedrock.ContentBlock{ + { + Text: ptr.To("part1"), + }, + }, + }, + { + Role: openai.ChatMessageRoleUser, + Content: []*awsbedrock.ContentBlock{ + { + Text: ptr.To("part2"), + }, + }, + }, + }, + }, + }, + { + name: "test content array", + input: openai.ChatCompletionRequest{ + Stream: false, + Model: "gpt-4o", + Messages: []openai.ChatCompletionMessageParamUnion{ + { + Value: openai.ChatCompletionSystemMessageParam{ + Content: openai.StringOrArray{ + Value: []openai.ChatCompletionContentPartTextParam{ + {Text: "from-system"}, + }, + }, + }, Type: openai.ChatMessageRoleSystem, + }, + { + Value: openai.ChatCompletionUserMessageParam{ + Content: openai.StringOrUserRoleContentUnion{ + Value: []openai.ChatCompletionContentPartUserUnionParam{ + {TextContent: &openai.ChatCompletionContentPartTextParam{Text: "from-user"}}, + }, + }, + }, Type: openai.ChatMessageRoleUser, + }, + { + Value: openai.ChatCompletionUserMessageParam{ + Content: openai.StringOrUserRoleContentUnion{ + Value: []openai.ChatCompletionContentPartUserUnionParam{ + {TextContent: &openai.ChatCompletionContentPartTextParam{Text: "user1"}}, + }, + }, + }, Type: openai.ChatMessageRoleUser, + }, + { + Value: openai.ChatCompletionUserMessageParam{ + Content: openai.StringOrUserRoleContentUnion{ + Value: []openai.ChatCompletionContentPartUserUnionParam{ + {TextContent: &openai.ChatCompletionContentPartTextParam{Text: "user2"}}, + }, + }, + }, Type: openai.ChatMessageRoleUser, + }, + }, + }, + output: awsbedrock.ConverseInput{ + InferenceConfig: &awsbedrock.InferenceConfiguration{}, + System: []*awsbedrock.SystemContentBlock{ + { + Text: "from-system", + }, + }, + Messages: []*awsbedrock.Message{ + { + Role: openai.ChatMessageRoleUser, + Content: []*awsbedrock.ContentBlock{ + { + Text: ptr.To("from-user"), + }, + }, + }, + { + Role: openai.ChatMessageRoleUser, + Content: []*awsbedrock.ContentBlock{ + { + Text: ptr.To("user1"), + }, + }, + }, + { + Role: openai.ChatMessageRoleUser, + Content: []*awsbedrock.ContentBlock{ + { + Text: ptr.To("user2"), + }, + }, + }, + }, + }, + }, + { + name: "test image", + input: openai.ChatCompletionRequest{ + Stream: false, + Model: "gpt-4o", + Messages: []openai.ChatCompletionMessageParamUnion{ + { + Value: openai.ChatCompletionSystemMessageParam{ + Content: openai.StringOrArray{ + Value: []openai.ChatCompletionContentPartTextParam{ + {Text: "from-system"}, + }, + }, + }, Type: openai.ChatMessageRoleSystem, + }, + { + Value: openai.ChatCompletionUserMessageParam{ + Content: openai.StringOrUserRoleContentUnion{ + Value: []openai.ChatCompletionContentPartUserUnionParam{ + {ImageContent: &openai.ChatCompletionContentPartImageParam{ + ImageURL: openai.ChatCompletionContentPartImageImageURLParam{ + URL: "data:image/jpeg;base64,dGVzdAo=", + }, + }}, + }, + }, + }, Type: openai.ChatMessageRoleUser, + }, + }, + }, + output: awsbedrock.ConverseInput{ + InferenceConfig: &awsbedrock.InferenceConfiguration{}, + System: []*awsbedrock.SystemContentBlock{ + { + Text: "from-system", + }, + }, + Messages: []*awsbedrock.Message{ + { + Role: openai.ChatMessageRoleUser, + Content: []*awsbedrock.ContentBlock{ + { + Image: &awsbedrock.ImageBlock{ + Source: awsbedrock.ImageSource{ + Bytes: []byte("dGVzdAo="), + }, + Format: "jpeg", + }, + }, + }, + }, + }, + }, + }, + { + name: "test parameters", + input: openai.ChatCompletionRequest{ + Stream: false, + Model: "gpt-4o", + MaxTokens: ptr.To(int64(10)), + TopP: ptr.To(float64(1)), + Temperature: ptr.To(0.7), + Messages: []openai.ChatCompletionMessageParamUnion{ + { + Value: openai.ChatCompletionUserMessageParam{ + Content: openai.StringOrUserRoleContentUnion{ + Value: "from-user", + }, + }, Type: openai.ChatMessageRoleUser, + }, + }, + }, + output: awsbedrock.ConverseInput{ + InferenceConfig: &awsbedrock.InferenceConfiguration{ + MaxTokens: ptr.To(int64(10)), + TopP: ptr.To(float64(1)), + Temperature: ptr.To(0.7), + }, + Messages: []*awsbedrock.Message{ + { + Role: openai.ChatMessageRoleUser, + Content: []*awsbedrock.ContentBlock{ + { + Text: ptr.To("from-user"), + }, + }, + }, + }, + }, + }, + { + name: "test function calling", + input: openai.ChatCompletionRequest{ + Stream: false, + Model: "gpt-4o", + MaxTokens: ptr.To(int64(10)), + TopP: ptr.To(float64(1)), + Temperature: ptr.To(0.7), + Messages: []openai.ChatCompletionMessageParamUnion{ + { + Value: openai.ChatCompletionUserMessageParam{ + Content: openai.StringOrUserRoleContentUnion{ + Value: "from-user", + }, + }, Type: openai.ChatMessageRoleUser, + }, + }, + Tools: []openai.Tool{ + { + Type: "function", + Function: &openai.FunctionDefinition{ + Name: "get_current_weather", + Description: "Get the current weather in a given location", + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "location": map[string]interface{}{ + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": map[string]interface{}{ + "type": "string", + "enum": []string{"celsius", "fahrenheit"}, + }, + }, + "required": []string{"location"}, + }, + }, + }, + }, + ToolChoice: "auto", + }, + output: awsbedrock.ConverseInput{ + InferenceConfig: &awsbedrock.InferenceConfiguration{ + MaxTokens: ptr.To(int64(10)), + TopP: ptr.To(float64(1)), + Temperature: ptr.To(0.7), + }, + Messages: []*awsbedrock.Message{ + { + Role: openai.ChatMessageRoleUser, + Content: []*awsbedrock.ContentBlock{ + { + Text: ptr.To("from-user"), + }, + }, + }, + }, + ToolConfig: &awsbedrock.ToolConfiguration{ + ToolChoice: &awsbedrock.ToolChoice{ + Auto: &awsbedrock.AutoToolChoice{}, + }, + Tools: []*awsbedrock.Tool{ + { + ToolSpec: &awsbedrock.ToolSpecification{ + Name: ptr.To("function"), + Description: ptr.To("Get the current weather in a given location"), + InputSchema: &awsbedrock.ToolInputSchema{ + JSON: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "location": map[string]interface{}{ + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": map[string]interface{}{ + "type": "string", + "enum": []any{"celsius", "fahrenheit"}, + }, + }, + "required": []any{"location"}, + }, + }, + }, + }, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + o := &openAIToAWSBedrockTranslatorV1ChatCompletion{} + originalReq := tt.input + hm, bm, mode, err := o.RequestBody(router.RequestBody(&originalReq)) + var expPath string + if tt.input.Stream { + expPath = "/model/gpt-4o/converse-stream" + require.True(t, o.stream) + require.NotNil(t, mode) + require.Equal(t, extprocv3http.ProcessingMode_STREAMED, mode.ResponseBodyMode) + require.Equal(t, extprocv3http.ProcessingMode_SEND, mode.ResponseHeaderMode) + } else { + expPath = "/model/gpt-4o/converse" + require.False(t, o.stream) + require.Nil(t, mode) + } + require.NoError(t, err) + require.NotNil(t, hm) + require.NotNil(t, hm.SetHeaders) + require.Len(t, hm.SetHeaders, 2) + require.Equal(t, ":path", hm.SetHeaders[0].Header.Key) + require.Equal(t, expPath, string(hm.SetHeaders[0].Header.RawValue)) + require.Equal(t, "content-length", hm.SetHeaders[1].Header.Key) + newBody := bm.Mutation.(*extprocv3.BodyMutation_Body).Body + require.Equal(t, strconv.Itoa(len(newBody)), string(hm.SetHeaders[1].Header.RawValue)) - var awsReq awsbedrock.ConverseRequest - err = json.Unmarshal(newBody, &awsReq) - require.NoError(t, err) - require.NotNil(t, awsReq.Messages) - require.Len(t, awsReq.Messages, 4) - for _, msg := range awsReq.Messages { - t.Log(msg) - } - require.Equal(t, "assistant", awsReq.Messages[0].Role) - require.Equal(t, "from-system", awsReq.Messages[0].Content[0].Text) - require.Equal(t, "user", awsReq.Messages[1].Role) - require.Equal(t, "from-user", awsReq.Messages[1].Content[0].Text) - require.Equal(t, "user", awsReq.Messages[2].Role) - require.Equal(t, "part1", awsReq.Messages[2].Content[0].Text) - require.Equal(t, "user", awsReq.Messages[3].Role) - require.Equal(t, "part2", awsReq.Messages[3].Content[0].Text) - }) - } - }) + var awsReq awsbedrock.ConverseInput + err = json.Unmarshal(newBody, &awsReq) + require.NoError(t, err) + if !cmp.Equal(awsReq, tt.output) { + t.Errorf("ConvertOpenAIToBedrock(), diff(got, expected) = %s\n", cmp.Diff(awsReq, tt.output)) + } + }) + } } func TestOpenAIToAWSBedrockTranslatorV1ChatCompletion_ResponseHeaders(t *testing.T) { @@ -120,7 +431,7 @@ func TestOpenAIToAWSBedrockTranslatorV1ChatCompletion_ResponseHeaders(t *testing }) } -func TestOpenAIToAWSBedrockTranslatorV1ChatCompletion_ResponseBody(t *testing.T) { +func TestOpenAIToAWSBedrockTranslatorV1ChatCompletion_Streaming_ResponseBody(t *testing.T) { t.Run("streaming", func(t *testing.T) { o := &openAIToAWSBedrockTranslatorV1ChatCompletion{stream: true} buf, err := base64.StdEncoding.DecodeString(base64RealStreamingEvents) @@ -161,33 +472,113 @@ data: {"object":"chat.completion.chunk","usage":{"completion_tokens":36,"prompt_ data: [DONE] `, result) }) - t.Run("non-streaming", func(t *testing.T) { - t.Run("invalid body", func(t *testing.T) { - o := &openAIToAWSBedrockTranslatorV1ChatCompletion{} - _, _, _, err := o.ResponseBody(bytes.NewBuffer([]byte("invalid")), false) - require.Error(t, err) - }) - t.Run("valid body", func(t *testing.T) { - originalAWSResp := awsbedrock.ConverseResponse{ - Usage: awsbedrock.TokenUsage{ +} + +func TestOpenAIToAWSBedrockTranslatorV1ChatCompletion_ResponseBody(t *testing.T) { + t.Run("invalid body", func(t *testing.T) { + o := &openAIToAWSBedrockTranslatorV1ChatCompletion{} + _, _, _, err := o.ResponseBody(bytes.NewBuffer([]byte("invalid")), false) + require.Error(t, err) + }) + tests := []struct { + name string + input awsbedrock.ConverseOutput + output openai.ChatCompletionResponse + }{ + { + name: "basic_testing", + input: awsbedrock.ConverseOutput{ + Usage: &awsbedrock.TokenUsage{ InputTokens: 10, OutputTokens: 20, TotalTokens: 30, }, - Output: awsbedrock.ConverseResponseOutput{ + Output: &awsbedrock.ConverseOutput_{ Message: awsbedrock.Message{ Role: "assistant", - Content: []awsbedrock.ContentBlock{ - {Text: "response"}, - {Text: "from"}, - {Text: "assistant"}, + Content: []*awsbedrock.ContentBlock{ + {Text: ptr.To("response")}, + {Text: ptr.To("from")}, + {Text: ptr.To("assistant")}, }, }, }, - } - body, err := json.Marshal(originalAWSResp) - require.NoError(t, err) + }, + output: openai.ChatCompletionResponse{ + Object: "chat.completion", + Usage: openai.ChatCompletionResponseUsage{ + TotalTokens: 30, + PromptTokens: 10, + CompletionTokens: 20, + }, + Choices: []openai.ChatCompletionResponseChoice{ + { + Index: 0, + Message: openai.ChatCompletionResponseChoiceMessage{ + Content: ptr.To("response"), + Role: "assistant", + }, + }, + { + Index: 1, + Message: openai.ChatCompletionResponseChoiceMessage{ + Content: ptr.To("from"), + Role: "assistant", + }, + }, + { + Index: 2, + Message: openai.ChatCompletionResponseChoiceMessage{ + Content: ptr.To("assistant"), + Role: "assistant", + }, + }, + }, + }, + }, + { + name: "test stop reason", + input: awsbedrock.ConverseOutput{ + Usage: &awsbedrock.TokenUsage{ + InputTokens: 10, + OutputTokens: 20, + TotalTokens: 30, + }, + StopReason: ptr.To("stop_sequence"), + Output: &awsbedrock.ConverseOutput_{ + Message: awsbedrock.Message{ + Role: awsbedrock.ConversationRoleAssistant, + Content: []*awsbedrock.ContentBlock{ + {Text: ptr.To("response")}, + }, + }, + }, + }, + output: openai.ChatCompletionResponse{ + Object: "chat.completion", + Usage: openai.ChatCompletionResponseUsage{ + TotalTokens: 30, + PromptTokens: 10, + CompletionTokens: 20, + }, + Choices: []openai.ChatCompletionResponseChoice{ + { + Index: 0, + FinishReason: openai.ChatCompletionChoicesFinishReasonStop, + Message: openai.ChatCompletionResponseChoiceMessage{ + Content: ptr.To("response"), + Role: awsbedrock.ConversationRoleAssistant, + }, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + body, err := json.Marshal(tt.input) + require.NoError(t, err) o := &openAIToAWSBedrockTranslatorV1ChatCompletion{} hm, bm, usedToken, err := o.ResponseBody(bytes.NewBuffer(body), false) require.NoError(t, err) @@ -205,20 +596,12 @@ data: [DONE] var openAIResp openai.ChatCompletionResponse err = json.Unmarshal(newBody, &openAIResp) require.NoError(t, err) - require.NotNil(t, openAIResp.Usage) require.Equal(t, uint32(30), usedToken) - require.Equal(t, 30, openAIResp.Usage.TotalTokens) - require.Equal(t, 10, openAIResp.Usage.PromptTokens) - require.Equal(t, 20, openAIResp.Usage.CompletionTokens) - - require.NotNil(t, openAIResp.Choices) - require.Len(t, openAIResp.Choices, 3) - - require.Equal(t, "response", *openAIResp.Choices[0].Message.Content) - require.Equal(t, "from", *openAIResp.Choices[1].Message.Content) - require.Equal(t, "assistant", *openAIResp.Choices[2].Message.Content) + if !cmp.Equal(openAIResp, tt.output) { + t.Errorf("ConvertOpenAIToBedrock(), diff(got, expected) = %s\n", cmp.Diff(openAIResp, tt.output)) + } }) - }) + } } const base64RealStreamingEvents = "AAAAnwAAAFKzEV9wCzpldmVudC10eXBlBwAMbWVzc2FnZVN0YXJ0DTpjb250ZW50LXR5cGUHABBhcHBsaWNhdGlvbi9qc29uDTptZXNzYWdlLXR5cGUHAAVldmVudHsicCI6ImFiY2RlZmdoaWprbG1ub3BxcnN0dXZ3eHl6QUJDREVGR0giLCJyb2xlIjoiYXNzaXN0YW50In0i9wVBAAAAxQAAAFex2HyVCzpldmVudC10eXBlBwARY29udGVudEJsb2NrRGVsdGENOmNvbnRlbnQtdHlwZQcAEGFwcGxpY2F0aW9uL2pzb24NOm1lc3NhZ2UtdHlwZQcABWV2ZW50eyJjb250ZW50QmxvY2tJbmRleCI6MCwiZGVsdGEiOnsidGV4dCI6IkRvbiJ9LCJwIjoiYWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXpBQkNERUZHSElKS0xNTk8ifb/whawAAADAAAAAV3k48+ULOmV2ZW50LXR5cGUHABFjb250ZW50QmxvY2tEZWx0YQ06Y29udGVudC10eXBlBwAQYXBwbGljYXRpb24vanNvbg06bWVzc2FnZS10eXBlBwAFZXZlbnR7ImNvbnRlbnRCbG9ja0luZGV4IjowLCJkZWx0YSI6eyJ0ZXh0IjoiJ3Qgd29ycnksIEknbSBoZXJlIHRvIGhlbHAuIEl0In0sInAiOiJhYmNkZWZnaGkifenahv0AAADgAAAAV7j53OELOmV2ZW50LXR5cGUHABFjb250ZW50QmxvY2tEZWx0YQ06Y29udGVudC10eXBlBwAQYXBwbGljYXRpb24vanNvbg06bWVzc2FnZS10eXBlBwAFZXZlbnR7ImNvbnRlbnRCbG9ja0luZGV4IjowLCJkZWx0YSI6eyJ0ZXh0IjoiIHNlZW1zIGxpa2UgeW91J3JlIHRlc3RpbmcgbXkgYWJpbGl0eSB0byByZXNwb25kIGFwcHJvcHJpYXRlbHkifSwicCI6ImFiY2RlZmdoaSJ9dNZCqAAAAM8AAABX+2hkNAs6ZXZlbnQtdHlwZQcAEWNvbnRlbnRCbG9ja0RlbHRhDTpjb250ZW50LXR5cGUHABBhcHBsaWNhdGlvbi9qc29uDTptZXNzYWdlLXR5cGUHAAVldmVudHsiY29udGVudEJsb2NrSW5kZXgiOjAsImRlbHRhIjp7InRleHQiOiIuIElmIHlvdSdkIGxpa2UgdG8gY29udGludWUgdGhlIHRlc3QsIn0sInAiOiJhYmNkZWZnaGlqa2xtbm9wcSJ9xQJqAgAAALUAAABXSAqcWgs6ZXZlbnQtdHlwZQcAEWNvbnRlbnRCbG9ja0RlbHRhDTpjb250ZW50LXR5cGUHABBhcHBsaWNhdGlvbi9qc29uDTptZXNzYWdlLXR5cGUHAAVldmVudHsiY29udGVudEJsb2NrSW5kZXgiOjAsImRlbHRhIjp7InRleHQiOiIgSSdtIHJlYWR5LiJ9LCJwIjoiYWJjZGVmZ2hpamtsbW5vcHEifTOb7esAAAC5AAAAVvr9Qc0LOmV2ZW50LXR5cGUHABBjb250ZW50QmxvY2tTdG9wDTpjb250ZW50LXR5cGUHABBhcHBsaWNhdGlvbi9qc29uDTptZXNzYWdlLXR5cGUHAAVldmVudHsiY29udGVudEJsb2NrSW5kZXgiOjAsInAiOiJhYmNkZWZnaGlqa2xtbm9wcXJzdHV2d3h5ekFCQ0RFRkdISUpLTE1OT1BRUlNUVVZXWFlaMCJ9iABE1AAAAI0AAABRMDjKKAs6ZXZlbnQtdHlwZQcAC21lc3NhZ2VTdG9wDTpjb250ZW50LXR5cGUHABBhcHBsaWNhdGlvbi9qc29uDTptZXNzYWdlLXR5cGUHAAVldmVudHsicCI6ImFiY2RlZmdoaWprbCIsInN0b3BSZWFzb24iOiJlbmRfdHVybiJ9LttU3QAAAPoAAABO9sL7Ags6ZXZlbnQtdHlwZQcACG1ldGFkYXRhDTpjb250ZW50LXR5cGUHABBhcHBsaWNhdGlvbi9qc29uDTptZXNzYWdlLXR5cGUHAAVldmVudHsibWV0cmljcyI6eyJsYXRlbmN5TXMiOjQ1Mn0sInAiOiJhYmNkZWZnaGlqa2xtbm9wcXJzdHV2d3h5ekFCQ0RFRkdISUpLTE1OT1BRUlNUVVZXWFlaMDEyMzQ1IiwidXNhZ2UiOnsiaW5wdXRUb2tlbnMiOjQxLCJvdXRwdXRUb2tlbnMiOjM2LCJ0b3RhbFRva2VucyI6Nzd9fX96gYI="