Skip to content

Commit

Permalink
feat: Support screen_hint in user_management.get_authorization_url (#396
Browse files Browse the repository at this point in the history
)

* feat: Support screen_hint in user_management.get_authorization_url

* Format

* Format

* Format

* Format
  • Loading branch information
faroceann authored Jan 6, 2025
1 parent dadb944 commit 4b3c7bf
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 3 deletions.
30 changes: 27 additions & 3 deletions tests/test_user_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,10 @@ def mock_invitations_multiple_pages(self):
class TestUserManagementBase(UserManagementFixtures):
@pytest.fixture(autouse=True)
def setup(self, sync_client_configuration_and_http_client_for_test):
client_configuration, http_client = (
sync_client_configuration_and_http_client_for_test
)
(
client_configuration,
http_client,
) = sync_client_configuration_and_http_client_for_test
self.http_client = http_client
self.user_management = UserManagement(
http_client=self.http_client, client_configuration=client_configuration
Expand Down Expand Up @@ -320,6 +321,29 @@ def test_authorization_url_has_expected_query_params_with_code_challenge(self):
"response_type": RESPONSE_TYPE_CODE,
}

def test_authorization_url_has_expected_query_params_with_screen_hint(self):
connection_id = "connection_123"
redirect_uri = "https://localhost/auth/callback"
screen_hint = "sign-up"

authorization_url = self.user_management.get_authorization_url(
connection_id=connection_id,
screen_hint=screen_hint,
redirect_uri=redirect_uri,
provider="authkit",
)

parsed_url = urlparse(authorization_url)
assert parsed_url.path == "/user_management/authorize"
assert dict(parse_qsl(str(parsed_url.query))) == {
"screen_hint": screen_hint,
"client_id": self.http_client.client_id,
"redirect_uri": redirect_uri,
"connection_id": connection_id,
"response_type": RESPONSE_TYPE_CODE,
"provider": "authkit",
}

def test_get_jwks_url(self):
expected = "%ssso/jwks/%s" % (
self.http_client.base_url,
Expand Down
3 changes: 3 additions & 0 deletions workos/types/user_management/screen_hint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from typing import Literal

ScreenHintType = Literal["sign-up", "sign-in"]
6 changes: 6 additions & 0 deletions workos/user_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
UsersListFilters,
)
from workos.types.user_management.password_hash_type import PasswordHashType
from workos.types.user_management.screen_hint import ScreenHintType
from workos.types.user_management.session import SessionConfig
from workos.types.user_management.user_management_provider_type import (
UserManagementProviderType,
Expand Down Expand Up @@ -342,6 +343,7 @@ def get_authorization_url(
organization_id: Optional[str] = None,
code_challenge: Optional[str] = None,
prompt: Optional[str] = None,
screen_hint: Optional[ScreenHintType] = None,
) -> str:
"""Generate an OAuth 2.0 authorization URL.
Expand Down Expand Up @@ -369,6 +371,7 @@ def get_authorization_url(
prompt (str): Used to specify whether the upstream provider should prompt the user for credentials or other
consent. Valid values depend on the provider. Currently only applies to provider values of 'GoogleOAuth',
'MicrosoftOAuth', or 'GitHubOAuth'. (Optional)
screen_hint (ScreenHintType): Specify which AuthKit screen users should land on upon redirection (Only applicable when provider is 'authkit').
Returns:
str: URL to redirect a User to to begin the OAuth workflow with WorkOS
Expand All @@ -385,6 +388,7 @@ def get_authorization_url(
)

if connection_id is not None:

params["connection_id"] = connection_id
if organization_id is not None:
params["organization_id"] = organization_id
Expand All @@ -401,6 +405,8 @@ def get_authorization_url(
params["code_challenge_method"] = "S256"
if prompt is not None:
params["prompt"] = prompt
if screen_hint is not None:
params["screen_hint"] = screen_hint

return RequestHelper.build_url_with_query_params(
base_url=self._client_configuration.base_url,
Expand Down

0 comments on commit 4b3c7bf

Please sign in to comment.