Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support screen_hint in user_management.get_authorization_url #396

Merged
merged 5 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions tests/test_user_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,10 @@ def mock_invitations_multiple_pages(self):
class TestUserManagementBase(UserManagementFixtures):
@pytest.fixture(autouse=True)
def setup(self, sync_client_configuration_and_http_client_for_test):
client_configuration, http_client = (
sync_client_configuration_and_http_client_for_test
)
(
client_configuration,
http_client,
) = sync_client_configuration_and_http_client_for_test
self.http_client = http_client
self.user_management = UserManagement(
http_client=self.http_client, client_configuration=client_configuration
Expand Down Expand Up @@ -320,6 +321,29 @@ def test_authorization_url_has_expected_query_params_with_code_challenge(self):
"response_type": RESPONSE_TYPE_CODE,
}

def test_authorization_url_has_expected_query_params_with_screen_hint(self):
connection_id = "connection_123"
redirect_uri = "https://localhost/auth/callback"
screen_hint = "sign-up"

authorization_url = self.user_management.get_authorization_url(
connection_id=connection_id,
screen_hint=screen_hint,
redirect_uri=redirect_uri,
provider="authkit",
)

parsed_url = urlparse(authorization_url)
assert parsed_url.path == "/user_management/authorize"
assert dict(parse_qsl(str(parsed_url.query))) == {
"screen_hint": screen_hint,
"client_id": self.http_client.client_id,
"redirect_uri": redirect_uri,
"connection_id": connection_id,
"response_type": RESPONSE_TYPE_CODE,
"provider": "authkit",
}

def test_get_jwks_url(self):
expected = "%ssso/jwks/%s" % (
self.http_client.base_url,
Expand Down
3 changes: 3 additions & 0 deletions workos/types/user_management/screen_hint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from typing import Literal

ScreenHintType = Literal["sign-up", "sign-in"]
6 changes: 6 additions & 0 deletions workos/user_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
UsersListFilters,
)
from workos.types.user_management.password_hash_type import PasswordHashType
from workos.types.user_management.screen_hint import ScreenHintType
from workos.types.user_management.session import SessionConfig
from workos.types.user_management.user_management_provider_type import (
UserManagementProviderType,
Expand Down Expand Up @@ -342,6 +343,7 @@ def get_authorization_url(
organization_id: Optional[str] = None,
code_challenge: Optional[str] = None,
prompt: Optional[str] = None,
screen_hint: Optional[ScreenHintType] = None,
) -> str:
"""Generate an OAuth 2.0 authorization URL.

Expand Down Expand Up @@ -369,6 +371,7 @@ def get_authorization_url(
prompt (str): Used to specify whether the upstream provider should prompt the user for credentials or other
consent. Valid values depend on the provider. Currently only applies to provider values of 'GoogleOAuth',
'MicrosoftOAuth', or 'GitHubOAuth'. (Optional)
screen_hint (ScreenHintType): Specify which AuthKit screen users should land on upon redirection (Only applicable when provider is 'authkit').

Returns:
str: URL to redirect a User to to begin the OAuth workflow with WorkOS
Expand All @@ -385,6 +388,7 @@ def get_authorization_url(
)

if connection_id is not None:

params["connection_id"] = connection_id
if organization_id is not None:
params["organization_id"] = organization_id
Expand All @@ -401,6 +405,8 @@ def get_authorization_url(
params["code_challenge_method"] = "S256"
if prompt is not None:
params["prompt"] = prompt
if screen_hint is not None:
params["screen_hint"] = screen_hint

return RequestHelper.build_url_with_query_params(
base_url=self._client_configuration.base_url,
Expand Down
Loading