From 21c7295ef91d272008ec9049c855c6e9ce22a57a Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 10 Jan 2025 11:02:22 -0800 Subject: [PATCH] Add back all_rulesets in favor of rulesets property cleverness --- .../response/prompt_response_rag_module.py | 4 +- griptape/mixins/rule_mixin.py | 8 +-- griptape/tasks/base_image_generation_task.py | 2 +- griptape/tasks/extraction_task.py | 4 +- griptape/tasks/prompt_task.py | 10 ++-- griptape/tasks/text_summary_task.py | 2 +- griptape/tasks/tool_task.py | 2 +- griptape/tools/prompt_summary/tool.py | 2 +- griptape/tools/query/tool.py | 2 +- tests/unit/mixins/test_rule_mixin.py | 51 +++++++++---------- tests/unit/structures/test_agent.py | 24 ++++----- tests/unit/structures/test_pipeline.py | 31 +++++------ tests/unit/structures/test_structure.py | 2 + tests/unit/structures/test_workflow.py | 28 +++++----- tests/unit/tasks/test_base_text_input_task.py | 8 +-- tests/unit/tasks/test_prompt_task.py | 18 +++---- tests/unit/tasks/test_tool_task.py | 1 + tests/unit/tasks/test_toolkit_task.py | 1 + tests/utils/structure_tester.py | 2 +- 19 files changed, 102 insertions(+), 100 deletions(-) diff --git a/griptape/engines/rag/modules/response/prompt_response_rag_module.py b/griptape/engines/rag/modules/response/prompt_response_rag_module.py index 2e4f39947..29e160a24 100644 --- a/griptape/engines/rag/modules/response/prompt_response_rag_module.py +++ b/griptape/engines/rag/modules/response/prompt_response_rag_module.py @@ -56,8 +56,8 @@ def run(self, context: RagContext) -> BaseArtifact: def default_generate_system_template(self, context: RagContext, artifacts: list[TextArtifact]) -> str: params: dict[str, Any] = {"text_chunks": [c.to_text() for c in artifacts]} - if len(self.rulesets) > 0: - params["rulesets"] = J2("rulesets/rulesets.j2").render(rulesets=self.rulesets) + if len(self.all_rulesets) > 0: + params["rulesets"] = J2("rulesets/rulesets.j2").render(rulesets=self.all_rulesets) if self.metadata is not None: params["metadata"] = J2("engines/rag/modules/response/metadata/system.j2").render(metadata=self.metadata) diff --git a/griptape/mixins/rule_mixin.py b/griptape/mixins/rule_mixin.py index cae731865..6d573c61a 100644 --- a/griptape/mixins/rule_mixin.py +++ b/griptape/mixins/rule_mixin.py @@ -12,14 +12,14 @@ class RuleMixin(SerializableMixin): DEFAULT_RULESET_NAME = "Default Ruleset" - _rulesets: list[Ruleset] = field(factory=list, kw_only=True, alias="rulesets", metadata={"serializable": True}) - rules: list[BaseRule] = field(factory=list, kw_only=True) + rulesets: list[Ruleset] = field(factory=list, kw_only=True, metadata={"serializable": True}) + rules: list[BaseRule] = field(factory=list, kw_only=True, metadata={"serializable": True}) _default_ruleset_name: str = field(default=Factory(lambda: RuleMixin.DEFAULT_RULESET_NAME), kw_only=True) _default_ruleset_id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True) @property - def rulesets(self) -> list[Ruleset]: - rulesets = self._rulesets.copy() + def all_rulesets(self) -> list[Ruleset]: + rulesets = self.rulesets.copy() if self.rules: rulesets.append(Ruleset(id=self._default_ruleset_id, name=self._default_ruleset_name, rules=self.rules)) diff --git a/griptape/tasks/base_image_generation_task.py b/griptape/tasks/base_image_generation_task.py index 8705d489b..31ac1d660 100644 --- a/griptape/tasks/base_image_generation_task.py +++ b/griptape/tasks/base_image_generation_task.py @@ -59,7 +59,7 @@ def _read_from_file(self, path: str) -> ImageArtifact: return ImageLoader().load(Path(path)) def _get_prompts(self, prompt: str) -> list[str]: - return [prompt, *[rule.value for ruleset in self.rulesets for rule in ruleset.rules]] + return [prompt, *[rule.value for ruleset in self.all_rulesets for rule in ruleset.rules]] def _get_negative_prompts(self) -> list[str]: return [rule.value for ruleset in self.negative_rulesets for rule in ruleset.rules] diff --git a/griptape/tasks/extraction_task.py b/griptape/tasks/extraction_task.py index 35cc83003..5d5a3b3f2 100644 --- a/griptape/tasks/extraction_task.py +++ b/griptape/tasks/extraction_task.py @@ -18,4 +18,6 @@ class ExtractionTask(BaseTextInputTask): args: dict = field(kw_only=True, factory=dict) def try_run(self) -> ListArtifact | ErrorArtifact: - return self.extraction_engine.extract_artifacts(ListArtifact([self.input]), rulesets=self.rulesets, **self.args) + return self.extraction_engine.extract_artifacts( + ListArtifact([self.input]), rulesets=self.all_rulesets, **self.args + ) diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index d2de27ac6..224056742 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -65,13 +65,13 @@ class PromptTask(BaseTask, RuleMixin, ActionsSubtaskOriginMixin): response_stop_sequence: str = field(default=RESPONSE_STOP_SEQUENCE, kw_only=True) @property - def rulesets(self) -> list: + def all_rulesets(self) -> list: default_rules = self.rules - rulesets = self._rulesets.copy() + rulesets = self.rulesets.copy() if self.structure is not None: - if self.structure._rulesets: - rulesets = self.structure._rulesets + self._rulesets + if self.structure.rulesets: + rulesets = self.structure.rulesets + self.rulesets if self.structure.rules: default_rules = self.structure.rules + self.rules @@ -213,7 +213,7 @@ def default_generate_system_template(self, _: PromptTask) -> str: schema["minItems"] = 1 # The `schema` library doesn't support `minItems` so we must add it manually. return J2("tasks/prompt_task/system.j2").render( - rulesets=J2("rulesets/rulesets.j2").render(rulesets=self.rulesets), + rulesets=J2("rulesets/rulesets.j2").render(rulesets=self.all_rulesets), action_names=str.join(", ", [tool.name for tool in self.tools]), actions_schema=utils.minify_json(json.dumps(schema)), meta_memory=J2("memory/meta/meta_memory.j2").render(meta_memories=self.meta_memories), diff --git a/griptape/tasks/text_summary_task.py b/griptape/tasks/text_summary_task.py index 5cb7510ac..68e9f2d87 100644 --- a/griptape/tasks/text_summary_task.py +++ b/griptape/tasks/text_summary_task.py @@ -17,4 +17,4 @@ class TextSummaryTask(BaseTextInputTask): summary_engine: BaseSummaryEngine = field(default=Factory(lambda: PromptSummaryEngine()), kw_only=True) def try_run(self) -> TextArtifact: - return TextArtifact(self.summary_engine.summarize_text(self.input.to_text(), rulesets=self.rulesets)) + return TextArtifact(self.summary_engine.summarize_text(self.input.to_text(), rulesets=self.all_rulesets)) diff --git a/griptape/tasks/tool_task.py b/griptape/tasks/tool_task.py index 200a230c6..9ee6c7836 100644 --- a/griptape/tasks/tool_task.py +++ b/griptape/tasks/tool_task.py @@ -54,7 +54,7 @@ def preprocess(self, structure: Structure) -> ToolTask: def default_generate_system_template(self, _: PromptTask) -> str: return J2("tasks/tool_task/system.j2").render( - rulesets=J2("rulesets/rulesets.j2").render(rulesets=self.rulesets), + rulesets=J2("rulesets/rulesets.j2").render(rulesets=self.all_rulesets), action_schema=utils.minify_json(json.dumps(self.tool.schema())), meta_memory=J2("memory/meta/meta_memory.j2").render(meta_memories=self.meta_memories), use_native_tools=self.prompt_driver.use_native_tools, diff --git a/griptape/tools/prompt_summary/tool.py b/griptape/tools/prompt_summary/tool.py index e12b19547..517507380 100644 --- a/griptape/tools/prompt_summary/tool.py +++ b/griptape/tools/prompt_summary/tool.py @@ -52,4 +52,4 @@ def summarize(self, params: dict) -> BaseArtifact: else: return ErrorArtifact("memory not found") - return self.prompt_summary_engine.summarize_artifacts(artifacts, rulesets=self.rulesets) + return self.prompt_summary_engine.summarize_artifacts(artifacts, rulesets=self.all_rulesets) diff --git a/griptape/tools/query/tool.py b/griptape/tools/query/tool.py index 0274e7940..de713cb84 100644 --- a/griptape/tools/query/tool.py +++ b/griptape/tools/query/tool.py @@ -32,7 +32,7 @@ class QueryTool(BaseTool, RuleMixin): lambda self: RagEngine( response_stage=ResponseRagStage( response_modules=[ - PromptResponseRagModule(prompt_driver=self.prompt_driver, rulesets=self.rulesets) + PromptResponseRagModule(prompt_driver=self.prompt_driver, rulesets=self.all_rulesets) ], ), ), diff --git a/tests/unit/mixins/test_rule_mixin.py b/tests/unit/mixins/test_rule_mixin.py index 19d3cbeac..c181dc315 100644 --- a/tests/unit/mixins/test_rule_mixin.py +++ b/tests/unit/mixins/test_rule_mixin.py @@ -15,21 +15,22 @@ def test_rules(self): assert mixin.rules == [rule] def test_rulesets(self): - ruleset = Ruleset("foo", [Rule("bar")]) - mixin = RuleMixin(rulesets=[ruleset]) + ruleset1 = Ruleset("foo", [Rule("bar")]) + mixin = RuleMixin(rulesets=[ruleset1]) - assert mixin.rulesets == [ruleset] - mixin.rulesets.append(Ruleset("baz", [Rule("qux")])) - assert mixin.rulesets == [ruleset] + assert mixin.rulesets == [ruleset1] + ruleset2 = Ruleset("baz", [Rule("qux")]) + mixin.rulesets.append(ruleset2) + assert mixin.rulesets == [ruleset1, ruleset2] def test_rules_and_rulesets(self): mixin = RuleMixin(rules=[Rule("foo")], rulesets=[Ruleset("bar", [Rule("baz")])]) assert mixin.rules == [Rule("foo")] - assert mixin.rulesets[0].name == "bar" - assert mixin.rulesets[0].rules == [Rule("baz")] - assert mixin.rulesets[1].name == "Default Ruleset" - assert mixin.rulesets[1].rules == [Rule("foo")] + assert mixin.all_rulesets[0].name == "bar" + assert mixin.all_rulesets[0].rules == [Rule("baz")] + assert mixin.all_rulesets[1].name == "Default Ruleset" + assert mixin.all_rulesets[1].rules == [Rule("foo")] def test_inherits_structure_rulesets(self): # Tests that a task using the mixin inherits rulesets from its structure. @@ -40,7 +41,7 @@ def test_inherits_structure_rulesets(self): task = PromptTask(rulesets=[ruleset2]) agent.add_task(task) - assert task.rulesets == [ruleset1, ruleset2] + assert task.all_rulesets == [ruleset1, ruleset2] def test_to_dict(self): mixin = RuleMixin( @@ -60,31 +61,25 @@ def test_to_dict(self): ) assert mixin.to_dict() == { + "rules": [ + {"type": "Rule", "value": "foo"}, + { + "type": "JsonSchemaRule", + "value": { + "properties": {"foo": {"type": "string"}}, + "required": ["foo"], + "type": "object", + }, + }, + ], "rulesets": [ { - "id": mixin.rulesets[0].id, + "id": mixin.all_rulesets[0].id, "meta": {}, "name": "bar", "rules": [{"type": "Rule", "value": "baz"}], "type": "Ruleset", }, - { - "name": "Default Ruleset", - "id": mixin.rulesets[1].id, - "meta": {}, - "rules": [ - {"type": "Rule", "value": "foo"}, - { - "type": "JsonSchemaRule", - "value": { - "properties": {"foo": {"type": "string"}}, - "required": ["foo"], - "type": "object", - }, - }, - ], - "type": "Ruleset", - }, ], "type": "RuleMixin", } diff --git a/tests/unit/structures/test_agent.py b/tests/unit/structures/test_agent.py index 442f654d5..523e8596a 100644 --- a/tests/unit/structures/test_agent.py +++ b/tests/unit/structures/test_agent.py @@ -21,8 +21,8 @@ def test_init(self): assert agent.prompt_driver is driver assert isinstance(agent.task, PromptTask) assert isinstance(agent.task, PromptTask) - assert agent.rulesets[0].name == "TestRuleset" - assert agent.rulesets[0].rules[0].value == "test" + assert agent.all_rulesets[0].name == "TestRuleset" + assert agent.all_rulesets[0].rules[0].value == "test" assert isinstance(agent.conversation_memory, ConversationMemory) assert isinstance(Agent(tools=[MockTool()]).task, PromptTask) @@ -32,9 +32,9 @@ def test_rulesets(self): agent.add_task(PromptTask(rulesets=[Ruleset("Bar", [Rule("bar test")])])) assert isinstance(agent.task, PromptTask) - assert len(agent.task.rulesets) == 2 - assert agent.task.rulesets[0].name == "Foo" - assert agent.task.rulesets[1].name == "Bar" + assert len(agent.task.all_rulesets) == 2 + assert agent.task.all_rulesets[0].name == "Foo" + assert agent.task.all_rulesets[1].name == "Bar" def test_rules(self): agent = Agent(rules=[Rule("foo test")]) @@ -42,21 +42,21 @@ def test_rules(self): agent.add_task(PromptTask(rules=[Rule("bar test")])) assert isinstance(agent.task, PromptTask) - assert len(agent.task.rulesets) == 1 - assert agent.task.rulesets[0].name == "Default Ruleset" - assert len(agent.task.rulesets[0].rules) == 2 - assert agent.task.rulesets[0].rules[0].value == "foo test" - assert agent.task.rulesets[0].rules[1].value == "bar test" + assert len(agent.task.all_rulesets) == 1 + assert agent.task.all_rulesets[0].name == "Default Ruleset" + assert len(agent.task.all_rulesets[0].rules) == 2 + assert agent.task.all_rulesets[0].rules[0].value == "foo test" + assert agent.task.all_rulesets[0].rules[1].value == "bar test" def test_rules_and_rulesets(self): agent = Agent(rules=[Rule("foo test")], rulesets=[Ruleset("Bar", [Rule("bar test")])]) - assert len(agent.rulesets) == 2 + assert len(agent.all_rulesets) == 2 assert len(agent.rules) == 1 agent = Agent() agent.add_task(PromptTask(rules=[Rule("foo test")], rulesets=[Ruleset("Bar", [Rule("bar test")])])) assert isinstance(agent.task, PromptTask) - assert len(agent.task.rulesets) == 2 + assert len(agent.task.all_rulesets) == 2 assert len(agent.task.rules) == 1 def test_with_task_memory(self): diff --git a/tests/unit/structures/test_pipeline.py b/tests/unit/structures/test_pipeline.py index d461b5bbf..c670538f0 100644 --- a/tests/unit/structures/test_pipeline.py +++ b/tests/unit/structures/test_pipeline.py @@ -32,8 +32,8 @@ def test_init(self): assert pipeline.input_task is None assert pipeline.output_task is None - assert pipeline.rulesets[0].name == "TestRuleset" - assert pipeline.rulesets[0].rules[0].value == "test" + assert pipeline.all_rulesets[0].name == "TestRuleset" + assert pipeline.all_rulesets[0].rules[0].value == "test" assert pipeline.conversation_memory is not None def test_rulesets(self): @@ -45,14 +45,14 @@ def test_rulesets(self): ) assert isinstance(pipeline.tasks[0], PromptTask) - assert len(pipeline.tasks[0].rulesets) == 2 - assert pipeline.tasks[0].rulesets[0].name == "Foo" - assert pipeline.tasks[0].rulesets[1].name == "Bar" + assert len(pipeline.tasks[0].all_rulesets) == 2 + assert pipeline.tasks[0].all_rulesets[0].name == "Foo" + assert pipeline.tasks[0].all_rulesets[1].name == "Bar" assert isinstance(pipeline.tasks[1], PromptTask) - assert len(pipeline.tasks[1].rulesets) == 2 - assert pipeline.tasks[1].rulesets[0].name == "Foo" - assert pipeline.tasks[1].rulesets[1].name == "Baz" + assert len(pipeline.tasks[1].all_rulesets) == 2 + assert pipeline.tasks[1].all_rulesets[0].name == "Foo" + assert pipeline.tasks[1].all_rulesets[1].name == "Baz" def test_rules(self): pipeline = Pipeline(rules=[Rule("foo test")]) @@ -60,23 +60,24 @@ def test_rules(self): pipeline.add_tasks(PromptTask(rules=[Rule("bar test")]), PromptTask(rules=[Rule("baz test")])) assert isinstance(pipeline.tasks[0], PromptTask) - assert len(pipeline.tasks[0].rulesets) == 1 - assert pipeline.tasks[0].rulesets[0].name == "Default Ruleset" - assert len(pipeline.tasks[0].rulesets[0].rules) == 2 + assert len(pipeline.tasks[0].all_rulesets) == 1 + assert pipeline.tasks[0].all_rulesets[0].name == "Default Ruleset" + assert len(pipeline.tasks[0].all_rulesets[0].rules) == 2 assert isinstance(pipeline.tasks[1], PromptTask) - assert pipeline.tasks[1].rulesets[0].name == "Default Ruleset" - assert len(pipeline.tasks[1].rulesets[0].rules) == 2 + assert pipeline.tasks[1].all_rulesets[0].name == "Default Ruleset" + assert len(pipeline.tasks[1].all_rulesets[0].rules) == 2 def test_rules_and_rulesets(self): pipeline = Pipeline(rules=[Rule("foo test")], rulesets=[Ruleset("Bar", [Rule("bar test")])]) - assert len(pipeline.rulesets) == 2 + assert len(pipeline.all_rulesets) == 2 + assert len(pipeline.rulesets) == 1 assert len(pipeline.rules) == 1 pipeline = Pipeline() pipeline.add_task(PromptTask(rules=[Rule("foo test")], rulesets=[Ruleset("Bar", [Rule("bar test")])])) assert isinstance(pipeline.tasks[0], PromptTask) - assert len(pipeline.tasks[0].rulesets) == 2 + assert len(pipeline.tasks[0].all_rulesets) == 2 assert len(pipeline.tasks[0].rules) == 1 def test_with_no_task_memory(self): diff --git a/tests/unit/structures/test_structure.py b/tests/unit/structures/test_structure.py index da277e81e..81e165595 100644 --- a/tests/unit/structures/test_structure.py +++ b/tests/unit/structures/test_structure.py @@ -74,6 +74,7 @@ def test_to_dict(self): "max_meta_memory_entries": agent.tasks[0].max_meta_memory_entries, "context": agent.tasks[0].context, "rulesets": [], + "rules": [], "max_subtasks": 20, "tools": [], "prompt_driver": { @@ -88,6 +89,7 @@ def test_to_dict(self): } ], "rulesets": [], + "rules": [], "conversation_memory": { "type": agent.conversation_memory.type, "runs": agent.conversation_memory.runs, diff --git a/tests/unit/structures/test_workflow.py b/tests/unit/structures/test_workflow.py index a40b20b93..5835ca0a0 100644 --- a/tests/unit/structures/test_workflow.py +++ b/tests/unit/structures/test_workflow.py @@ -42,14 +42,14 @@ def test_rulesets(self): ) assert isinstance(workflow.tasks[0], PromptTask) - assert len(workflow.tasks[0].rulesets) == 2 - assert workflow.tasks[0].rulesets[0].name == "Foo" - assert workflow.tasks[0].rulesets[1].name == "Bar" + assert len(workflow.tasks[0].all_rulesets) == 2 + assert workflow.tasks[0].all_rulesets[0].name == "Foo" + assert workflow.tasks[0].all_rulesets[1].name == "Bar" assert isinstance(workflow.tasks[1], PromptTask) - assert len(workflow.tasks[1].rulesets) == 2 - assert workflow.tasks[1].rulesets[0].name == "Foo" - assert workflow.tasks[1].rulesets[1].name == "Baz" + assert len(workflow.tasks[1].all_rulesets) == 2 + assert workflow.tasks[1].all_rulesets[0].name == "Foo" + assert workflow.tasks[1].all_rulesets[1].name == "Baz" def test_rules(self): workflow = Workflow(rules=[Rule("foo test")]) @@ -57,24 +57,24 @@ def test_rules(self): workflow.add_tasks(PromptTask(rules=[Rule("bar test")]), PromptTask(rules=[Rule("baz test")])) assert isinstance(workflow.tasks[0], PromptTask) - assert len(workflow.tasks[0].rulesets) == 1 - assert workflow.tasks[0].rulesets[0].name == "Default Ruleset" - assert len(workflow.tasks[0].rulesets[0].rules) == 2 + assert len(workflow.tasks[0].all_rulesets) == 1 + assert workflow.tasks[0].all_rulesets[0].name == "Default Ruleset" + assert len(workflow.tasks[0].all_rulesets[0].rules) == 2 assert isinstance(workflow.tasks[1], PromptTask) - assert len(workflow.tasks[1].rulesets) == 1 - assert workflow.tasks[1].rulesets[0].name == "Default Ruleset" - assert len(workflow.tasks[1].rulesets[0].rules) == 2 + assert len(workflow.tasks[1].all_rulesets) == 1 + assert workflow.tasks[1].all_rulesets[0].name == "Default Ruleset" + assert len(workflow.tasks[1].all_rulesets[0].rules) == 2 def test_rules_and_rulesets(self): workflow = Workflow(rules=[Rule("foo test")], rulesets=[Ruleset("Bar", [Rule("bar test")])]) - assert len(workflow.rulesets) == 2 + assert len(workflow.all_rulesets) == 2 assert len(workflow.rules) == 1 workflow = Workflow() workflow.add_task(PromptTask(rules=[Rule("foo test")], rulesets=[Ruleset("Bar", [Rule("bar test")])])) assert isinstance(workflow.tasks[0], PromptTask) - assert len(workflow.tasks[0].rulesets) == 2 + assert len(workflow.tasks[0].all_rulesets) == 2 assert len(workflow.tasks[0].rules) == 1 def test_with_no_task_memory(self): diff --git a/tests/unit/tasks/test_base_text_input_task.py b/tests/unit/tasks/test_base_text_input_task.py index 1be3904c5..e4adebbf0 100644 --- a/tests/unit/tasks/test_base_text_input_task.py +++ b/tests/unit/tasks/test_base_text_input_task.py @@ -54,11 +54,11 @@ def test_rulesets(self): rulesets=[Ruleset("Foo", [Rule("foo test")]), Ruleset("Bar", [Rule("bar test")])] ) - assert len(prompt_task.rulesets) == 2 - assert prompt_task.rulesets[0].name == "Foo" - assert prompt_task.rulesets[1].name == "Bar" + assert len(prompt_task.all_rulesets) == 2 + assert prompt_task.all_rulesets[0].name == "Foo" + assert prompt_task.all_rulesets[1].name == "Bar" def test_rules(self): prompt_task = MockTextInputTask(rules=[Rule("foo test"), Rule("bar test")]) - assert prompt_task.rulesets[0].name == "Default Ruleset" + assert prompt_task.all_rulesets[0].name == "Default Ruleset" diff --git a/tests/unit/tasks/test_prompt_task.py b/tests/unit/tasks/test_prompt_task.py index d146d2249..164444148 100644 --- a/tests/unit/tasks/test_prompt_task.py +++ b/tests/unit/tasks/test_prompt_task.py @@ -194,15 +194,15 @@ def test_rulesets(self): pipeline.add_task(task) - assert len(task.rulesets) == 3 - assert task.rulesets[0].name == "Pipeline Ruleset" - assert task.rulesets[1].name == "Task Ruleset" - assert task.rulesets[2].name == "Default Ruleset" - - assert len(task.rulesets[0].rules) == 0 - assert len(task.rulesets[1].rules) == 0 - assert task.rulesets[2].rules[0].value == "Pipeline Rule" - assert task.rulesets[2].rules[1].value == "Task Rule" + assert len(task.all_rulesets) == 3 + assert task.all_rulesets[0].name == "Pipeline Ruleset" + assert task.all_rulesets[1].name == "Task Ruleset" + assert task.all_rulesets[2].name == "Default Ruleset" + + assert len(task.all_rulesets[0].rules) == 0 + assert len(task.all_rulesets[1].rules) == 0 + assert task.all_rulesets[2].rules[0].value == "Pipeline Rule" + assert task.all_rulesets[2].rules[1].value == "Task Rule" def test_conversation_memory(self): conversation_memory = ConversationMemory() diff --git a/tests/unit/tasks/test_tool_task.py b/tests/unit/tasks/test_tool_task.py index 5c7f6b394..ea86d56c8 100644 --- a/tests/unit/tasks/test_tool_task.py +++ b/tests/unit/tasks/test_tool_task.py @@ -251,6 +251,7 @@ def test_to_dict(self): "max_meta_memory_entries": task.max_meta_memory_entries, "context": task.context, "rulesets": [], + "rules": [], "prompt_driver": { "extra_params": {}, "max_tokens": None, diff --git a/tests/unit/tasks/test_toolkit_task.py b/tests/unit/tasks/test_toolkit_task.py index a5e95f4d1..e781a7710 100644 --- a/tests/unit/tasks/test_toolkit_task.py +++ b/tests/unit/tasks/test_toolkit_task.py @@ -392,6 +392,7 @@ def test_to_dict(self): "max_meta_memory_entries": 20, "context": {}, "rulesets": [], + "rules": [], "prompt_driver": { "extra_params": {}, "max_tokens": None, diff --git a/tests/utils/structure_tester.py b/tests/utils/structure_tester.py index a34871013..0409ee978 100644 --- a/tests/utils/structure_tester.py +++ b/tests/utils/structure_tester.py @@ -248,7 +248,7 @@ def verify_structure_output(self, structure) -> dict: task_names = [task.__class__.__name__ for task in structure.tasks] prompt = structure.input_task.input.to_text() actual = structure.output.to_text() - rules = [rule.value for ruleset in structure.input_task.rulesets for rule in ruleset.rules] + rules = [rule.value for ruleset in structure.input_task.all_rulesets for rule in ruleset.rules] agent = Agent( rulesets=[