Skip to content

Commit

Permalink
Migrate Azure MSI token source to httpclient and add 90% test coverage
Browse files Browse the repository at this point in the history
This PR improves the stability for Azure MSI authentication by adopting the httpclient transport.

add 100% coverage for MSI & metadata-service

..

..

..
  • Loading branch information
nfx committed Dec 1, 2023
1 parent e86cbfd commit 3e96fd4
Show file tree
Hide file tree
Showing 16 changed files with 455 additions and 171 deletions.
7 changes: 6 additions & 1 deletion apierr/unwrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ func (e *wrapError) Unwrap() error {
return e.wrap
}

func ByStatusCode(statusCode int) (error, bool) {
err, ok := statusCodeMapping[statusCode]
return err, ok
}

// Unwrap error for easier client code checking
//
// See https://pkg.go.dev/errors#example-Unwrap
Expand All @@ -28,7 +33,7 @@ func (apiError *APIError) Unwrap() error {
if ok {
return byErrorCode
}
byStatusCode, ok := statusCodeMapping[apiError.StatusCode]
byStatusCode, ok := ByStatusCode(apiError.StatusCode)
if ok {
return byStatusCode
}
Expand Down
2 changes: 0 additions & 2 deletions config/auth_azure_cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (

"golang.org/x/oauth2"

"github.com/databricks/databricks-sdk-go/httpclient"
"github.com/databricks/databricks-sdk-go/logger"
)

Expand Down Expand Up @@ -73,7 +72,6 @@ func (c AzureCliCredentials) Configure(ctx context.Context, cfg *Config) (func(*
}
return nil, err
}
ctx = httpclient.DefaultClient.InContextForOAuth2(ctx)
err = cfg.azureEnsureWorkspaceUrl(ctx, c)
if err != nil {
return nil, fmt.Errorf("resolve host: %w", err)
Expand Down
2 changes: 0 additions & 2 deletions config/auth_azure_client_secret.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"golang.org/x/oauth2"
"golang.org/x/oauth2/clientcredentials"

"github.com/databricks/databricks-sdk-go/httpclient"
"github.com/databricks/databricks-sdk-go/logger"
)

Expand Down Expand Up @@ -43,7 +42,6 @@ func (c AzureClientSecretCredentials) Configure(ctx context.Context, cfg *Config
if !cfg.IsAzure() {
return nil, nil
}
ctx = httpclient.DefaultClient.InContextForOAuth2(ctx)
err := cfg.azureEnsureWorkspaceUrl(ctx, c)
if err != nil {
return nil, fmt.Errorf("resolve host: %w", err)
Expand Down
93 changes: 39 additions & 54 deletions config/auth_azure_msi.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ package config
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"time"

Expand All @@ -13,6 +13,9 @@ import (
"golang.org/x/oauth2"
)

var errInvalidToken = errors.New("invalid token")
var errInvalidTokenExpiry = errors.New("invalid token expiry")

// well-known URL for Azure Instance Metadata Service (IMDS)
// https://learn.microsoft.com/en-us/azure-stack/user/instance-metadata-service
var instanceMetadataPrefix = "http://169.254.169.254/metadata"
Expand All @@ -32,94 +35,76 @@ func (c AzureMsiCredentials) Configure(ctx context.Context, cfg *Config) (func(*
return nil, nil
}
env := cfg.Environment()
ctx = httpclient.DefaultClient.InContextForOAuth2(ctx)
if !cfg.IsAccountClient() {
err := cfg.azureEnsureWorkspaceUrl(ctx, c)
if err != nil {
return nil, fmt.Errorf("resolve host: %w", err)
}
}
logger.Debugf(ctx, "Generating AAD token via Azure MSI")
inner := azureReuseTokenSource(nil, azureMsiTokenSource{
resource: env.AzureApplicationID,
clientId: cfg.AzureClientID,
})
management := azureReuseTokenSource(nil, azureMsiTokenSource{
resource: env.AzureServiceManagementEndpoint(),
clientId: cfg.AzureClientID,
})
inner := azureReuseTokenSource(nil, c.tokenSourceFor(ctx, cfg, "", env.azureApplicationID))

Check failure on line 45 in config/auth_azure_msi.go

View workflow job for this annotation

GitHub Actions / tests (1.19)

env.azureApplicationID undefined (type DatabricksEnvironment has no field or method azureApplicationID, but does have AzureApplicationID)

Check failure on line 45 in config/auth_azure_msi.go

View workflow job for this annotation

GitHub Actions / tests (1.19)

env.azureApplicationID undefined (type DatabricksEnvironment has no field or method azureApplicationID, but does have AzureApplicationID)

Check failure on line 45 in config/auth_azure_msi.go

View workflow job for this annotation

GitHub Actions / tests (1.20)

env.azureApplicationID undefined (type DatabricksEnvironment has no field or method azureApplicationID, but does have AzureApplicationID)

Check failure on line 45 in config/auth_azure_msi.go

View workflow job for this annotation

GitHub Actions / tests (1.20)

env.azureApplicationID undefined (type DatabricksEnvironment has no field or method azureApplicationID, but does have AzureApplicationID)

Check failure on line 45 in config/auth_azure_msi.go

View workflow job for this annotation

GitHub Actions / tests (1.21)

env.azureApplicationID undefined (type DatabricksEnvironment has no field or method azureApplicationID, but does have AzureApplicationID)

Check failure on line 45 in config/auth_azure_msi.go

View workflow job for this annotation

GitHub Actions / tests (1.21)

env.azureApplicationID undefined (type DatabricksEnvironment has no field or method azureApplicationID, but does have AzureApplicationID)
management := azureReuseTokenSource(nil, c.tokenSourceFor(ctx, cfg, "", env.AzureServiceManagementEndpoint()))
return azureVisitor(cfg, serviceToServiceVisitor(inner, management, xDatabricksAzureSpManagementToken)), nil
}

// implementing azureHostResolver for ensureWorkspaceUrl to work
func (c AzureMsiCredentials) tokenSourceFor(_ context.Context, cfg *Config, _, resource string) oauth2.TokenSource {
return azureMsiTokenSource{
resource: resource,
client: cfg.refreshClient,
clientId: cfg.AzureClientID,
resource: resource,
}
}

type azureMsiTokenSource struct {
client *httpclient.ApiClient
resource string
clientId string
}

func (s azureMsiTokenSource) Token() (*oauth2.Token, error) {
ctx, cancel := context.WithTimeout(context.Background(), azureMsiTimeout)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet,
fmt.Sprintf("%s/identity/oauth2/token", instanceMetadataPrefix), nil)
if err != nil {
return nil, fmt.Errorf("token request: %w", err)
query := map[string]string{
"api-version": "2018-02-01",
"resource": s.resource,
}
query := req.URL.Query()
query.Add("api-version", "2018-02-01")
query.Add("resource", s.resource)
if s.clientId != "" {
query.Add("client_id", s.clientId)
query["client_id"] = s.clientId
}
req.URL.RawQuery = query.Encode()
req.Header.Add("Metadata", "true")
return makeMsiRequest(req)
}

func makeMsiRequest(req *http.Request) (*oauth2.Token, error) {
res, err := http.DefaultClient.Do(req)
var inner msiToken
err := s.client.Do(ctx, http.MethodGet,
fmt.Sprintf("%s/identity/oauth2/token", instanceMetadataPrefix),
httpclient.WithRequestHeader("Metadata", "true"),
httpclient.WithRequestData(query),
httpclient.WithResponseUnmarshal(&inner),
)
if err != nil {
return nil, fmt.Errorf("token response: %w", err)
}
defer res.Body.Close()
if res.StatusCode == http.StatusNotFound {
return nil, nil
}
raw, err := io.ReadAll(res.Body)
if err != nil {
return nil, fmt.Errorf("token read: %w", err)
}
if res.StatusCode != http.StatusOK {
return nil, fmt.Errorf("token error: %s", raw)
}
var token azureMsiToken
err = json.Unmarshal(raw, &token)
if err != nil {
return nil, fmt.Errorf("token parse: %w", err)
return nil, fmt.Errorf("token request: %w", err)
}
return inner.Token()
}

type msiToken struct {
TokenType string `json:"token_type"`
AccessToken string `json:"access_token,omitempty"`
RefreshToken string `json:"refresh_token,omitempty"`
ExpiresOn json.Number `json:"expires_on"`
}

func (token msiToken) Token() (*oauth2.Token, error) {
if token.AccessToken == "" {
return nil, fmt.Errorf("token parse: invalid token")
return nil, fmt.Errorf("token parse: %w", errInvalidToken)
}
epoch, err := token.ExpiresOn.Int64()
if err != nil {
return nil, fmt.Errorf("token expires on: %w", err)
// go 1.19 doesn't support multiple error unwraps
return nil, fmt.Errorf("%w: %s", errInvalidTokenExpiry, err)
}
return &oauth2.Token{
TokenType: token.TokenType,
AccessToken: token.AccessToken,
Expiry: time.Unix(epoch, 0),
TokenType: token.TokenType,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
Expiry: time.Unix(epoch, 0),
}, nil
}

type azureMsiToken struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresOn json.Number `json:"expires_on"`
}
133 changes: 133 additions & 0 deletions config/auth_azure_msi_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package config

import (
"net/http"
"testing"
"time"

"github.com/databricks/databricks-sdk-go/apierr"
"github.com/databricks/databricks-sdk-go/httpclient/fixtures"
"github.com/databricks/databricks-sdk-go/logger"
"github.com/stretchr/testify/require"
)

func init() {
logger.DefaultLogger = &logger.SimpleLogger{
Level: logger.LevelDebug,
}
}

func someValidToken(bearer string) any {
return map[string]any{
"token_type": "Bearer",
"access_token": bearer,
"expires_on": time.Now().Add(5 * time.Minute).Unix(),
}
}

func authenticateRequest(cfg *Config) (*http.Request, error) {
cfg.ConfigFile = "/dev/null"
cfg.DebugHeaders = true
req, _ := http.NewRequest("GET", "http://localhost", nil)
err := cfg.Authenticate(req)
return req, err
}

func assertHeaders(t *testing.T, cfg *Config, expectedHeaders map[string]string) {
req, err := authenticateRequest(cfg)
require.NoError(t, err)
actualHeaders := map[string]string{}
for k := range req.Header {
actualHeaders[k] = req.Header.Get(k)
}
require.Equal(t, expectedHeaders, actualHeaders)
}

func TestMsiHappyFlow(t *testing.T) {
assertHeaders(t, &Config{
AzureUseMSI: true,
AzureResourceID: "/a/b/c",
HTTPTransport: fixtures.MappingTransport{
"GET /metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fmanagement.azure.com%2F": {
ExpectedHeaders: map[string]string{
"Metadata": "true",
},
Response: someValidToken("bcd"),
},
"GET /a/b/c?api-version=2018-04-01": {
Response: `{"properties": {
"workspaceUrl": "https://abc"
}}`,
},
"GET /metadata/identity/oauth2/token?api-version=2018-02-01&resource=2ff814a6-3304-4ab8-85cb-cd0e6f879c1d": {
ExpectedHeaders: map[string]string{
"Metadata": "true",
},
Response: someValidToken("cde"),
},
"GET /metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fmanagement.core.windows.net%2F": {
ExpectedHeaders: map[string]string{
"Metadata": "true",
},
Response: someValidToken("def"),
},
},
}, map[string]string{
"Authorization": "Bearer cde",
"X-Databricks-Azure-Sp-Management-Token": "def",
"X-Databricks-Azure-Workspace-Resource-Id": "/a/b/c",
})
}

func TestMsiFailsOnResolveWorkspace(t *testing.T) {
_, err := authenticateRequest(&Config{
AzureUseMSI: true,
AzureResourceID: "/a/b/c",
HTTPTransport: fixtures.MappingTransport{
"GET /metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fmanagement.azure.com%2F": {
Response: someValidToken("bcd"),
},
"GET /a/b/c?api-version=2018-04-01": {
Status: 404,
Response: azureResourceManagerErrorResponse{
Error: azureResourceManagerErrorError{
Message: "nope",
},
},
},
},
})
require.ErrorIs(t, err, apierr.ErrNotFound)
}

func TestMsiTokenNotFound(t *testing.T) {
_, err := authenticateRequest(&Config{
AzureUseMSI: true,
AzureClientID: "abc",
AzureResourceID: "/a/b/c",
HTTPTransport: fixtures.MappingTransport{
"GET /metadata/identity/oauth2/token?api-version=2018-02-01&client_id=abc&resource=https%3A%2F%2Fmanagement.azure.com%2F": {
Status: 404,
Response: `...`,
},
},
})
require.ErrorIs(t, err, apierr.ErrNotFound)
}

func TestMsiInvalidTokenExpiry(t *testing.T) {
_, err := authenticateRequest(&Config{
AzureUseMSI: true,
AzureResourceID: "/a/b/c",
HTTPTransport: fixtures.MappingTransport{
"GET /metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fmanagement.azure.com%2F": {
Response: map[string]any{
"token_type": "Bearer",
"access_token": "abc",
"expires_on": "12345678912345678901234567890123456789123456789",
},
},
},
})
require.ErrorIs(t, err, errInvalidTokenExpiry)
}
22 changes: 16 additions & 6 deletions config/auth_metadata_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@ package config

import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
"time"

"github.com/databricks/databricks-sdk-go/httpclient"
"github.com/databricks/databricks-sdk-go/logger"
"golang.org/x/oauth2"
)
Expand All @@ -18,6 +20,9 @@ const MetadataServiceVersion = "1"
const MetadataServiceVersionHeader = "X-Databricks-Metadata-Version"
const MetadataServiceHostHeader = "X-Databricks-Host"

var errMetadataServiceMalformed = errors.New("invalid auth server URL")
var errMetadataServiceNotLocalhost = errors.New("only localhost URLs are allowed")

// Credentials provider that fetches a token from a locally running HTTP server
//
// The credentials provider will perform a GET request to the configured URL.
Expand Down Expand Up @@ -49,11 +54,12 @@ func (c MetadataServiceCredentials) Configure(ctx context.Context, cfg *Config)
}
parsedMetadataServiceURL, err := url.Parse(cfg.MetadataServiceURL)
if err != nil {
return nil, fmt.Errorf("invalid auth server URL: %w", err)
// go 1.19 doesn't allow multiple error unwraping
return nil, fmt.Errorf("%w: %s", errMetadataServiceMalformed, err)
}
// only allow localhost URLs
if parsedMetadataServiceURL.Hostname() != "localhost" && parsedMetadataServiceURL.Hostname() != "127.0.0.1" {
return nil, fmt.Errorf("invalid auth server URL: %s", cfg.MetadataServiceURL)
return nil, fmt.Errorf("%w: %s", errMetadataServiceNotLocalhost, cfg.MetadataServiceURL)
}
ms := metadataService{
metadataServiceURL: parsedMetadataServiceURL,
Expand All @@ -78,13 +84,17 @@ type metadataService struct {
func (s metadataService) Get() (*oauth2.Token, error) {
ctx, cancel := context.WithTimeout(context.Background(), metadataServiceTimeout)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, s.metadataServiceURL.String(), nil)
var inner msiToken
err := s.config.refreshClient.Do(ctx, http.MethodGet,
s.metadataServiceURL.String(),
httpclient.WithRequestHeader(MetadataServiceVersionHeader, MetadataServiceVersion),
httpclient.WithRequestHeader(MetadataServiceHostHeader, s.config.Host),
httpclient.WithResponseUnmarshal(&inner),
)
if err != nil {
return nil, fmt.Errorf("token request: %w", err)
}
req.Header.Add(MetadataServiceVersionHeader, MetadataServiceVersion)
req.Header.Add(MetadataServiceHostHeader, s.config.Host)
return makeMsiRequest(req)
return inner.Token()
}

func (t metadataService) Token() (*oauth2.Token, error) {
Expand Down
Loading

0 comments on commit 3e96fd4

Please sign in to comment.