Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 78 additions & 1 deletion cmd/auth/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,16 @@ import (
"encoding/json"
"errors"
"fmt"
"strings"
"time"

"github.com/databricks/cli/libs/auth"
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/databrickscfg/profile"
"github.com/databricks/databricks-sdk-go/config"
"github.com/databricks/databricks-sdk-go/credentials/u2m"
"github.com/databricks/databricks-sdk-go/credentials/u2m/cache"
"github.com/manifoldco/promptui"
"github.com/spf13/cobra"
"golang.org/x/oauth2"
)
Expand All @@ -22,7 +26,7 @@ func helpfulError(ctx context.Context, profile string, persistentAuth u2m.OAuthA

func newTokenCommand(authArguments *auth.AuthArguments) *cobra.Command {
cmd := &cobra.Command{
Use: "token [HOST]",
Use: "token [HOST_OR_PROFILE]",
Short: "Get authentication token",
Long: `Get authentication token from the local cache in ~/.databricks/token-cache.json.
Refresh the access token if it is expired. Note: This command only works with
Expand Down Expand Up @@ -93,6 +97,19 @@ func loadToken(ctx context.Context, args loadTokenArgs) (*oauth2.Token, error) {
return nil, errors.New("providing both a profile and host is not supported")
}

// If no --profile flag, try resolving the positional arg as a profile name.
// If it matches, use it. If not, fall through to host treatment.
if args.profileName == "" && len(args.args) == 1 {
candidateProfile, err := loadProfileByName(ctx, args.args[0], args.profiler)
if err != nil {
return nil, err
}
if candidateProfile != nil {
args.profileName = args.args[0]
args.args = nil
}
}

existingProfile, err := loadProfileByName(ctx, args.profileName, args.profiler)
if err != nil {
return nil, err
Expand All @@ -113,6 +130,47 @@ func loadToken(ctx context.Context, args loadTokenArgs) (*oauth2.Token, error) {
return nil, err
}

// When no profile was specified, check if multiple profiles match the
// effective cache key for this host.
if args.profileName == "" && args.authArguments.Host != "" {
cfg := &config.Config{
Host: args.authArguments.Host,
AccountID: args.authArguments.AccountID,
Experimental_IsUnifiedHost: args.authArguments.IsUnifiedHost,
}
// Canonicalize first so HostType() can correctly identify account hosts
// even when the host string lacks a scheme (e.g. "accounts.cloud.databricks.com").
cfg.CanonicalHostName()
var matchFn profile.ProfileMatchFunction
switch cfg.HostType() {
case config.AccountHost, config.UnifiedHost:
matchFn = profile.WithHostAndAccountID(args.authArguments.Host, args.authArguments.AccountID)
default:
matchFn = profile.WithHost(args.authArguments.Host)
}

matchingProfiles, err := args.profiler.LoadProfiles(ctx, matchFn)
if err != nil && !errors.Is(err, profile.ErrNoConfiguration) {
return nil, err
}
if len(matchingProfiles) > 1 {
configPath, _ := args.profiler.GetPath(ctx)
if configPath == "" {
panic("configPath is empty but LoadProfiles returned multiple profiles")
}
if !cmdio.IsPromptSupported(ctx) {
names := strings.Join(matchingProfiles.Names(), " and ")
return nil, fmt.Errorf("%s match %s in %s. Use --profile to specify which profile to use",
names, args.authArguments.Host, configPath)
}
selected, err := askForMatchingProfile(ctx, matchingProfiles, args.authArguments.Host)
if err != nil {
return nil, err
}
args.profileName = selected
}
}

args.authArguments.Profile = args.profileName

ctx, cancel := context.WithTimeout(ctx, args.tokenTimeout)
Expand Down Expand Up @@ -149,3 +207,22 @@ func loadToken(ctx context.Context, args loadTokenArgs) (*oauth2.Token, error) {
}
return t, nil
}

func askForMatchingProfile(ctx context.Context, profiles profile.Profiles, host string) (string, error) {
i, _, err := cmdio.RunSelect(ctx, &promptui.Select{
Label: "Multiple profiles match " + host,
Items: profiles,
Searcher: profiles.SearchCaseInsensitive,
StartInSearchMode: true,
Templates: &promptui.SelectTemplates{
Label: "{{ . | faint }}",
Active: `{{.Name | bold}} ({{.Host|faint}})`,
Inactive: `{{.Name}}`,
Selected: `{{ "Using profile" | faint }}: {{ .Name | bold }}`,
},
})
if err != nil {
return "", err
}
return profiles[i].Name, nil
}
201 changes: 200 additions & 1 deletion cmd/auth/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"time"

"github.com/databricks/cli/libs/auth"
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/databrickscfg/profile"
"github.com/databricks/databricks-sdk-go/credentials/u2m"
"github.com/databricks/databricks-sdk-go/httpclient/fixtures"
Expand Down Expand Up @@ -89,6 +90,32 @@ func TestToken_loadToken(t *testing.T) {
Host: "https://accounts.cloud.databricks.com",
AccountID: "active",
},
{
Name: "workspace-a",
Host: "https://workspace-a.cloud.databricks.com",
},
{
Name: "dup1",
Host: "https://shared.cloud.databricks.com",
},
{
Name: "dup2",
Host: "https://shared.cloud.databricks.com",
},
{
Name: "acct-dup1",
Host: "https://accounts.cloud.databricks.com",
AccountID: "same-account",
},
{
Name: "acct-dup2",
Host: "https://accounts.cloud.databricks.com",
AccountID: "same-account",
},
{
Name: "default.dev",
Host: "https://dev.cloud.databricks.com",
},
},
}
tokenCache := &inMemoryTokenCache{
Expand All @@ -107,6 +134,18 @@ func TestToken_loadToken(t *testing.T) {
RefreshToken: "active",
Expiry: time.Now().Add(1 * time.Hour),
},
"workspace-a": {
RefreshToken: "workspace-a",
Expiry: time.Now().Add(1 * time.Hour),
},
"https://workspace-a.cloud.databricks.com": {
RefreshToken: "workspace-a",
Expiry: time.Now().Add(1 * time.Hour),
},
"default.dev": {
RefreshToken: "default.dev",
Expiry: time.Now().Add(1 * time.Hour),
},
},
}
validateToken := func(resp *oauth2.Token) {
Expand All @@ -116,6 +155,7 @@ func TestToken_loadToken(t *testing.T) {

cases := []struct {
name string
ctx context.Context
args loadTokenArgs
validateToken func(*oauth2.Token)
wantErr string
Expand Down Expand Up @@ -223,10 +263,169 @@ func TestToken_loadToken(t *testing.T) {
},
validateToken: validateToken,
},
{
name: "positional arg resolved as profile name",
args: loadTokenArgs{
authArguments: &auth.AuthArguments{},
profileName: "",
args: []string{"workspace-a"},
tokenTimeout: 1 * time.Hour,
profiler: profiler,
persistentAuthOpts: []u2m.PersistentAuthOption{
u2m.WithTokenCache(tokenCache),
u2m.WithOAuthEndpointSupplier(&MockApiClient{}),
u2m.WithHttpClient(&http.Client{Transport: fixtures.SliceTransport{refreshSuccessTokenResponse}}),
},
},
validateToken: validateToken,
},
{
name: "positional arg with dot treated as host when no profile matches",
args: loadTokenArgs{
authArguments: &auth.AuthArguments{},
profileName: "",
args: []string{"workspace-a.cloud.databricks.com"},
tokenTimeout: 1 * time.Hour,
profiler: profiler,
persistentAuthOpts: []u2m.PersistentAuthOption{
u2m.WithTokenCache(tokenCache),
u2m.WithOAuthEndpointSupplier(&MockApiClient{}),
u2m.WithHttpClient(&http.Client{Transport: fixtures.SliceTransport{refreshSuccessTokenResponse}}),
},
},
validateToken: validateToken,
},
{
name: "dotted profile name resolved as profile not host",
args: loadTokenArgs{
authArguments: &auth.AuthArguments{},
profileName: "",
args: []string{"default.dev"},
tokenTimeout: 1 * time.Hour,
profiler: profiler,
persistentAuthOpts: []u2m.PersistentAuthOption{
u2m.WithTokenCache(tokenCache),
u2m.WithOAuthEndpointSupplier(&MockApiClient{}),
u2m.WithHttpClient(&http.Client{Transport: fixtures.SliceTransport{refreshSuccessTokenResponse}}),
},
},
validateToken: validateToken,
},
{
name: "positional arg not a profile falls through to host",
args: loadTokenArgs{
authArguments: &auth.AuthArguments{},
profileName: "",
args: []string{"nonexistent"},
tokenTimeout: 1 * time.Hour,
profiler: profiler,
persistentAuthOpts: []u2m.PersistentAuthOption{
u2m.WithTokenCache(tokenCache),
u2m.WithOAuthEndpointSupplier(&MockApiClient{}),
},
},
wantErr: "cache: databricks OAuth is not configured for this host. " +
"Try logging in again with `databricks auth login --host https://nonexistent` before retrying. " +
"If this fails, please report this issue to the Databricks CLI maintainers at https://github.com/databricks/cli/issues/new",
},
{
name: "scheme-less account host ambiguity detected correctly",
ctx: cmdio.MockDiscard(context.Background()),
args: loadTokenArgs{
authArguments: &auth.AuthArguments{
Host: "accounts.cloud.databricks.com",
AccountID: "same-account",
},
profileName: "",
args: []string{},
tokenTimeout: 1 * time.Hour,
profiler: profiler,
persistentAuthOpts: []u2m.PersistentAuthOption{
u2m.WithTokenCache(tokenCache),
u2m.WithOAuthEndpointSupplier(&MockApiClient{}),
},
},
wantErr: "acct-dup1 and acct-dup2 match accounts.cloud.databricks.com in <in memory>. Use --profile to specify which profile to use",
},
{
name: "workspace host ambiguity — multiple profiles, non-interactive",
ctx: cmdio.MockDiscard(context.Background()),
args: loadTokenArgs{
authArguments: &auth.AuthArguments{
Host: "https://shared.cloud.databricks.com",
},
profileName: "",
args: []string{},
tokenTimeout: 1 * time.Hour,
profiler: profiler,
persistentAuthOpts: []u2m.PersistentAuthOption{
u2m.WithTokenCache(tokenCache),
u2m.WithOAuthEndpointSupplier(&MockApiClient{}),
},
},
wantErr: "dup1 and dup2 match https://shared.cloud.databricks.com in <in memory>. Use --profile to specify which profile to use",
},
{
name: "account host — same host, different account IDs — no ambiguity",
args: loadTokenArgs{
authArguments: &auth.AuthArguments{
Host: "https://accounts.cloud.databricks.com",
AccountID: "active",
},
profileName: "",
args: []string{},
tokenTimeout: 1 * time.Hour,
profiler: profiler,
persistentAuthOpts: []u2m.PersistentAuthOption{
u2m.WithTokenCache(tokenCache),
u2m.WithOAuthEndpointSupplier(&MockApiClient{}),
u2m.WithHttpClient(&http.Client{Transport: fixtures.SliceTransport{refreshSuccessTokenResponse}}),
},
},
validateToken: validateToken,
},
{
name: "account host — same host AND same account ID — ambiguity",
ctx: cmdio.MockDiscard(context.Background()),
args: loadTokenArgs{
authArguments: &auth.AuthArguments{
Host: "https://accounts.cloud.databricks.com",
AccountID: "same-account",
},
profileName: "",
args: []string{},
tokenTimeout: 1 * time.Hour,
profiler: profiler,
persistentAuthOpts: []u2m.PersistentAuthOption{
u2m.WithTokenCache(tokenCache),
u2m.WithOAuthEndpointSupplier(&MockApiClient{}),
},
},
wantErr: "acct-dup1 and acct-dup2 match https://accounts.cloud.databricks.com in <in memory>. Use --profile to specify which profile to use",
},
{
name: "profile flag + positional non-host arg still errors",
args: loadTokenArgs{
authArguments: &auth.AuthArguments{},
profileName: "active",
args: []string{"workspace-a"},
tokenTimeout: 1 * time.Hour,
profiler: profiler,
persistentAuthOpts: []u2m.PersistentAuthOption{
u2m.WithTokenCache(tokenCache),
u2m.WithOAuthEndpointSupplier(&MockApiClient{}),
},
},
wantErr: "providing both a profile and host is not supported",
},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
got, err := loadToken(context.Background(), c.args)
ctx := c.ctx
if ctx == nil {
ctx = context.Background()
}
got, err := loadToken(ctx, c.args)
if c.wantErr != "" {
assert.Equal(t, c.wantErr, err.Error())
} else {
Expand Down
25 changes: 25 additions & 0 deletions libs/databrickscfg/profile/profiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package profile

import (
"context"

"github.com/databricks/databricks-sdk-go/config"
)

type ProfileMatchFunction func(Profile) bool
Expand Down Expand Up @@ -30,6 +32,29 @@ func WithName(name string) ProfileMatchFunction {
}
}

// WithHost returns a ProfileMatchFunction that matches profiles whose
// canonical host equals the given host.
func WithHost(host string) ProfileMatchFunction {
target := canonicalizeHost(host)
return func(p Profile) bool {
return p.Host != "" && canonicalizeHost(p.Host) == target
}
}

// WithHostAndAccountID returns a ProfileMatchFunction that matches profiles
// by both canonical host and account ID.
func WithHostAndAccountID(host, accountID string) ProfileMatchFunction {
target := canonicalizeHost(host)
return func(p Profile) bool {
return p.Host != "" && canonicalizeHost(p.Host) == target && p.AccountID == accountID
}
}

// canonicalizeHost normalizes a host using the SDK's canonical host logic.
func canonicalizeHost(host string) string {
return (&config.Config{Host: host}).CanonicalHostName()
}

type Profiler interface {
LoadProfiles(context.Context, ProfileMatchFunction) (Profiles, error)
GetPath(context.Context) (string, error)
Expand Down
Loading