Skip to content

Commit

Permalink
Add the correct path when it's a package
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Fraenkel committed May 3, 2023
1 parent 0fa82cf commit 0f7016e
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 0 deletions.
2 changes: 2 additions & 0 deletions integration/_support/package/tasks/module.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from invoke import task
from . import pytest as pt
from pytest import Testdir


@task
Expand Down
Empty file.
2 changes: 2 additions & 0 deletions invoke/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def load(self, name: Optional[str] = None) -> Tuple[ModuleType, str]:
# being imported is trying to load local-to-it names.
if os.path.isfile(spec.origin):
path = os.path.dirname(spec.origin)
if spec.origin.endswith("__init__.py"):
path = os.path.dirname(path)
if path not in sys.path:
sys.path.insert(0, path)
# Actual import
Expand Down
6 changes: 6 additions & 0 deletions tests/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ def adds_module_parent_dir_to_sys_path(self):
# Crummy doesn't-explode test.
_BasicLoader().load("namespacing")

def adds_package_dir_to_sys_path(self):
config = Config({"tasks": {"collection_name": "module"}})
_BasicLoader(config).load("package")
package = Path(support) / "package"
assert str(package) not in sys.path

def doesnt_duplicate_parent_dir_addition(self):
_BasicLoader().load("namespacing")
_BasicLoader().load("namespacing")
Expand Down

0 comments on commit 0f7016e

Please sign in to comment.