Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions cmd/psql/psql.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,10 @@ For more information, see: https://docs.databricks.com/aws/en/oltp/
ctx := cmd.Context()
w := cmdctx.WorkspaceClient(ctx)

instances, projects := listAllDatabases(ctx, w)
instances, projects, err := listAllDatabases(ctx, w)
if err != nil {
return nil, cobra.ShellCompDirectiveError
}

var names []string
for _, inst := range instances {
Expand Down Expand Up @@ -237,8 +240,10 @@ func parseResourcePath(input string) (project, branch, endpoint string, err erro
}

// listAllDatabases fetches all database instances and projects in parallel.
// Errors are silently ignored; callers should check for empty results.
func listAllDatabases(ctx context.Context, w *databricks.WorkspaceClient) ([]database.DatabaseInstance, []postgres.Project) {
// A single failing call is tolerated because a workspace may have only one of
// the two products enabled; an error is returned only when both calls fail,
// so that e.g. an auth failure is not reported as an empty workspace.
func listAllDatabases(ctx context.Context, w *databricks.WorkspaceClient) ([]database.DatabaseInstance, []postgres.Project, error) {
type result[T any] struct {
value []T
err error
Expand All @@ -260,6 +265,10 @@ func listAllDatabases(ctx context.Context, w *databricks.WorkspaceClient) ([]dat
instResult := <-instancesCh
projResult := <-projectsCh

if instResult.err != nil && projResult.err != nil {
return nil, nil, errors.Join(instResult.err, projResult.err)
}

var instances []database.DatabaseInstance
var projects []postgres.Project
if instResult.err == nil {
Expand All @@ -269,7 +278,7 @@ func listAllDatabases(ctx context.Context, w *databricks.WorkspaceClient) ([]dat
projects = projResult.value
}

return instances, projects
return instances, projects, nil
}

// showSelectionAndConnect shows a combined dropdown of Lakebase databases.
Expand All @@ -278,8 +287,11 @@ func showSelectionAndConnect(ctx context.Context, retryConfig libpsql.RetryConfi

sp := cmdio.NewSpinner(ctx)
sp.Update("Loading Lakebase databases...")
instances, projects := listAllDatabases(ctx, w)
instances, projects, err := listAllDatabases(ctx, w)
sp.Close()
if err != nil {
return fmt.Errorf("failed to list Lakebase databases: %w", err)
}

if len(instances) == 0 && len(projects) == 0 {
return errors.New("no Lakebase databases found in workspace")
Expand Down
76 changes: 76 additions & 0 deletions cmd/psql/psql_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
package psql

import (
"errors"
"testing"

"github.com/databricks/databricks-sdk-go/experimental/mocks"
"github.com/databricks/databricks-sdk-go/service/database"
"github.com/databricks/databricks-sdk-go/service/postgres"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -81,3 +86,74 @@ func TestParseResourcePath(t *testing.T) {
})
}
}

func TestListAllDatabases(t *testing.T) {
instErr := errors.New("instances list failed")
projErr := errors.New("projects list failed")
instances := []database.DatabaseInstance{{Name: "my-instance"}}
projects := []postgres.Project{{Name: "projects/my-project"}}

tests := []struct {
name string
instErr error
projErr error
wantInstances []database.DatabaseInstance
wantProjects []postgres.Project
wantErr bool
}{
{
name: "both succeed",
wantInstances: instances,
wantProjects: projects,
},
{
name: "instances call fails",
instErr: instErr,
wantProjects: projects,
},
{
name: "projects call fails",
projErr: projErr,
wantInstances: instances,
},
{
name: "both calls fail",
instErr: instErr,
projErr: projErr,
wantErr: true,
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
m := mocks.NewMockWorkspaceClient(t)

var instReturn []database.DatabaseInstance
if tc.instErr == nil {
instReturn = instances
}
m.GetMockDatabaseAPI().EXPECT().
ListDatabaseInstancesAll(mock.Anything, database.ListDatabaseInstancesRequest{}).
Return(instReturn, tc.instErr)

var projReturn []postgres.Project
if tc.projErr == nil {
projReturn = projects
}
m.GetMockPostgresAPI().EXPECT().
ListProjectsAll(mock.Anything, postgres.ListProjectsRequest{}).
Return(projReturn, tc.projErr)

gotInstances, gotProjects, err := listAllDatabases(t.Context(), m.WorkspaceClient)
if tc.wantErr {
require.Error(t, err)
assert.ErrorIs(t, err, tc.instErr)
assert.ErrorIs(t, err, tc.projErr)
return
}
require.NoError(t, err)
assert.Equal(t, tc.wantInstances, gotInstances)
assert.Equal(t, tc.wantProjects, gotProjects)
})
}
}
Loading