-
Notifications
You must be signed in to change notification settings - Fork 562
Add gradio-tools plugin #14
base: master
Are you sure you want to change the base?
Changes from all commits
9038116
a8833ac
25cc465
def0335
1038ab1
1fae35e
0645a8e
271503a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# freddyaboulton/autogpt-gradio-tools 🤝 | ||
|
||
A plugin giving AutoGPT access to [Gradio](https://github.com/gradio-app/gradio) spaces running on | ||
the [huggingface hub](https://huggingface.co/spaces) and elsewhere! | ||
|
||
Integration powered by [gradio-tools](https://github.com/freddyaboulton/gradio-tools) | ||
|
||
gradio-tools comes with a set of pre-built tools but it is easy to add new tools. | ||
|
||
All contributions to `gradio-tools` and this plugin are welcome! | ||
|
||
## Features | ||
|
||
Each tool specified via the env file will add a command that gives AutoGPT | ||
the ability to call that gradio app programmatically and get its prediciton. | ||
|
||
For example, an LLM could use a Gradio tool to transcribe a voice recording it finds online and then summarize it for you. Or it could use a different Gradio tool to apply OCR to a document on your Google Drive and then answer questions about it. | ||
|
||
## Installation | ||
|
||
1. Download this repository, and save it as `autogpt-gradio-tools.zip` | ||
2. Place the `.zip` file in the plugins directory of your AutoGPT install | ||
3. Add your twitter API information to the `.env` file within AutoGPT: | ||
|
||
``` | ||
################################################################################ | ||
### GRADIO-TOOLS | ||
################################################################################ | ||
|
||
# Consumer Keys are also known as API keys on the dev portal | ||
|
||
AUTOGPT_GRADIO_TOOLS=StableDiffusion,ImageCaptioner | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. List all tools that function here |
||
GRADIO_TOOLS_HF_TOKEN=<Optional hs token to clone spaces and avoid rate limits> | ||
``` | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,240 @@ | ||
"""Integration with Gradio Spaces On HuggingFace via gradio_tools.""" | ||
from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar | ||
from dotenv import load_dotenv | ||
from auto_gpt_plugin_template import AutoGPTPluginTemplate | ||
from pathlib import Path | ||
import os | ||
from .tools import (AutoGPTClipInterrogatorTool, | ||
AutoGPTStableDiffusion, | ||
AutoGPTWhisperTool, | ||
AutoGPTTextToVideoTool, | ||
AutoGPTCaptioner, | ||
AutoGPTPromptGeneratorTool, | ||
AutoGPTImageToMusicTool, | ||
AutoGPTDocumentAnsweringTool) | ||
from gradio_tools import GradioTool | ||
|
||
PromptGenerator = TypeVar("PromptGenerator") | ||
|
||
if (Path(os.getcwd()) / ".env").exists(): | ||
with open(str(Path(os.getcwd()) / ".env"), 'r') as fp: | ||
load_dotenv(stream=fp) | ||
|
||
Comment on lines
+19
to
+22
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This shouldn't be needed |
||
|
||
TOOLS = [ | ||
AutoGPTStableDiffusion(hf_token=os.getenv("GRADIO_TOOLS_HF_TOKEN")), | ||
AutoGPTCaptioner(hf_token=os.getenv("GRADIO_TOOLS_HF_TOKEN")), | ||
AutoGPTWhisperTool(hf_token=os.getenv("GRADIO_TOOLS_HF_TOKEN")), | ||
AutoGPTTextToVideoTool(hf_token=os.getenv("GRADIO_TOOLS_HF_TOKEN")), | ||
AutoGPTPromptGeneratorTool(hf_token=os.getenv("GRADIO_TOOLS_HF_TOKEN")), | ||
AutoGPTDocumentAnsweringTool(hf_token=os.getenv("GRADIO_TOOLS_HF_TOKEN")), | ||
AutoGPTImageToMusicTool(hf_token=os.getenv("GRADIO_TOOLS_HF_TOKEN")), | ||
AutoGPTClipInterrogatorTool(hf_token=os.getenv("GRADIO_TOOLS_HF_TOKEN")) | ||
] | ||
Comment on lines
+24
to
+33
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Only run this if your plugin is loaded. Take a look at the email plugin to see how to do this
Comment on lines
+24
to
+33
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. don't use this like a global variable, initialize it within the init and scope it to the plugin's class instance |
||
|
||
def get_tool(tool: str) -> GradioTool: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. make this a class function |
||
return next(t for t in TOOLS if t.name == tool) | ||
|
||
|
||
class Message(TypedDict): | ||
role: str | ||
content: str | ||
|
||
|
||
class AutoGPTGradioTools(AutoGPTPluginTemplate): | ||
""" | ||
Twitter API integrations using Tweepy | ||
""" | ||
|
||
def __init__(self): | ||
super().__init__() | ||
self._name = "autogpt-gradio-tools" | ||
self._version = "0.1.0" | ||
self._description = "Calling Gradio Apps ." | ||
ts = [t for t in os.getenv("AUTOGPT_GRADIO_TOOLS", "").split(",") if t != ""] | ||
self.tools = [get_tool(t) for t in ts] | ||
Comment on lines
+54
to
+55
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Filter attempts to load this to names matching a tool. If you attempt to load an invalid name it will crash
|
||
|
||
def can_handle_on_response(self) -> bool: | ||
"""This method is called to check that the plugin can | ||
handle the on_response method. | ||
Returns: | ||
bool: True if the plugin can handle the on_response method.""" | ||
return False | ||
|
||
def on_response(self, response: str, *args, **kwargs) -> str: | ||
"""This method is called when a response is received from the model.""" | ||
pass | ||
|
||
def can_handle_post_prompt(self) -> bool: | ||
"""This method is called to check that the plugin can | ||
handle the post_prompt method. | ||
Returns: | ||
bool: True if the plugin can handle the post_prompt method.""" | ||
return True | ||
|
||
def can_handle_on_planning(self) -> bool: | ||
"""This method is called to check that the plugin can | ||
handle the on_planning method. | ||
Returns: | ||
bool: True if the plugin can handle the on_planning method.""" | ||
return False | ||
|
||
def on_planning( | ||
self, prompt: PromptGenerator, messages: List[str] | ||
) -> Optional[str]: | ||
"""This method is called before the planning chat completeion is done. | ||
Args: | ||
prompt (PromptGenerator): The prompt generator. | ||
messages (List[str]): The list of messages. | ||
""" | ||
pass | ||
|
||
def can_handle_post_planning(self) -> bool: | ||
"""This method is called to check that the plugin can | ||
handle the post_planning method. | ||
Returns: | ||
bool: True if the plugin can handle the post_planning method.""" | ||
return False | ||
|
||
def post_planning(self, response: str) -> str: | ||
"""This method is called after the planning chat completeion is done. | ||
Args: | ||
response (str): The response. | ||
Returns: | ||
str: The resulting response. | ||
""" | ||
pass | ||
|
||
def can_handle_pre_instruction(self) -> bool: | ||
"""This method is called to check that the plugin can | ||
handle the pre_instruction method. | ||
Returns: | ||
bool: True if the plugin can handle the pre_instruction method.""" | ||
return False | ||
|
||
def pre_instruction(self, messages: List[str]) -> List[str]: | ||
"""This method is called before the instruction chat is done. | ||
Args: | ||
messages (List[str]): The list of context messages. | ||
Returns: | ||
List[str]: The resulting list of messages. | ||
""" | ||
pass | ||
|
||
def can_handle_on_instruction(self) -> bool: | ||
"""This method is called to check that the plugin can | ||
handle the on_instruction method. | ||
Returns: | ||
bool: True if the plugin can handle the on_instruction method.""" | ||
return False | ||
|
||
def on_instruction(self, messages: List[str]) -> Optional[str]: | ||
"""This method is called when the instruction chat is done. | ||
Args: | ||
messages (List[str]): The list of context messages. | ||
Returns: | ||
Optional[str]: The resulting message. | ||
""" | ||
pass | ||
|
||
def can_handle_post_instruction(self) -> bool: | ||
"""This method is called to check that the plugin can | ||
handle the post_instruction method. | ||
Returns: | ||
bool: True if the plugin can handle the post_instruction method.""" | ||
return False | ||
|
||
def post_instruction(self, response: str) -> str: | ||
"""This method is called after the instruction chat is done. | ||
Args: | ||
response (str): The response. | ||
Returns: | ||
str: The resulting response. | ||
""" | ||
pass | ||
|
||
def can_handle_pre_command(self) -> bool: | ||
"""This method is called to check that the plugin can | ||
handle the pre_command method. | ||
Returns: | ||
bool: True if the plugin can handle the pre_command method.""" | ||
return False | ||
|
||
def pre_command( | ||
self, command_name: str, arguments: Dict[str, Any] | ||
) -> Tuple[str, Dict[str, Any]]: | ||
"""This method is called before the command is executed. | ||
Args: | ||
command_name (str): The command name. | ||
arguments (Dict[str, Any]): The arguments. | ||
Returns: | ||
Tuple[str, Dict[str, Any]]: The command name and the arguments. | ||
""" | ||
pass | ||
|
||
def can_handle_post_command(self) -> bool: | ||
"""This method is called to check that the plugin can | ||
handle the post_command method. | ||
Returns: | ||
bool: True if the plugin can handle the post_command method.""" | ||
return False | ||
|
||
def post_command(self, command_name: str, response: str) -> str: | ||
"""This method is called after the command is executed. | ||
Args: | ||
command_name (str): The command name. | ||
response (str): The response. | ||
Returns: | ||
str: The resulting response. | ||
""" | ||
pass | ||
|
||
def can_handle_chat_completion( | ||
self, | ||
messages: list[Dict[Any, Any]], | ||
model: str, | ||
temperature: float, | ||
max_tokens: int, | ||
) -> bool: | ||
"""This method is called to check that the plugin can | ||
handle the chat_completion method. | ||
Args: | ||
messages (Dict[Any, Any]): The messages. | ||
model (str): The model name. | ||
temperature (float): The temperature. | ||
max_tokens (int): The max tokens. | ||
Returns: | ||
bool: True if the plugin can handle the chat_completion method.""" | ||
return False | ||
|
||
def handle_chat_completion( | ||
self, | ||
messages: list[Dict[Any, Any]], | ||
model: str, | ||
temperature: float, | ||
max_tokens: int, | ||
) -> str: | ||
"""This method is called when the chat completion is done. | ||
Args: | ||
messages (Dict[Any, Any]): The messages. | ||
model (str): The model name. | ||
temperature (float): The temperature. | ||
max_tokens (int): The max tokens. | ||
Returns: | ||
str: The resulting response. | ||
""" | ||
return None | ||
|
||
def post_prompt(self, prompt: PromptGenerator) -> PromptGenerator: | ||
"""This method is called just after the generate_prompt is called, | ||
but actually before the prompt is generated. | ||
Args: | ||
prompt (PromptGenerator): The prompt generator. | ||
Returns: | ||
PromptGenerator: The prompt generator. | ||
""" | ||
if self.tools: | ||
for tool in self.tools: | ||
prompt.add_command(tool.description, tool.name.lower(), tool.args, tool.run) | ||
|
||
return prompt |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import os | ||
from unittest.mock import patch, MagicMock | ||
import unittest | ||
import gradio_tools | ||
from .tools import (AutoGPTClipInterrogatorTool, | ||
AutoGPTStableDiffusion, | ||
AutoGPTWhisperTool, | ||
AutoGPTTextToVideoTool, | ||
AutoGPTCaptioner, | ||
AutoGPTPromptGeneratorTool, | ||
AutoGPTImageToMusicTool, | ||
AutoGPTDocumentAnsweringTool) | ||
|
||
|
||
class TestGradioTools(unittest.TestCase): | ||
@patch.dict( | ||
os.environ, | ||
{ | ||
"AUTOGPT_GRADIO_TOOLS": "WhisperAudioTranscription,TextToVideo", | ||
}, | ||
) | ||
def test_right_tools_loaded(self): | ||
from . import AutoGPTGradioTools | ||
|
||
plugin = AutoGPTGradioTools() | ||
assert plugin.tools[0].name == "WhisperAudioTranscription" | ||
assert plugin.tools[1].name == "TextToVideo" | ||
|
||
@patch.dict( | ||
os.environ, | ||
{ | ||
"AUTOGPT_GRADIO_TOOLS": "WhisperAudioTranscription,TextToVideo", | ||
}, | ||
) | ||
def test_commands_added_to_prompt(self): | ||
from . import AutoGPTGradioTools | ||
|
||
mock_prompt = MagicMock() | ||
plugin = AutoGPTGradioTools() | ||
plugin.post_prompt(mock_prompt) | ||
# Two tools added to prompt | ||
assert mock_prompt.add_command.call_count == 2 | ||
|
||
|
||
def test_tools_configured_correctly(self): | ||
all_tools = [ | ||
(AutoGPTClipInterrogatorTool(), gradio_tools.ClipInterrogatorTool()), | ||
(AutoGPTStableDiffusion(), gradio_tools.StableDiffusionTool()), | ||
(AutoGPTWhisperTool(), gradio_tools.WhisperAudioTranscriptionTool()), | ||
(AutoGPTTextToVideoTool(), gradio_tools.TextToVideoTool()), | ||
(AutoGPTCaptioner(), gradio_tools.ImageCaptioningTool()), | ||
(AutoGPTPromptGeneratorTool(), gradio_tools.StableDiffusionPromptGeneratorTool()), | ||
(AutoGPTImageToMusicTool(), gradio_tools.ImageToMusicTool()), | ||
(AutoGPTDocumentAnsweringTool(), gradio_tools.DocQueryDocumentAnsweringTool()) | ||
] | ||
for tool_1, tool_2 in all_tools: | ||
assert tool_1.name == tool_2.name | ||
assert tool_1.description == tool_2.description | ||
assert tool_1.src == tool_2.src | ||
assert tool_1.args |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this section