Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add developer message for o1 models #4923

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
FunctionExecutionResult,
FunctionExecutionResultMessage,
SystemMessage,
DeveloperMessage,
UserMessage,
)
from autogen_core.tools import FunctionTool, Tool
Expand Down Expand Up @@ -238,15 +239,22 @@ def __init__(
system_message: (
str | None
) = "You are a helpful AI assistant. Solve tasks using your tools. Reply with TERMINATE when the task has been completed.",
developer_message: (
str | None
) = None,
reflect_on_tool_use: bool = False,
tool_call_summary_format: str = "{result}",
):
super().__init__(name=name, description=description)
self._model_client = model_client
if system_message is None:
if system_message is None or developer_message is not None:
Copy link
Member

@afourney afourney Jan 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than checking which system message type we've received, we should be checking which model family we were given. We can then load defaults accordingly, and we can turn the system message into a developer message as needed (again, I might recommend a name like sys_or_dev_message -- but that would be a breaking change I guess)

self._system_messages = []
else:
self._system_messages = [SystemMessage(content=system_message)]
if developer_message is None:
self._developer_messages = []
else:
self._developer_messages = [DeveloperMessage(content=developer_message)]
self._tools: List[Tool] = []
if tools is not None:
if model_client.model_info["function_calling"] is False:
Expand Down Expand Up @@ -326,7 +334,10 @@ async def on_messages_stream(
inner_messages: List[AgentEvent | ChatMessage] = []

# Generate an inference result based on the current model context.
llm_messages = self._system_messages + await self._model_context.get_messages()
if self._developer_messages:
llm_messages = self._developer_messages + await self._model_context.get_messages()
else:
llm_messages = self._system_messages + await self._model_context.get_messages()
result = await self._model_client.create(
llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token
)
Expand Down Expand Up @@ -379,7 +390,10 @@ async def on_messages_stream(

if self._reflect_on_tool_use:
# Generate another inference result based on the tool call and result.
llm_messages = self._system_messages + await self._model_context.get_messages()
if self._developer_messages:
bassmang marked this conversation as resolved.
Show resolved Hide resolved
llm_messages = self._developer_messages + await self._model_context.get_messages()
else:
llm_messages = self._system_messages + await self._model_context.get_messages()
result = await self._model_client.create(llm_messages, cancellation_token=cancellation_token)
assert isinstance(result.content, str)
# Add the response to the model context.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
LLMMessage,
RequestUsage,
SystemMessage,
DeveloperMessage,
TopLogprob,
UserMessage,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
class SystemMessage(BaseModel):
content: str
type: Literal["SystemMessage"] = "SystemMessage"


class DeveloperMessage(BaseModel):
content: str
type: Literal["DeveloperMessage"] = "DeveloperMessage"

class UserMessage(BaseModel):
content: Union[str, List[Union[str, Image]]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
ModelInfo,
RequestUsage,
SystemMessage,
DeveloperMessage,
TopLogprob,
UserMessage,
)
Expand All @@ -58,6 +59,7 @@
ChatCompletionMessageToolCallParam,
ChatCompletionRole,
ChatCompletionSystemMessageParam,
ChatCompletionDeveloperMessageParam,
ChatCompletionToolMessageParam,
ChatCompletionToolParam,
ChatCompletionUserMessageParam,
Expand Down Expand Up @@ -170,6 +172,13 @@ def system_message_to_oai(message: SystemMessage) -> ChatCompletionSystemMessage
content=message.content,
role="system",
)


def developer_message_to_oai(message: DeveloperMessage) -> ChatCompletionDeveloperMessageParam:
return ChatCompletionDeveloperMessageParam(
content=message.content,
role="developer",
)


def func_call_to_oai(message: FunctionCall) -> ChatCompletionMessageToolCallParam:
Expand Down Expand Up @@ -212,6 +221,8 @@ def assistant_message_to_oai(
def to_oai_type(message: LLMMessage) -> Sequence[ChatCompletionMessageParam]:
if isinstance(message, SystemMessage):
return [system_message_to_oai(message)]
elif isinstance(message, DeveloperMessage):
return [developer_message_to_oai(message)]
elif isinstance(message, UserMessage):
return [user_message_to_oai(message)]
elif isinstance(message, AssistantMessage):
Expand Down
Loading