From 1d0a4627a9c8f57609af971b1efad043727e0f54 Mon Sep 17 00:00:00 2001 From: Scott McKendry Date: Thu, 30 Apr 2026 18:16:50 +1200 Subject: [PATCH 1/3] refactor(db): use new store interface --- internal/assets/assets.go | 2 +- .../{ => sqlite}/000001_init_sqlite.down.sql | 0 .../{ => sqlite}/000001_init_sqlite.up.sql | 0 .../{ => sqlite}/000002_oauth_name.down.sql | 0 .../{ => sqlite}/000002_oauth_name.up.sql | 0 .../{ => sqlite}/000003_oauth_sub.down.sql | 0 .../{ => sqlite}/000003_oauth_sub.up.sql | 0 .../{ => sqlite}/000004_created_at.down.sql | 0 .../{ => sqlite}/000004_created_at.up.sql | 0 .../{ => sqlite}/000005_oidc_session.down.sql | 0 .../{ => sqlite}/000005_oidc_session.up.sql | 0 .../{ => sqlite}/000006_oidc_nonce.down.sql | 0 .../{ => sqlite}/000006_oidc_nonce.up.sql | 0 .../{ => sqlite}/000007_oidc_pkce.down.sql | 0 .../{ => sqlite}/000007_oidc_pkce.up.sql | 0 .../000008_oidc_code_reuse.down.sql | 0 .../000008_oidc_code_reuse.up.sql | 0 .../000009_oidc_userinfo_profile.down.sql | 0 .../000009_oidc_userinfo_profile.up.sql | 0 internal/bootstrap/app_bootstrap.go | 11 +-- internal/bootstrap/db_bootstrap.go | 19 ++++- internal/bootstrap/service_bootstrap.go | 2 +- internal/config/config.go | 2 +- internal/controller/oidc_controller_test.go | 13 +--- internal/controller/proxy_controller_test.go | 14 +--- internal/controller/user_controller_test.go | 14 +--- .../controller/well_known_controller_test.go | 14 +--- internal/repository/models.go | 73 ++++--------------- internal/repository/{ => sqlite}/db.go | 4 +- internal/repository/sqlite/models.go | 64 ++++++++++++++++ .../{ => sqlite}/oidc_queries.sql.go | 4 +- .../{ => sqlite}/session_queries.sql.go | 4 +- internal/repository/store.go | 41 +++++++++++ internal/service/auth_service.go | 4 +- internal/service/oidc_service.go | 4 +- sql/{ => sqlite}/oidc_queries.sql | 0 sql/{ => sqlite}/oidc_schemas.sql | 0 sql/{ => sqlite}/session_queries.sql | 0 sql/{ => sqlite}/session_schemas.sql | 0 sqlc.yml | 8 +- 40 files changed, 164 insertions(+), 133 deletions(-) rename internal/assets/migrations/{ => sqlite}/000001_init_sqlite.down.sql (100%) rename internal/assets/migrations/{ => sqlite}/000001_init_sqlite.up.sql (100%) rename internal/assets/migrations/{ => sqlite}/000002_oauth_name.down.sql (100%) rename internal/assets/migrations/{ => sqlite}/000002_oauth_name.up.sql (100%) rename internal/assets/migrations/{ => sqlite}/000003_oauth_sub.down.sql (100%) rename internal/assets/migrations/{ => sqlite}/000003_oauth_sub.up.sql (100%) rename internal/assets/migrations/{ => sqlite}/000004_created_at.down.sql (100%) rename internal/assets/migrations/{ => sqlite}/000004_created_at.up.sql (100%) rename internal/assets/migrations/{ => sqlite}/000005_oidc_session.down.sql (100%) rename internal/assets/migrations/{ => sqlite}/000005_oidc_session.up.sql (100%) rename internal/assets/migrations/{ => sqlite}/000006_oidc_nonce.down.sql (100%) rename internal/assets/migrations/{ => sqlite}/000006_oidc_nonce.up.sql (100%) rename internal/assets/migrations/{ => sqlite}/000007_oidc_pkce.down.sql (100%) rename internal/assets/migrations/{ => sqlite}/000007_oidc_pkce.up.sql (100%) rename internal/assets/migrations/{ => sqlite}/000008_oidc_code_reuse.down.sql (100%) rename internal/assets/migrations/{ => sqlite}/000008_oidc_code_reuse.up.sql (100%) rename internal/assets/migrations/{ => sqlite}/000009_oidc_userinfo_profile.down.sql (100%) rename internal/assets/migrations/{ => sqlite}/000009_oidc_userinfo_profile.up.sql (100%) rename internal/repository/{ => sqlite}/db.go (93%) create mode 100644 internal/repository/sqlite/models.go rename internal/repository/{ => sqlite}/oidc_queries.sql.go (99%) rename internal/repository/{ => sqlite}/session_queries.sql.go (98%) create mode 100644 internal/repository/store.go rename sql/{ => sqlite}/oidc_queries.sql (100%) rename sql/{ => sqlite}/oidc_schemas.sql (100%) rename sql/{ => sqlite}/session_queries.sql (100%) rename sql/{ => sqlite}/session_schemas.sql (100%) diff --git a/internal/assets/assets.go b/internal/assets/assets.go index 412403c9..a5c3d79d 100644 --- a/internal/assets/assets.go +++ b/internal/assets/assets.go @@ -11,5 +11,5 @@ var FrontendAssets embed.FS // Migrations // -//go:embed migrations/*.sql +//go:embed migrations/sqlite/*.sql var Migrations embed.FS diff --git a/internal/assets/migrations/000001_init_sqlite.down.sql b/internal/assets/migrations/sqlite/000001_init_sqlite.down.sql similarity index 100% rename from internal/assets/migrations/000001_init_sqlite.down.sql rename to internal/assets/migrations/sqlite/000001_init_sqlite.down.sql diff --git a/internal/assets/migrations/000001_init_sqlite.up.sql b/internal/assets/migrations/sqlite/000001_init_sqlite.up.sql similarity index 100% rename from internal/assets/migrations/000001_init_sqlite.up.sql rename to internal/assets/migrations/sqlite/000001_init_sqlite.up.sql diff --git a/internal/assets/migrations/000002_oauth_name.down.sql b/internal/assets/migrations/sqlite/000002_oauth_name.down.sql similarity index 100% rename from internal/assets/migrations/000002_oauth_name.down.sql rename to internal/assets/migrations/sqlite/000002_oauth_name.down.sql diff --git a/internal/assets/migrations/000002_oauth_name.up.sql b/internal/assets/migrations/sqlite/000002_oauth_name.up.sql similarity index 100% rename from internal/assets/migrations/000002_oauth_name.up.sql rename to internal/assets/migrations/sqlite/000002_oauth_name.up.sql diff --git a/internal/assets/migrations/000003_oauth_sub.down.sql b/internal/assets/migrations/sqlite/000003_oauth_sub.down.sql similarity index 100% rename from internal/assets/migrations/000003_oauth_sub.down.sql rename to internal/assets/migrations/sqlite/000003_oauth_sub.down.sql diff --git a/internal/assets/migrations/000003_oauth_sub.up.sql b/internal/assets/migrations/sqlite/000003_oauth_sub.up.sql similarity index 100% rename from internal/assets/migrations/000003_oauth_sub.up.sql rename to internal/assets/migrations/sqlite/000003_oauth_sub.up.sql diff --git a/internal/assets/migrations/000004_created_at.down.sql b/internal/assets/migrations/sqlite/000004_created_at.down.sql similarity index 100% rename from internal/assets/migrations/000004_created_at.down.sql rename to internal/assets/migrations/sqlite/000004_created_at.down.sql diff --git a/internal/assets/migrations/000004_created_at.up.sql b/internal/assets/migrations/sqlite/000004_created_at.up.sql similarity index 100% rename from internal/assets/migrations/000004_created_at.up.sql rename to internal/assets/migrations/sqlite/000004_created_at.up.sql diff --git a/internal/assets/migrations/000005_oidc_session.down.sql b/internal/assets/migrations/sqlite/000005_oidc_session.down.sql similarity index 100% rename from internal/assets/migrations/000005_oidc_session.down.sql rename to internal/assets/migrations/sqlite/000005_oidc_session.down.sql diff --git a/internal/assets/migrations/000005_oidc_session.up.sql b/internal/assets/migrations/sqlite/000005_oidc_session.up.sql similarity index 100% rename from internal/assets/migrations/000005_oidc_session.up.sql rename to internal/assets/migrations/sqlite/000005_oidc_session.up.sql diff --git a/internal/assets/migrations/000006_oidc_nonce.down.sql b/internal/assets/migrations/sqlite/000006_oidc_nonce.down.sql similarity index 100% rename from internal/assets/migrations/000006_oidc_nonce.down.sql rename to internal/assets/migrations/sqlite/000006_oidc_nonce.down.sql diff --git a/internal/assets/migrations/000006_oidc_nonce.up.sql b/internal/assets/migrations/sqlite/000006_oidc_nonce.up.sql similarity index 100% rename from internal/assets/migrations/000006_oidc_nonce.up.sql rename to internal/assets/migrations/sqlite/000006_oidc_nonce.up.sql diff --git a/internal/assets/migrations/000007_oidc_pkce.down.sql b/internal/assets/migrations/sqlite/000007_oidc_pkce.down.sql similarity index 100% rename from internal/assets/migrations/000007_oidc_pkce.down.sql rename to internal/assets/migrations/sqlite/000007_oidc_pkce.down.sql diff --git a/internal/assets/migrations/000007_oidc_pkce.up.sql b/internal/assets/migrations/sqlite/000007_oidc_pkce.up.sql similarity index 100% rename from internal/assets/migrations/000007_oidc_pkce.up.sql rename to internal/assets/migrations/sqlite/000007_oidc_pkce.up.sql diff --git a/internal/assets/migrations/000008_oidc_code_reuse.down.sql b/internal/assets/migrations/sqlite/000008_oidc_code_reuse.down.sql similarity index 100% rename from internal/assets/migrations/000008_oidc_code_reuse.down.sql rename to internal/assets/migrations/sqlite/000008_oidc_code_reuse.down.sql diff --git a/internal/assets/migrations/000008_oidc_code_reuse.up.sql b/internal/assets/migrations/sqlite/000008_oidc_code_reuse.up.sql similarity index 100% rename from internal/assets/migrations/000008_oidc_code_reuse.up.sql rename to internal/assets/migrations/sqlite/000008_oidc_code_reuse.up.sql diff --git a/internal/assets/migrations/000009_oidc_userinfo_profile.down.sql b/internal/assets/migrations/sqlite/000009_oidc_userinfo_profile.down.sql similarity index 100% rename from internal/assets/migrations/000009_oidc_userinfo_profile.down.sql rename to internal/assets/migrations/sqlite/000009_oidc_userinfo_profile.down.sql diff --git a/internal/assets/migrations/000009_oidc_userinfo_profile.up.sql b/internal/assets/migrations/sqlite/000009_oidc_userinfo_profile.up.sql similarity index 100% rename from internal/assets/migrations/000009_oidc_userinfo_profile.up.sql rename to internal/assets/migrations/sqlite/000009_oidc_userinfo_profile.up.sql diff --git a/internal/bootstrap/app_bootstrap.go b/internal/bootstrap/app_bootstrap.go index 3879c05e..8e204f0e 100644 --- a/internal/bootstrap/app_bootstrap.go +++ b/internal/bootstrap/app_bootstrap.go @@ -130,17 +130,14 @@ func (app *BootstrapApp) Setup() error { tlog.App.Trace().Str("redirectCookieName", app.context.redirectCookieName).Msg("Redirect cookie name") // Database - db, err := app.SetupDatabase(app.config.Database.Path) + store, err := app.SetupStore() if err != nil { return fmt.Errorf("failed to setup database: %w", err) } - // Queries - queries := repository.New(db) - // Services - services, err := app.initServices(queries) + services, err := app.initServices(store) if err != nil { return fmt.Errorf("failed to initialize services: %w", err) @@ -196,7 +193,7 @@ func (app *BootstrapApp) Setup() error { // Start db cleanup routine tlog.App.Debug().Msg("Starting database cleanup routine") - go app.dbCleanupRoutine(queries) + go app.dbCleanupRoutine(store) // If analytics are not disabled, start heartbeat if app.config.Analytics.Enabled { @@ -286,7 +283,7 @@ func (app *BootstrapApp) heartbeatRoutine() { } } -func (app *BootstrapApp) dbCleanupRoutine(queries *repository.Queries) { +func (app *BootstrapApp) dbCleanupRoutine(queries repository.Store) { ticker := time.NewTicker(time.Duration(30) * time.Minute) defer ticker.Stop() ctx := context.Background() diff --git a/internal/bootstrap/db_bootstrap.go b/internal/bootstrap/db_bootstrap.go index 3f48f793..efc21311 100644 --- a/internal/bootstrap/db_bootstrap.go +++ b/internal/bootstrap/db_bootstrap.go @@ -7,6 +7,8 @@ import ( "path/filepath" "github.com/tinyauthapp/tinyauth/internal/assets" + "github.com/tinyauthapp/tinyauth/internal/repository" + "github.com/tinyauthapp/tinyauth/internal/repository/sqlite" "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/database/sqlite3" @@ -14,7 +16,18 @@ import ( _ "modernc.org/sqlite" ) -func (app *BootstrapApp) SetupDatabase(databasePath string) (*sql.DB, error) { +func (app *BootstrapApp) SetupStore() (repository.Store, error) { + return app.setupSQLite(app.config.Database.Path) +} + +// NewSQLiteStore opens a SQLite database at the given path, runs migrations, and returns a Store. +// Useful for testing or when constructing a store outside of a BootstrapApp. +func NewSQLiteStore(databasePath string) (repository.Store, error) { + app := &BootstrapApp{} + return app.setupSQLite(databasePath) +} + +func (app *BootstrapApp) setupSQLite(databasePath string) (repository.Store, error) { dir := filepath.Dir(databasePath) if err := os.MkdirAll(dir, 0750); err != nil { @@ -31,7 +44,7 @@ func (app *BootstrapApp) SetupDatabase(databasePath string) (*sql.DB, error) { // if the sqlite connection starts being a bottleneck db.SetMaxOpenConns(1) - migrations, err := iofs.New(assets.Migrations, "migrations") + migrations, err := iofs.New(assets.Migrations, "migrations/sqlite") if err != nil { return nil, fmt.Errorf("failed to create migrations: %w", err) @@ -53,5 +66,5 @@ func (app *BootstrapApp) SetupDatabase(databasePath string) (*sql.DB, error) { return nil, fmt.Errorf("failed to migrate database: %w", err) } - return db, nil + return sqlite.New(db), nil } diff --git a/internal/bootstrap/service_bootstrap.go b/internal/bootstrap/service_bootstrap.go index 91e2b50b..7cdeaf4d 100644 --- a/internal/bootstrap/service_bootstrap.go +++ b/internal/bootstrap/service_bootstrap.go @@ -18,7 +18,7 @@ type Services struct { oidcService *service.OIDCService } -func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, error) { +func (app *BootstrapApp) initServices(queries repository.Store) (Services, error) { services := Services{} ldapService := service.NewLdapService(service.LdapServiceConfig{ diff --git a/internal/config/config.go b/internal/config/config.go index e364b458..5b14e27e 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -95,7 +95,7 @@ type Config struct { } type DatabaseConfig struct { - Path string `description:"The path to the database, including file name." yaml:"path"` + Path string `description:"The path to the SQLite database, including file name." yaml:"path"` } type AnalyticsConfig struct { diff --git a/internal/controller/oidc_controller_test.go b/internal/controller/oidc_controller_test.go index a09697bf..991f6759 100644 --- a/internal/controller/oidc_controller_test.go +++ b/internal/controller/oidc_controller_test.go @@ -15,7 +15,6 @@ import ( "github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/config" "github.com/tinyauthapp/tinyauth/internal/controller" - "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/stretchr/testify/assert" @@ -848,13 +847,10 @@ func TestOIDCController(t *testing.T) { }, } - app := bootstrap.NewBootstrapApp(config.Config{}) - - db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db")) + store, err := bootstrap.NewSQLiteStore(path.Join(tempDir, "tinyauth.db")) require.NoError(t, err) - queries := repository.New(db) - oidcService := service.NewOIDCService(oidcServiceCfg, queries) + oidcService := service.NewOIDCService(oidcServiceCfg, store) err = oidcService.Init() require.NoError(t, err) @@ -877,9 +873,4 @@ func TestOIDCController(t *testing.T) { test.run(t, router, recorder) }) } - - t.Cleanup(func() { - err = db.Close() - require.NoError(t, err) - }) } diff --git a/internal/controller/proxy_controller_test.go b/internal/controller/proxy_controller_test.go index 8efbd31c..adfc7fb1 100644 --- a/internal/controller/proxy_controller_test.go +++ b/internal/controller/proxy_controller_test.go @@ -9,7 +9,6 @@ import ( "github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/config" "github.com/tinyauthapp/tinyauth/internal/controller" - "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/stretchr/testify/assert" @@ -393,13 +392,9 @@ func TestProxyController(t *testing.T) { oauthBrokerCfgs := make(map[string]config.OAuthServiceConfig) - app := bootstrap.NewBootstrapApp(config.Config{}) - - db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db")) + store, err := bootstrap.NewSQLiteStore(path.Join(tempDir, "tinyauth.db")) require.NoError(t, err) - queries := repository.New(db) - docker := service.NewDockerService() err = docker.Init() require.NoError(t, err) @@ -412,7 +407,7 @@ func TestProxyController(t *testing.T) { err = broker.Init() require.NoError(t, err) - authService := service.NewAuthService(authServiceCfg, ldap, queries, broker) + authService := service.NewAuthService(authServiceCfg, ldap, store, broker) err = authService.Init() require.NoError(t, err) @@ -437,9 +432,4 @@ func TestProxyController(t *testing.T) { test.run(t, router, recorder) }) } - - t.Cleanup(func() { - err = db.Close() - require.NoError(t, err) - }) } diff --git a/internal/controller/user_controller_test.go b/internal/controller/user_controller_test.go index 65ef15ef..b67c70fa 100644 --- a/internal/controller/user_controller_test.go +++ b/internal/controller/user_controller_test.go @@ -13,7 +13,6 @@ import ( "github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/config" "github.com/tinyauthapp/tinyauth/internal/controller" - "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/stretchr/testify/assert" @@ -351,13 +350,9 @@ func TestUserController(t *testing.T) { oauthBrokerCfgs := make(map[string]config.OAuthServiceConfig) - app := bootstrap.NewBootstrapApp(config.Config{}) - - db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db")) + store, err := bootstrap.NewSQLiteStore(path.Join(tempDir, "tinyauth.db")) require.NoError(t, err) - queries := repository.New(db) - docker := service.NewDockerService() err = docker.Init() require.NoError(t, err) @@ -370,7 +365,7 @@ func TestUserController(t *testing.T) { err = broker.Init() require.NoError(t, err) - authService := service.NewAuthService(authServiceCfg, ldap, queries, broker) + authService := service.NewAuthService(authServiceCfg, ldap, store, broker) err = authService.Init() require.NoError(t, err) @@ -435,9 +430,4 @@ func TestUserController(t *testing.T) { test.run(t, router, recorder) }) } - - t.Cleanup(func() { - err = db.Close() - require.NoError(t, err) - }) } diff --git a/internal/controller/well_known_controller_test.go b/internal/controller/well_known_controller_test.go index 7d8d05f5..eba449b0 100644 --- a/internal/controller/well_known_controller_test.go +++ b/internal/controller/well_known_controller_test.go @@ -11,7 +11,6 @@ import ( "github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/config" "github.com/tinyauthapp/tinyauth/internal/controller" - "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/stretchr/testify/assert" @@ -101,14 +100,10 @@ func TestWellKnownController(t *testing.T) { }, } - app := bootstrap.NewBootstrapApp(config.Config{}) - - db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db")) + store, err := bootstrap.NewSQLiteStore(path.Join(tempDir, "tinyauth.db")) require.NoError(t, err) - queries := repository.New(db) - - oidcService := service.NewOIDCService(oidcServiceCfg, queries) + oidcService := service.NewOIDCService(oidcServiceCfg, store) err = oidcService.Init() require.NoError(t, err) @@ -125,9 +120,4 @@ func TestWellKnownController(t *testing.T) { test.run(t, router, recorder) }) } - - t.Cleanup(func() { - err = db.Close() - require.NoError(t, err) - }) } diff --git a/internal/repository/models.go b/internal/repository/models.go index bc2e2c66..0c33e038 100644 --- a/internal/repository/models.go +++ b/internal/repository/models.go @@ -1,64 +1,19 @@ -// Code generated by sqlc. DO NOT EDIT. -// versions: -// sqlc v1.30.0 - package repository -type OidcCode struct { - Sub string - CodeHash string - Scope string - RedirectURI string - ClientID string - ExpiresAt int64 - Nonce string - CodeChallenge string -} +// This file is a stop-gap until more drivers are added. It re-exports the models from the sqlite package so that the rest +// of the codebase can import them from a single location without needing to know about the underlying database implementation. -type OidcToken struct { - Sub string - AccessTokenHash string - RefreshTokenHash string - CodeHash string - Scope string - ClientID string - TokenExpiresAt int64 - RefreshTokenExpiresAt int64 - Nonce string -} +import "github.com/tinyauthapp/tinyauth/internal/repository/sqlite" -type OidcUserinfo struct { - Sub string - Name string - PreferredUsername string - Email string - Groups string - UpdatedAt int64 - GivenName string - FamilyName string - MiddleName string - Nickname string - Profile string - Picture string - Website string - Gender string - Birthdate string - Zoneinfo string - Locale string - PhoneNumber string - Address string -} +type Session = sqlite.Session +type OidcCode = sqlite.OidcCode +type OidcToken = sqlite.OidcToken +type OidcUserinfo = sqlite.OidcUserinfo -type Session struct { - UUID string - Username string - Email string - Name string - Provider string - TotpPending bool - OAuthGroups string - Expiry int64 - CreatedAt int64 - OAuthName string - OAuthSub string -} +type CreateSessionParams = sqlite.CreateSessionParams +type UpdateSessionParams = sqlite.UpdateSessionParams +type CreateOidcCodeParams = sqlite.CreateOidcCodeParams +type CreateOidcTokenParams = sqlite.CreateOidcTokenParams +type UpdateOidcTokenByRefreshTokenParams = sqlite.UpdateOidcTokenByRefreshTokenParams +type DeleteExpiredOidcTokensParams = sqlite.DeleteExpiredOidcTokensParams +type CreateOidcUserInfoParams = sqlite.CreateOidcUserInfoParams diff --git a/internal/repository/db.go b/internal/repository/sqlite/db.go similarity index 93% rename from internal/repository/db.go rename to internal/repository/sqlite/db.go index 998bfd3b..ee310fc2 100644 --- a/internal/repository/db.go +++ b/internal/repository/sqlite/db.go @@ -1,8 +1,8 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.30.0 +// sqlc v1.31.0 -package repository +package sqlite import ( "context" diff --git a/internal/repository/sqlite/models.go b/internal/repository/sqlite/models.go new file mode 100644 index 00000000..caf37f4c --- /dev/null +++ b/internal/repository/sqlite/models.go @@ -0,0 +1,64 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.0 + +package sqlite + +type OidcCode struct { + Sub string + CodeHash string + Scope string + RedirectURI string + ClientID string + ExpiresAt int64 + Nonce string + CodeChallenge string +} + +type OidcToken struct { + Sub string + AccessTokenHash string + RefreshTokenHash string + CodeHash string + Scope string + ClientID string + TokenExpiresAt int64 + RefreshTokenExpiresAt int64 + Nonce string +} + +type OidcUserinfo struct { + Sub string + Name string + PreferredUsername string + Email string + Groups string + UpdatedAt int64 + GivenName string + FamilyName string + MiddleName string + Nickname string + Profile string + Picture string + Website string + Gender string + Birthdate string + Zoneinfo string + Locale string + PhoneNumber string + Address string +} + +type Session struct { + UUID string + Username string + Email string + Name string + Provider string + TotpPending bool + OAuthGroups string + Expiry int64 + CreatedAt int64 + OAuthName string + OAuthSub string +} diff --git a/internal/repository/oidc_queries.sql.go b/internal/repository/sqlite/oidc_queries.sql.go similarity index 99% rename from internal/repository/oidc_queries.sql.go rename to internal/repository/sqlite/oidc_queries.sql.go index 7caac9d4..027ac421 100644 --- a/internal/repository/oidc_queries.sql.go +++ b/internal/repository/sqlite/oidc_queries.sql.go @@ -1,9 +1,9 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.30.0 +// sqlc v1.31.0 // source: oidc_queries.sql -package repository +package sqlite import ( "context" diff --git a/internal/repository/session_queries.sql.go b/internal/repository/sqlite/session_queries.sql.go similarity index 98% rename from internal/repository/session_queries.sql.go rename to internal/repository/sqlite/session_queries.sql.go index c846c3f9..4271b727 100644 --- a/internal/repository/session_queries.sql.go +++ b/internal/repository/sqlite/session_queries.sql.go @@ -1,9 +1,9 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.30.0 +// sqlc v1.31.0 // source: session_queries.sql -package repository +package sqlite import ( "context" diff --git a/internal/repository/store.go b/internal/repository/store.go new file mode 100644 index 00000000..765df6a5 --- /dev/null +++ b/internal/repository/store.go @@ -0,0 +1,41 @@ +package repository + +import "context" + +// Store is the interface that all storage drivers must implement. +// The sqlc-generated *Queries struct satisfies this interface for SQLite. +// Future drivers (postgres, etc.) must return the shared types defined in this package. +type Store interface { + // Sessions + CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error) + GetSession(ctx context.Context, uuid string) (Session, error) + UpdateSession(ctx context.Context, arg UpdateSessionParams) (Session, error) + DeleteSession(ctx context.Context, uuid string) error + DeleteExpiredSessions(ctx context.Context, expiry int64) error + + // OIDC codes + CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error) + GetOidcCode(ctx context.Context, codeHash string) (OidcCode, error) + GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, error) + GetOidcCodeUnsafe(ctx context.Context, codeHash string) (OidcCode, error) + GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (OidcCode, error) + DeleteOidcCode(ctx context.Context, codeHash string) error + DeleteOidcCodeBySub(ctx context.Context, sub string) error + DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]OidcCode, error) + + // OIDC tokens + CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams) (OidcToken, error) + GetOidcToken(ctx context.Context, accessTokenHash string) (OidcToken, error) + GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (OidcToken, error) + GetOidcTokenBySub(ctx context.Context, sub string) (OidcToken, error) + UpdateOidcTokenByRefreshToken(ctx context.Context, arg UpdateOidcTokenByRefreshTokenParams) (OidcToken, error) + DeleteOidcToken(ctx context.Context, accessTokenHash string) error + DeleteOidcTokenBySub(ctx context.Context, sub string) error + DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error + DeleteExpiredOidcTokens(ctx context.Context, arg DeleteExpiredOidcTokensParams) ([]OidcToken, error) + + // OIDC userinfo + CreateOidcUserInfo(ctx context.Context, arg CreateOidcUserInfoParams) (OidcUserinfo, error) + GetOidcUserInfo(ctx context.Context, sub string) (OidcUserinfo, error) + DeleteOidcUserInfo(ctx context.Context, sub string) error +} diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index 0311229d..ab343396 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -90,14 +90,14 @@ type AuthService struct { loginMutex sync.RWMutex ldapGroupsMutex sync.RWMutex ldap *LdapService - queries *repository.Queries + queries repository.Store oauthBroker *OAuthBrokerService lockdown *Lockdown lockdownCtx context.Context lockdownCancelFunc context.CancelFunc } -func NewAuthService(config AuthServiceConfig, ldap *LdapService, queries *repository.Queries, oauthBroker *OAuthBrokerService) *AuthService { +func NewAuthService(config AuthServiceConfig, ldap *LdapService, queries repository.Store, oauthBroker *OAuthBrokerService) *AuthService { return &AuthService{ config: config, loginAttempts: make(map[string]*LoginAttempt), diff --git a/internal/service/oidc_service.go b/internal/service/oidc_service.go index 1ac138ae..e5f7ea76 100644 --- a/internal/service/oidc_service.go +++ b/internal/service/oidc_service.go @@ -121,7 +121,7 @@ type OIDCServiceConfig struct { type OIDCService struct { config OIDCServiceConfig - queries *repository.Queries + queries repository.Store clients map[string]config.OIDCClientConfig privateKey *rsa.PrivateKey publicKey crypto.PublicKey @@ -129,7 +129,7 @@ type OIDCService struct { isConfigured bool } -func NewOIDCService(config OIDCServiceConfig, queries *repository.Queries) *OIDCService { +func NewOIDCService(config OIDCServiceConfig, queries repository.Store) *OIDCService { return &OIDCService{ config: config, queries: queries, diff --git a/sql/oidc_queries.sql b/sql/sqlite/oidc_queries.sql similarity index 100% rename from sql/oidc_queries.sql rename to sql/sqlite/oidc_queries.sql diff --git a/sql/oidc_schemas.sql b/sql/sqlite/oidc_schemas.sql similarity index 100% rename from sql/oidc_schemas.sql rename to sql/sqlite/oidc_schemas.sql diff --git a/sql/session_queries.sql b/sql/sqlite/session_queries.sql similarity index 100% rename from sql/session_queries.sql rename to sql/sqlite/session_queries.sql diff --git a/sql/session_schemas.sql b/sql/sqlite/session_schemas.sql similarity index 100% rename from sql/session_schemas.sql rename to sql/sqlite/session_schemas.sql diff --git a/sqlc.yml b/sqlc.yml index de08738a..e7b2c4b4 100644 --- a/sqlc.yml +++ b/sqlc.yml @@ -1,12 +1,12 @@ version: "2" sql: - engine: "sqlite" - queries: "sql/*_queries.sql" - schema: "sql/*_schemas.sql" + queries: "sql/sqlite/*_queries.sql" + schema: "sql/sqlite/*_schemas.sql" gen: go: - package: "repository" - out: "internal/repository" + package: "sqlite" + out: "internal/repository/sqlite" rename: uuid: "UUID" oauth_groups: "OAuthGroups" From 0244f39387be47a925af2eec15eba5b8a7820b30 Mon Sep 17 00:00:00 2001 From: Scott McKendry Date: Sun, 3 May 2026 13:49:24 +1200 Subject: [PATCH 2/3] feat(db): add code gen to build sqlc-compatible wrappers --- .github/workflows/ci.yml | 6 + cmd/gen/sqlc-wrapper/main.go | 522 +++++++++++++++++++++++++ go.mod | 2 + internal/bootstrap/db_bootstrap.go | 2 +- internal/repository/models.go | 163 +++++++- internal/repository/sqlite/generate.go | 3 + internal/repository/sqlite/store.go | 206 ++++++++++ 7 files changed, 886 insertions(+), 18 deletions(-) create mode 100644 cmd/gen/sqlc-wrapper/main.go create mode 100644 internal/repository/sqlite/generate.go create mode 100644 internal/repository/sqlite/store.go diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 12db1641..fb8c9736 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -26,6 +26,12 @@ jobs: - name: Go dependencies run: go mod download + - name: Check codegen is up to date + run: | + go generate ./internal/repository/... + git diff --exit-code -- internal/repository/ + git status --porcelain -- internal/repository/ | grep -q . && echo "untracked files in internal/repository/" && exit 1 || true + - name: Install frontend dependencies run: | cd frontend diff --git a/cmd/gen/sqlc-wrapper/main.go b/cmd/gen/sqlc-wrapper/main.go new file mode 100644 index 00000000..e66ae8ee --- /dev/null +++ b/cmd/gen/sqlc-wrapper/main.go @@ -0,0 +1,522 @@ +// gen/sqlc-wrapper generates store.go wrapper files for each sqlc driver package under +// internal/repository//. Run via: +// +// go generate ./internal/repository/... +// +// The generator introspects *Queries methods and the model/params types in the +// driver package, then emits a store.go that wraps *Queries so it satisfies +// repository.Store using the canonical shared types in the parent package. +// This generator is specific to sqlc-generated drivers. Non-sqlc drivers should +// implement repository.Store directly by hand. +package main + +import ( + "bytes" + "flag" + "fmt" + "go/format" + "go/types" + "log" + "os" + "os/exec" + "path/filepath" + "sort" + "strings" + "text/template" + + "golang.org/x/tools/go/packages" +) + +func main() { + driverPkg := flag.String("pkg", "", "import path of the driver package") + out := flag.String("out", "store.go", "output filename relative to driver package directory") + flag.Parse() + + if *driverPkg == "" { + log.Fatal("-pkg is required") + } + + // Resolve the driver package directory so we can overlay the output file + // with a valid stub. This prevents a stale store.go from poisoning the + // type-checker and producing cryptic "undefined" errors. + driverDir, err := pkgDir(*driverPkg) + if err != nil { + log.Fatalf("resolve driver dir: %v", err) + } + outPath := filepath.Join(driverDir, *out) + if filepath.IsAbs(*out) { + outPath = *out + } + + // Stub replaces the output file during load so stale generated code is ignored. + stub := []byte("package " + filepath.Base(driverDir) + "\n") + + cfg := &packages.Config{ + Mode: packages.NeedName | packages.NeedTypes | packages.NeedSyntax | packages.NeedImports, + Overlay: map[string][]byte{outPath: stub}, + } + pkgs, err := packages.Load(cfg, *driverPkg) + if err != nil { + log.Fatalf("load %s: %v", *driverPkg, err) + } + if len(pkgs) != 1 { + log.Fatalf("expected 1 package, got %d", len(pkgs)) + } + pkg := pkgs[0] + if len(pkg.Errors) > 0 { + for _, e := range pkg.Errors { + log.Printf("package error: %v", e) + } + log.Fatal("package has errors") + } + + repoPkg := parentPkg(*driverPkg) + + // Load the parent (repository) package so we can validate struct shapes. + repoPkgs, err := packages.Load(cfg, repoPkg) + if err != nil { + log.Fatalf("load repo pkg %s: %v", repoPkg, err) + } + if len(repoPkgs) != 1 || len(repoPkgs[0].Errors) > 0 { + log.Fatalf("could not load repo package %s cleanly", repoPkg) + } + if err := validateStructShapes(pkg.Types, repoPkgs[0].Types); err != nil { + log.Fatalf("struct shape mismatch: %v", err) + } + + // Check *Queries covers every method in repository.Store before generating. + if err := validateStoreCoverage(pkg.Types, repoPkgs[0].Types); err != nil { + log.Fatalf("%v", err) + } + + methods, err := collectMethods(pkg.Types) + if err != nil { + log.Fatal(err) + } + + models, _ := collectTypes(pkg.Types) + + data := tmplData{ + PkgName: pkg.Name, + RepoPkg: repoPkg, + ModelTypes: models, + Methods: renderMethods(methods), + } + + src, err := render(data) + if err != nil { + log.Fatalf("render: %v", err) + } + + if err := os.WriteFile(outPath, src, 0644); err != nil { + log.Fatalf("write %s: %v", outPath, err) + } + fmt.Printf("wrote %s\n", outPath) +} + +func parentPkg(imp string) string { + parts := strings.Split(imp, "/") + return strings.Join(parts[:len(parts)-1], "/") +} + +// pkgDir returns the on-disk directory for an import path using `go list`. +func pkgDir(importPath string) (string, error) { + out, err := exec.Command("go", "list", "-f", "{{.Dir}}", importPath).Output() + if err != nil { + return "", fmt.Errorf("go list %s: %w", importPath, err) + } + return strings.TrimSpace(string(out)), nil +} + +// validateStoreCoverage checks that every method declared in repository.Store +// exists on *Queries in the driver package. Missing methods are reported by +// name so the developer knows exactly which SQL queries need to be added. +func validateStoreCoverage(driverPkg, repoPkg *types.Package) error { + // Collect *Queries method names. + queriesObj := driverPkg.Scope().Lookup("Queries") + if queriesObj == nil { + return fmt.Errorf("Queries type not found in driver package") + } + queriesNamed := queriesObj.Type().(*types.Named) + queriesMS := types.NewMethodSet(types.NewPointer(queriesNamed)) + queriesMethods := make(map[string]bool) + for m := range queriesMS.Methods() { + queriesMethods[m.Obj().Name()] = true + } + + // Collect repository.Store interface methods. + storeObj := repoPkg.Scope().Lookup("Store") + if storeObj == nil { + return fmt.Errorf("Store type not found in repository package") + } + storeIface, ok := storeObj.Type().Underlying().(*types.Interface) + if !ok { + return fmt.Errorf("repository.Store is not an interface") + } + + var missing []string + for i := range storeIface.NumMethods() { + name := storeIface.Method(i).Name() + if !queriesMethods[name] { + missing = append(missing, name) + } + } + if len(missing) > 0 { + sort.Strings(missing) + return fmt.Errorf( + "driver *Queries is missing %d method(s) required by repository.Store:\n - %s\n\nRun sqlc generate to regenerate query methods, or add the missing SQL queries.", + len(missing), strings.Join(missing, "\n - "), + ) + } + return nil +} + +type methodInfo struct { + Name string + Params []paramInfo + Results []resultInfo +} + +type paramInfo struct { + Name string + TypeStr string // local (unqualified) type name + RepoType string // "repository.X" if this is a driver model/params type; else "" +} + +type resultInfo struct { + TypeStr string + IsSlice bool + RepoType string // "repository.X" if driver type; else "" +} + +func collectMethods(pkg *types.Package) ([]methodInfo, error) { + obj := pkg.Scope().Lookup("Queries") + if obj == nil { + return nil, fmt.Errorf("queries type not found in %s", pkg.Path()) + } + named, ok := obj.Type().(*types.Named) + if !ok { + return nil, fmt.Errorf("queries is not a named type") + } + ms := types.NewMethodSet(types.NewPointer(named)) + + var out []methodInfo + for method := range ms.Methods() { + fn, ok := method.Obj().(*types.Func) + if !ok || fn.Name() == "WithTx" { + continue + } + sig := fn.Type().(*types.Signature) + mi := methodInfo{Name: fn.Name()} + + // params: skip receiver + first (context.Context) + for i := 1; i < sig.Params().Len(); i++ { + p := sig.Params().At(i) + mi.Params = append(mi.Params, makeParam(p.Name(), p.Type(), pkg.Path())) + } + // results: skip error + for r := range sig.Results().Variables() { + if r.Type().String() == "error" { + continue + } + mi.Results = append(mi.Results, makeResult(r.Type(), pkg.Path())) + } + out = append(out, mi) + } + return out, nil +} + +func makeParam(name string, t types.Type, driverPath string) paramInfo { + pi := paramInfo{Name: name} + pi.TypeStr = localName(t, driverPath) + pi.RepoType = repoName(t, driverPath) + return pi +} + +func makeResult(t types.Type, driverPath string) resultInfo { + ri := resultInfo{} + if sl, ok := t.(*types.Slice); ok { + ri.IsSlice = true + t = sl.Elem() + } + ri.TypeStr = localName(t, driverPath) + ri.RepoType = repoName(t, driverPath) + return ri +} + +func localName(t types.Type, driverPath string) string { + named, ok := t.(*types.Named) + if !ok { + return types.TypeString(t, nil) + } + if named.Obj().Pkg() != nil && named.Obj().Pkg().Path() == driverPath { + return named.Obj().Name() + } + return types.TypeString(t, func(p *types.Package) string { return p.Name() }) +} + +func repoName(t types.Type, driverPath string) string { + named, ok := t.(*types.Named) + if !ok { + return "" + } + if named.Obj().Pkg() != nil && named.Obj().Pkg().Path() == driverPath { + return "repository." + named.Obj().Name() + } + return "" +} + +func collectTypes(pkg *types.Package) (models []string, params []string) { + for _, name := range pkg.Scope().Names() { + obj := pkg.Scope().Lookup(name) + if obj == nil { + continue + } + tn, ok := obj.(*types.TypeName) + if !ok { + continue + } + named, ok := tn.Type().(*types.Named) + if !ok { + continue + } + if _, ok := named.Underlying().(*types.Struct); !ok { + continue + } + switch name { + case "Queries", "DBTX", "Store": + continue + } + if strings.HasSuffix(name, "Params") { + params = append(params, name) + } else { + models = append(models, name) + } + } + return +} + +// validateStructShapes checks that every model/params struct in the driver +// package has fields that exactly match the corresponding type in the repo +// (parent) package. This catches drift between sqlc-generated types and the +// canonical repository types before a broken cast reaches the compiler. +func validateStructShapes(driverPkg, repoPkg *types.Package) error { + var errs []string + for _, name := range driverPkg.Scope().Names() { + obj := driverPkg.Scope().Lookup(name) + if obj == nil { + continue + } + tn, ok := obj.(*types.TypeName) + if !ok { + continue + } + named, ok := tn.Type().(*types.Named) + if !ok { + continue + } + driverStruct, ok := named.Underlying().(*types.Struct) + if !ok { + continue + } + switch name { + case "Queries", "DBTX", "Store": + continue + } + + repoObj := repoPkg.Scope().Lookup(name) + if repoObj == nil { + // Driver has a type not in repo — that's fine (e.g. internal helpers). + continue + } + repoNamed, ok := repoObj.Type().(*types.Named) + if !ok { + continue + } + repoStruct, ok := repoNamed.Underlying().(*types.Struct) + if !ok { + errs = append(errs, fmt.Sprintf("%s: repo type is not a struct", name)) + continue + } + + if err := compareStructs(name, driverStruct, repoStruct); err != nil { + errs = append(errs, err.Error()) + } + } + if len(errs) > 0 { + return fmt.Errorf("%s", strings.Join(errs, "\n ")) + } + return nil +} + +func compareStructs(name string, driver, repo *types.Struct) error { + if driver.NumFields() != repo.NumFields() { + return fmt.Errorf("%s: field count mismatch (driver=%d, repo=%d)", + name, driver.NumFields(), repo.NumFields()) + } + for i := range driver.NumFields() { + df := driver.Field(i) + rf := repo.Field(i) + if df.Name() != rf.Name() { + return fmt.Errorf("%s: field %d name mismatch (driver=%q, repo=%q)", + name, i, df.Name(), rf.Name()) + } + if !types.Identical(df.Type(), rf.Type()) { + return fmt.Errorf("%s.%s: type mismatch (driver=%s, repo=%s)", + name, df.Name(), df.Type(), rf.Type()) + } + } + return nil +} + +// converterFn: "Session" -> "sessionToRepo" +func converterFn(s string) string { + if s == "" { + return "" + } + r := []rune(s) + r[0] = []rune(strings.ToLower(string(r[0])))[0] + return string(r) + "ToRepo" +} + +// renderedMethod is the pre-built method body passed to the template. +type renderedMethod struct { + Signature string + Body string +} + +// renderMethods converts []methodInfo into fully pre-rendered signature+body strings. +func renderMethods(methods []methodInfo) []renderedMethod { + var out []renderedMethod + for _, m := range methods { + out = append(out, renderedMethod{ + Signature: buildSig(m), + Body: buildBody(m), + }) + } + return out +} + +func buildSig(m methodInfo) string { + var sb strings.Builder + sb.WriteString("func (s *Store) ") + sb.WriteString(m.Name) + sb.WriteString("(ctx context.Context") + for _, p := range m.Params { + sb.WriteString(", ") + sb.WriteString(p.Name) + sb.WriteString(" ") + if p.RepoType != "" { + sb.WriteString(p.RepoType) + } else { + sb.WriteString(p.TypeStr) + } + } + sb.WriteString(") (") + for _, r := range m.Results { + if r.IsSlice { + sb.WriteString("[]") + } + if r.RepoType != "" { + sb.WriteString(r.RepoType) + } else { + sb.WriteString(r.TypeStr) + } + sb.WriteString(", ") + } + sb.WriteString("error)") + return sb.String() +} + +func callArgs(m methodInfo) string { + var args []string + for _, p := range m.Params { + if p.RepoType != "" { + // convert repo type → driver type: DriverType(arg) + args = append(args, p.TypeStr+"("+p.Name+")") + } else { + args = append(args, p.Name) + } + } + if len(args) == 0 { + return "ctx" + } + return "ctx, " + strings.Join(args, ", ") +} + +func buildBody(m methodInfo) string { + call := "s.q." + m.Name + "(" + callArgs(m) + ")" + + // no repo-typed result → direct return + if len(m.Results) == 0 || m.Results[0].RepoType == "" { + return "\treturn " + call + "\n" + } + + r := m.Results[0] + if r.IsSlice { + return fmt.Sprintf( + "\trows, err := %s\n\tif err != nil {\n\t\treturn nil, err\n\t}\n\tout := make([]%s, len(rows))\n\tfor i, row := range rows {\n\t\tout[i] = %s(row)\n\t}\n\treturn out, nil\n", + call, r.RepoType, converterFn(r.TypeStr), + ) + } + return fmt.Sprintf( + "\tr, err := %s\n\tif err != nil {\n\t\treturn %s{}, err\n\t}\n\treturn %s(r), nil\n", + call, r.RepoType, converterFn(r.TypeStr), + ) +} + +type tmplData struct { + PkgName string + RepoPkg string + ModelTypes []string + Methods []renderedMethod +} + +const storeSrc = `// Code generated by cmd/gen/sqlc-wrapper. DO NOT EDIT. +package {{.PkgName}} + +import ( + "context" + + "{{.RepoPkg}}" +) + +// Store wraps *Queries and implements repository.Store. +type Store struct { + q *Queries +} + +// NewStore wraps a *Queries to satisfy repository.Store. +func NewStore(q *Queries) repository.Store { + return &Store{q: q} +} + +{{range .ModelTypes -}} +func {{converterFn .}}(v {{.}}) repository.{{.}} { + return repository.{{.}}(v) +} +{{end -}} +{{range .Methods}}{{.Signature}} { +{{.Body}}} + +{{end}}` + +func render(data tmplData) ([]byte, error) { + t, err := template.New("store").Funcs(template.FuncMap{ + "converterFn": converterFn, + }).Parse(storeSrc) + if err != nil { + return nil, fmt.Errorf("parse template: %w", err) + } + + var buf bytes.Buffer + if err := t.Execute(&buf, data); err != nil { + return nil, fmt.Errorf("execute template: %w", err) + } + + formatted, err := format.Source(buf.Bytes()) + if err != nil { + return buf.Bytes(), fmt.Errorf("format source: %w\nraw:\n%s", err, buf.String()) + } + return formatted, nil +} diff --git a/go.mod b/go.mod index d0c5a515..fb4a459c 100644 --- a/go.mod +++ b/go.mod @@ -20,6 +20,7 @@ require ( github.com/weppos/publicsuffix-go v0.50.3 golang.org/x/crypto v0.50.0 golang.org/x/oauth2 v0.36.0 + golang.org/x/tools v0.43.0 gotest.tools/v3 v3.5.2 k8s.io/apimachinery v0.32.2 k8s.io/client-go v0.32.2 @@ -124,6 +125,7 @@ require ( go.opentelemetry.io/otel/trace v1.43.0 // indirect golang.org/x/arch v0.22.0 // indirect golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect + golang.org/x/mod v0.34.0 // indirect golang.org/x/net v0.52.0 // indirect golang.org/x/sync v0.20.0 // indirect golang.org/x/sys v0.43.0 // indirect diff --git a/internal/bootstrap/db_bootstrap.go b/internal/bootstrap/db_bootstrap.go index efc21311..2279cb23 100644 --- a/internal/bootstrap/db_bootstrap.go +++ b/internal/bootstrap/db_bootstrap.go @@ -66,5 +66,5 @@ func (app *BootstrapApp) setupSQLite(databasePath string) (repository.Store, err return nil, fmt.Errorf("failed to migrate database: %w", err) } - return sqlite.New(db), nil + return sqlite.NewStore(sqlite.New(db)), nil } diff --git a/internal/repository/models.go b/internal/repository/models.go index 0c33e038..3f58dd66 100644 --- a/internal/repository/models.go +++ b/internal/repository/models.go @@ -1,19 +1,148 @@ package repository -// This file is a stop-gap until more drivers are added. It re-exports the models from the sqlite package so that the rest -// of the codebase can import them from a single location without needing to know about the underlying database implementation. - -import "github.com/tinyauthapp/tinyauth/internal/repository/sqlite" - -type Session = sqlite.Session -type OidcCode = sqlite.OidcCode -type OidcToken = sqlite.OidcToken -type OidcUserinfo = sqlite.OidcUserinfo - -type CreateSessionParams = sqlite.CreateSessionParams -type UpdateSessionParams = sqlite.UpdateSessionParams -type CreateOidcCodeParams = sqlite.CreateOidcCodeParams -type CreateOidcTokenParams = sqlite.CreateOidcTokenParams -type UpdateOidcTokenByRefreshTokenParams = sqlite.UpdateOidcTokenByRefreshTokenParams -type DeleteExpiredOidcTokensParams = sqlite.DeleteExpiredOidcTokensParams -type CreateOidcUserInfoParams = sqlite.CreateOidcUserInfoParams +// Shared model and parameter types for all storage drivers. +// sqlc-generated driver packages use these via the conversion layer in their store.go. + +type Session struct { + UUID string + Username string + Email string + Name string + Provider string + TotpPending bool + OAuthGroups string + Expiry int64 + CreatedAt int64 + OAuthName string + OAuthSub string +} + +type OidcCode struct { + Sub string + CodeHash string + Scope string + RedirectURI string + ClientID string + ExpiresAt int64 + Nonce string + CodeChallenge string +} + +type OidcToken struct { + Sub string + AccessTokenHash string + RefreshTokenHash string + CodeHash string + Scope string + ClientID string + TokenExpiresAt int64 + RefreshTokenExpiresAt int64 + Nonce string +} + +type OidcUserinfo struct { + Sub string + Name string + PreferredUsername string + Email string + Groups string + UpdatedAt int64 + GivenName string + FamilyName string + MiddleName string + Nickname string + Profile string + Picture string + Website string + Gender string + Birthdate string + Zoneinfo string + Locale string + PhoneNumber string + Address string +} + +type CreateSessionParams struct { + UUID string + Username string + Email string + Name string + Provider string + TotpPending bool + OAuthGroups string + Expiry int64 + CreatedAt int64 + OAuthName string + OAuthSub string +} + +type UpdateSessionParams struct { + Username string + Email string + Name string + Provider string + TotpPending bool + OAuthGroups string + Expiry int64 + OAuthName string + OAuthSub string + UUID string +} + +type CreateOidcCodeParams struct { + Sub string + CodeHash string + Scope string + RedirectURI string + ClientID string + ExpiresAt int64 + Nonce string + CodeChallenge string +} + +type CreateOidcTokenParams struct { + Sub string + AccessTokenHash string + RefreshTokenHash string + Scope string + ClientID string + TokenExpiresAt int64 + RefreshTokenExpiresAt int64 + CodeHash string + Nonce string +} + +type UpdateOidcTokenByRefreshTokenParams struct { + AccessTokenHash string + RefreshTokenHash string + TokenExpiresAt int64 + RefreshTokenExpiresAt int64 + RefreshTokenHash_2 string +} + +type DeleteExpiredOidcTokensParams struct { + TokenExpiresAt int64 + RefreshTokenExpiresAt int64 +} + +type CreateOidcUserInfoParams struct { + Sub string + Name string + PreferredUsername string + Email string + Groups string + UpdatedAt int64 + GivenName string + FamilyName string + MiddleName string + Nickname string + Profile string + Picture string + Website string + Gender string + Birthdate string + Zoneinfo string + Locale string + PhoneNumber string + Address string +} diff --git a/internal/repository/sqlite/generate.go b/internal/repository/sqlite/generate.go new file mode 100644 index 00000000..5f6011f1 --- /dev/null +++ b/internal/repository/sqlite/generate.go @@ -0,0 +1,3 @@ +package sqlite + +//go:generate go run github.com/tinyauthapp/tinyauth/cmd/gen/sqlc-wrapper -pkg github.com/tinyauthapp/tinyauth/internal/repository/sqlite diff --git a/internal/repository/sqlite/store.go b/internal/repository/sqlite/store.go new file mode 100644 index 00000000..65b4e190 --- /dev/null +++ b/internal/repository/sqlite/store.go @@ -0,0 +1,206 @@ +// Code generated by cmd/gen/sqlc-wrapper. DO NOT EDIT. +package sqlite + +import ( + "context" + + "github.com/tinyauthapp/tinyauth/internal/repository" +) + +// Store wraps *Queries and implements repository.Store. +type Store struct { + q *Queries +} + +// NewStore wraps a *Queries to satisfy repository.Store. +func NewStore(q *Queries) repository.Store { + return &Store{q: q} +} + +func oidcCodeToRepo(v OidcCode) repository.OidcCode { + return repository.OidcCode(v) +} +func oidcTokenToRepo(v OidcToken) repository.OidcToken { + return repository.OidcToken(v) +} +func oidcUserinfoToRepo(v OidcUserinfo) repository.OidcUserinfo { + return repository.OidcUserinfo(v) +} +func sessionToRepo(v Session) repository.Session { + return repository.Session(v) +} +func (s *Store) CreateOidcCode(ctx context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) { + r, err := s.q.CreateOidcCode(ctx, CreateOidcCodeParams(arg)) + if err != nil { + return repository.OidcCode{}, err + } + return oidcCodeToRepo(r), nil +} + +func (s *Store) CreateOidcToken(ctx context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) { + r, err := s.q.CreateOidcToken(ctx, CreateOidcTokenParams(arg)) + if err != nil { + return repository.OidcToken{}, err + } + return oidcTokenToRepo(r), nil +} + +func (s *Store) CreateOidcUserInfo(ctx context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) { + r, err := s.q.CreateOidcUserInfo(ctx, CreateOidcUserInfoParams(arg)) + if err != nil { + return repository.OidcUserinfo{}, err + } + return oidcUserinfoToRepo(r), nil +} + +func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionParams) (repository.Session, error) { + r, err := s.q.CreateSession(ctx, CreateSessionParams(arg)) + if err != nil { + return repository.Session{}, err + } + return sessionToRepo(r), nil +} + +func (s *Store) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]repository.OidcCode, error) { + rows, err := s.q.DeleteExpiredOidcCodes(ctx, expiresAt) + if err != nil { + return nil, err + } + out := make([]repository.OidcCode, len(rows)) + for i, row := range rows { + out[i] = oidcCodeToRepo(row) + } + return out, nil +} + +func (s *Store) DeleteExpiredOidcTokens(ctx context.Context, arg repository.DeleteExpiredOidcTokensParams) ([]repository.OidcToken, error) { + rows, err := s.q.DeleteExpiredOidcTokens(ctx, DeleteExpiredOidcTokensParams(arg)) + if err != nil { + return nil, err + } + out := make([]repository.OidcToken, len(rows)) + for i, row := range rows { + out[i] = oidcTokenToRepo(row) + } + return out, nil +} + +func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error { + return s.q.DeleteExpiredSessions(ctx, expiry) +} + +func (s *Store) DeleteOidcCode(ctx context.Context, codeHash string) error { + return s.q.DeleteOidcCode(ctx, codeHash) +} + +func (s *Store) DeleteOidcCodeBySub(ctx context.Context, sub string) error { + return s.q.DeleteOidcCodeBySub(ctx, sub) +} + +func (s *Store) DeleteOidcToken(ctx context.Context, accessTokenHash string) error { + return s.q.DeleteOidcToken(ctx, accessTokenHash) +} + +func (s *Store) DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error { + return s.q.DeleteOidcTokenByCodeHash(ctx, codeHash) +} + +func (s *Store) DeleteOidcTokenBySub(ctx context.Context, sub string) error { + return s.q.DeleteOidcTokenBySub(ctx, sub) +} + +func (s *Store) DeleteOidcUserInfo(ctx context.Context, sub string) error { + return s.q.DeleteOidcUserInfo(ctx, sub) +} + +func (s *Store) DeleteSession(ctx context.Context, uuid string) error { + return s.q.DeleteSession(ctx, uuid) +} + +func (s *Store) GetOidcCode(ctx context.Context, codeHash string) (repository.OidcCode, error) { + r, err := s.q.GetOidcCode(ctx, codeHash) + if err != nil { + return repository.OidcCode{}, err + } + return oidcCodeToRepo(r), nil +} + +func (s *Store) GetOidcCodeBySub(ctx context.Context, sub string) (repository.OidcCode, error) { + r, err := s.q.GetOidcCodeBySub(ctx, sub) + if err != nil { + return repository.OidcCode{}, err + } + return oidcCodeToRepo(r), nil +} + +func (s *Store) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (repository.OidcCode, error) { + r, err := s.q.GetOidcCodeBySubUnsafe(ctx, sub) + if err != nil { + return repository.OidcCode{}, err + } + return oidcCodeToRepo(r), nil +} + +func (s *Store) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (repository.OidcCode, error) { + r, err := s.q.GetOidcCodeUnsafe(ctx, codeHash) + if err != nil { + return repository.OidcCode{}, err + } + return oidcCodeToRepo(r), nil +} + +func (s *Store) GetOidcToken(ctx context.Context, accessTokenHash string) (repository.OidcToken, error) { + r, err := s.q.GetOidcToken(ctx, accessTokenHash) + if err != nil { + return repository.OidcToken{}, err + } + return oidcTokenToRepo(r), nil +} + +func (s *Store) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (repository.OidcToken, error) { + r, err := s.q.GetOidcTokenByRefreshToken(ctx, refreshTokenHash) + if err != nil { + return repository.OidcToken{}, err + } + return oidcTokenToRepo(r), nil +} + +func (s *Store) GetOidcTokenBySub(ctx context.Context, sub string) (repository.OidcToken, error) { + r, err := s.q.GetOidcTokenBySub(ctx, sub) + if err != nil { + return repository.OidcToken{}, err + } + return oidcTokenToRepo(r), nil +} + +func (s *Store) GetOidcUserInfo(ctx context.Context, sub string) (repository.OidcUserinfo, error) { + r, err := s.q.GetOidcUserInfo(ctx, sub) + if err != nil { + return repository.OidcUserinfo{}, err + } + return oidcUserinfoToRepo(r), nil +} + +func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session, error) { + r, err := s.q.GetSession(ctx, uuid) + if err != nil { + return repository.Session{}, err + } + return sessionToRepo(r), nil +} + +func (s *Store) UpdateOidcTokenByRefreshToken(ctx context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) { + r, err := s.q.UpdateOidcTokenByRefreshToken(ctx, UpdateOidcTokenByRefreshTokenParams(arg)) + if err != nil { + return repository.OidcToken{}, err + } + return oidcTokenToRepo(r), nil +} + +func (s *Store) UpdateSession(ctx context.Context, arg repository.UpdateSessionParams) (repository.Session, error) { + r, err := s.q.UpdateSession(ctx, UpdateSessionParams(arg)) + if err != nil { + return repository.Session{}, err + } + return sessionToRepo(r), nil +} From 04b8e9884bcd76ae502fd686f6e8ab5c8e703e60 Mon Sep 17 00:00:00 2001 From: Scott McKendry Date: Mon, 4 May 2026 05:02:27 +1200 Subject: [PATCH 3/3] feat(db): add `memory` storage driver removes the sqlite dependency for tests, also brings back the option for users to run zero persistence instances of tinyauth. adds new mapErr fn for sqlc wrapper gen to prevent sql errors from leaking out of the store implementation. --- cmd/gen/sqlc-wrapper/main.go | 24 +- internal/bootstrap/db_bootstrap.go | 17 +- internal/config/config.go | 6 +- internal/controller/oidc_controller_test.go | 7 +- internal/controller/proxy_controller_test.go | 9 +- internal/controller/user_controller_test.go | 9 +- .../controller/well_known_controller_test.go | 7 +- internal/repository/memory/oidc_queries.go | 241 ++++++++++++++++++ internal/repository/memory/session_queries.go | 63 +++++ internal/repository/memory/store.go | 27 ++ internal/repository/sqlite/store.go | 68 +++-- internal/repository/store.go | 8 +- internal/service/auth_service.go | 3 +- internal/service/oidc_service.go | 15 +- 14 files changed, 435 insertions(+), 69 deletions(-) create mode 100644 internal/repository/memory/oidc_queries.go create mode 100644 internal/repository/memory/session_queries.go create mode 100644 internal/repository/memory/store.go diff --git a/cmd/gen/sqlc-wrapper/main.go b/cmd/gen/sqlc-wrapper/main.go index e66ae8ee..d6cb6318 100644 --- a/cmd/gen/sqlc-wrapper/main.go +++ b/cmd/gen/sqlc-wrapper/main.go @@ -449,18 +449,18 @@ func buildBody(m methodInfo) string { // no repo-typed result → direct return if len(m.Results) == 0 || m.Results[0].RepoType == "" { - return "\treturn " + call + "\n" + return "\treturn mapErr(" + call + ")\n" } r := m.Results[0] if r.IsSlice { return fmt.Sprintf( - "\trows, err := %s\n\tif err != nil {\n\t\treturn nil, err\n\t}\n\tout := make([]%s, len(rows))\n\tfor i, row := range rows {\n\t\tout[i] = %s(row)\n\t}\n\treturn out, nil\n", + "\trows, err := %s\n\tif err != nil {\n\t\treturn nil, mapErr(err)\n\t}\n\tout := make([]%s, len(rows))\n\tfor i, row := range rows {\n\t\tout[i] = %s(row)\n\t}\n\treturn out, nil\n", call, r.RepoType, converterFn(r.TypeStr), ) } return fmt.Sprintf( - "\tr, err := %s\n\tif err != nil {\n\t\treturn %s{}, err\n\t}\n\treturn %s(r), nil\n", + "\tr, err := %s\n\tif err != nil {\n\t\treturn %s{}, mapErr(err)\n\t}\n\treturn %s(r), nil\n", call, r.RepoType, converterFn(r.TypeStr), ) } @@ -477,6 +477,8 @@ package {{.PkgName}} import ( "context" + "database/sql" + "errors" "{{.RepoPkg}}" ) @@ -491,6 +493,22 @@ func NewStore(q *Queries) repository.Store { return &Store{q: q} } +var errMap = []struct { + from error + to error +}{ + {sql.ErrNoRows, repository.ErrNotFound}, +} + +func mapErr(err error) error { + for _, e := range errMap { + if errors.Is(err, e.from) { + return e.to + } + } + return err +} + {{range .ModelTypes -}} func {{converterFn .}}(v {{.}}) repository.{{.}} { return repository.{{.}}(v) diff --git a/internal/bootstrap/db_bootstrap.go b/internal/bootstrap/db_bootstrap.go index 2279cb23..4f09372a 100644 --- a/internal/bootstrap/db_bootstrap.go +++ b/internal/bootstrap/db_bootstrap.go @@ -8,6 +8,7 @@ import ( "github.com/tinyauthapp/tinyauth/internal/assets" "github.com/tinyauthapp/tinyauth/internal/repository" + "github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/repository/sqlite" "github.com/golang-migrate/migrate/v4" @@ -17,14 +18,14 @@ import ( ) func (app *BootstrapApp) SetupStore() (repository.Store, error) { - return app.setupSQLite(app.config.Database.Path) -} - -// NewSQLiteStore opens a SQLite database at the given path, runs migrations, and returns a Store. -// Useful for testing or when constructing a store outside of a BootstrapApp. -func NewSQLiteStore(databasePath string) (repository.Store, error) { - app := &BootstrapApp{} - return app.setupSQLite(databasePath) + switch app.config.Database.Driver { + case "memory": + return memory.New(), nil + case "sqlite", "": + return app.setupSQLite(app.config.Database.Path) + default: + return nil, fmt.Errorf("unknown database driver %q: valid values are sqlite, memory", app.config.Database.Driver) + } } func (app *BootstrapApp) setupSQLite(databasePath string) (repository.Store, error) { diff --git a/internal/config/config.go b/internal/config/config.go index 5b14e27e..9d2a8663 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -4,7 +4,8 @@ package config func NewDefaultConfiguration() *Config { return &Config{ Database: DatabaseConfig{ - Path: "./tinyauth.db", + Driver: "sqlite", + Path: "./tinyauth.db", }, Analytics: AnalyticsConfig{ Enabled: true, @@ -95,7 +96,8 @@ type Config struct { } type DatabaseConfig struct { - Path string `description:"The path to the SQLite database, including file name." yaml:"path"` + Driver string `description:"The database driver to use. Valid values: sqlite, memory." yaml:"driver"` + Path string `description:"The path to the SQLite database, including file name. Only used when driver is sqlite." yaml:"path"` } type AnalyticsConfig struct { diff --git a/internal/controller/oidc_controller_test.go b/internal/controller/oidc_controller_test.go index 991f6759..b83094c1 100644 --- a/internal/controller/oidc_controller_test.go +++ b/internal/controller/oidc_controller_test.go @@ -12,9 +12,9 @@ import ( "github.com/gin-gonic/gin" "github.com/google/go-querystring/query" - "github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/config" "github.com/tinyauthapp/tinyauth/internal/controller" + "github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/stretchr/testify/assert" @@ -847,11 +847,10 @@ func TestOIDCController(t *testing.T) { }, } - store, err := bootstrap.NewSQLiteStore(path.Join(tempDir, "tinyauth.db")) - require.NoError(t, err) + store := memory.New() oidcService := service.NewOIDCService(oidcServiceCfg, store) - err = oidcService.Init() + err := oidcService.Init() require.NoError(t, err) for _, test := range tests { diff --git a/internal/controller/proxy_controller_test.go b/internal/controller/proxy_controller_test.go index adfc7fb1..74bfdead 100644 --- a/internal/controller/proxy_controller_test.go +++ b/internal/controller/proxy_controller_test.go @@ -2,13 +2,12 @@ package controller_test import ( "net/http/httptest" - "path" "testing" "github.com/gin-gonic/gin" - "github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/config" "github.com/tinyauthapp/tinyauth/internal/controller" + "github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/stretchr/testify/assert" @@ -17,7 +16,6 @@ import ( func TestProxyController(t *testing.T) { tlog.NewTestLogger().Init() - tempDir := t.TempDir() authServiceCfg := service.AuthServiceConfig{ Users: []config.User{ @@ -392,11 +390,10 @@ func TestProxyController(t *testing.T) { oauthBrokerCfgs := make(map[string]config.OAuthServiceConfig) - store, err := bootstrap.NewSQLiteStore(path.Join(tempDir, "tinyauth.db")) - require.NoError(t, err) + store := memory.New() docker := service.NewDockerService() - err = docker.Init() + err := docker.Init() require.NoError(t, err) ldap := service.NewLdapService(service.LdapServiceConfig{}) diff --git a/internal/controller/user_controller_test.go b/internal/controller/user_controller_test.go index b67c70fa..1d6e11b2 100644 --- a/internal/controller/user_controller_test.go +++ b/internal/controller/user_controller_test.go @@ -3,16 +3,15 @@ package controller_test import ( "encoding/json" "net/http/httptest" - "path" "strings" "testing" "time" "github.com/gin-gonic/gin" "github.com/pquerna/otp/totp" - "github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/config" "github.com/tinyauthapp/tinyauth/internal/controller" + "github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/stretchr/testify/assert" @@ -21,7 +20,6 @@ import ( func TestUserController(t *testing.T) { tlog.NewTestLogger().Init() - tempDir := t.TempDir() authServiceCfg := service.AuthServiceConfig{ Users: []config.User{ @@ -350,11 +348,10 @@ func TestUserController(t *testing.T) { oauthBrokerCfgs := make(map[string]config.OAuthServiceConfig) - store, err := bootstrap.NewSQLiteStore(path.Join(tempDir, "tinyauth.db")) - require.NoError(t, err) + store := memory.New() docker := service.NewDockerService() - err = docker.Init() + err := docker.Init() require.NoError(t, err) ldap := service.NewLdapService(service.LdapServiceConfig{}) diff --git a/internal/controller/well_known_controller_test.go b/internal/controller/well_known_controller_test.go index eba449b0..25c8e5a8 100644 --- a/internal/controller/well_known_controller_test.go +++ b/internal/controller/well_known_controller_test.go @@ -8,9 +8,9 @@ import ( "testing" "github.com/gin-gonic/gin" - "github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/config" "github.com/tinyauthapp/tinyauth/internal/controller" + "github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/stretchr/testify/assert" @@ -100,11 +100,10 @@ func TestWellKnownController(t *testing.T) { }, } - store, err := bootstrap.NewSQLiteStore(path.Join(tempDir, "tinyauth.db")) - require.NoError(t, err) + store := memory.New() oidcService := service.NewOIDCService(oidcServiceCfg, store) - err = oidcService.Init() + err := oidcService.Init() require.NoError(t, err) for _, test := range tests { diff --git a/internal/repository/memory/oidc_queries.go b/internal/repository/memory/oidc_queries.go new file mode 100644 index 00000000..80305fc0 --- /dev/null +++ b/internal/repository/memory/oidc_queries.go @@ -0,0 +1,241 @@ +package memory + +import ( + "context" + "fmt" + + "github.com/tinyauthapp/tinyauth/internal/repository" +) + +func (s *Store) CreateOidcCode(_ context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) { + s.mu.Lock() + defer s.mu.Unlock() + // Enforce sub UNIQUE constraint + for _, c := range s.oidcCodes { + if c.Sub == arg.Sub { + return repository.OidcCode{}, fmt.Errorf("UNIQUE constraint failed: oidc_codes.sub") + } + } + code := repository.OidcCode(arg) + s.oidcCodes[arg.CodeHash] = code + return code, nil +} + +// GetOidcCode is a destructive read: it deletes and returns the code (mirrors SQLite's DELETE...RETURNING). +func (s *Store) GetOidcCode(_ context.Context, codeHash string) (repository.OidcCode, error) { + s.mu.Lock() + defer s.mu.Unlock() + c, ok := s.oidcCodes[codeHash] + if !ok { + return repository.OidcCode{}, repository.ErrNotFound + } + delete(s.oidcCodes, codeHash) + return c, nil +} + +// GetOidcCodeBySub is a destructive read: it deletes and returns the code (mirrors SQLite's DELETE...RETURNING). +func (s *Store) GetOidcCodeBySub(_ context.Context, sub string) (repository.OidcCode, error) { + s.mu.Lock() + defer s.mu.Unlock() + for k, c := range s.oidcCodes { + if c.Sub == sub { + delete(s.oidcCodes, k) + return c, nil + } + } + return repository.OidcCode{}, repository.ErrNotFound +} + +// GetOidcCodeUnsafe is a non-destructive read (mirrors SQLite's SELECT). +func (s *Store) GetOidcCodeUnsafe(_ context.Context, codeHash string) (repository.OidcCode, error) { + s.mu.RLock() + defer s.mu.RUnlock() + c, ok := s.oidcCodes[codeHash] + if !ok { + return repository.OidcCode{}, repository.ErrNotFound + } + return c, nil +} + +// GetOidcCodeBySubUnsafe is a non-destructive read (mirrors SQLite's SELECT). +func (s *Store) GetOidcCodeBySubUnsafe(_ context.Context, sub string) (repository.OidcCode, error) { + s.mu.RLock() + defer s.mu.RUnlock() + for _, c := range s.oidcCodes { + if c.Sub == sub { + return c, nil + } + } + return repository.OidcCode{}, repository.ErrNotFound +} + +func (s *Store) DeleteOidcCode(_ context.Context, codeHash string) error { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.oidcCodes, codeHash) + return nil +} + +func (s *Store) DeleteOidcCodeBySub(_ context.Context, sub string) error { + s.mu.Lock() + defer s.mu.Unlock() + for k, c := range s.oidcCodes { + if c.Sub == sub { + delete(s.oidcCodes, k) + } + } + return nil +} + +func (s *Store) DeleteExpiredOidcCodes(_ context.Context, expiresAt int64) ([]repository.OidcCode, error) { + s.mu.Lock() + defer s.mu.Unlock() + var deleted []repository.OidcCode + for k, c := range s.oidcCodes { + if c.ExpiresAt < expiresAt { + deleted = append(deleted, c) + delete(s.oidcCodes, k) + } + } + return deleted, nil +} + +func (s *Store) CreateOidcToken(_ context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) { + s.mu.Lock() + defer s.mu.Unlock() + // Enforce sub UNIQUE constraint + for _, t := range s.oidcTokens { + if t.Sub == arg.Sub { + return repository.OidcToken{}, fmt.Errorf("UNIQUE constraint failed: oidc_tokens.sub") + } + } + tok := repository.OidcToken{ + Sub: arg.Sub, + AccessTokenHash: arg.AccessTokenHash, + RefreshTokenHash: arg.RefreshTokenHash, + CodeHash: arg.CodeHash, + Scope: arg.Scope, + ClientID: arg.ClientID, + TokenExpiresAt: arg.TokenExpiresAt, + RefreshTokenExpiresAt: arg.RefreshTokenExpiresAt, + Nonce: arg.Nonce, + } + s.oidcTokens[arg.AccessTokenHash] = tok + return tok, nil +} + +func (s *Store) GetOidcToken(_ context.Context, accessTokenHash string) (repository.OidcToken, error) { + s.mu.RLock() + defer s.mu.RUnlock() + t, ok := s.oidcTokens[accessTokenHash] + if !ok { + return repository.OidcToken{}, repository.ErrNotFound + } + return t, nil +} + +func (s *Store) GetOidcTokenByRefreshToken(_ context.Context, refreshTokenHash string) (repository.OidcToken, error) { + s.mu.RLock() + defer s.mu.RUnlock() + for _, t := range s.oidcTokens { + if t.RefreshTokenHash == refreshTokenHash { + return t, nil + } + } + return repository.OidcToken{}, repository.ErrNotFound +} + +func (s *Store) GetOidcTokenBySub(_ context.Context, sub string) (repository.OidcToken, error) { + s.mu.RLock() + defer s.mu.RUnlock() + for _, t := range s.oidcTokens { + if t.Sub == sub { + return t, nil + } + } + return repository.OidcToken{}, repository.ErrNotFound +} + +func (s *Store) UpdateOidcTokenByRefreshToken(_ context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) { + s.mu.Lock() + defer s.mu.Unlock() + for k, t := range s.oidcTokens { + if t.RefreshTokenHash == arg.RefreshTokenHash_2 { + delete(s.oidcTokens, k) + t.AccessTokenHash = arg.AccessTokenHash + t.RefreshTokenHash = arg.RefreshTokenHash + t.TokenExpiresAt = arg.TokenExpiresAt + t.RefreshTokenExpiresAt = arg.RefreshTokenExpiresAt + s.oidcTokens[arg.AccessTokenHash] = t + return t, nil + } + } + return repository.OidcToken{}, repository.ErrNotFound +} + +func (s *Store) DeleteOidcToken(_ context.Context, accessTokenHash string) error { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.oidcTokens, accessTokenHash) + return nil +} + +func (s *Store) DeleteOidcTokenBySub(_ context.Context, sub string) error { + s.mu.Lock() + defer s.mu.Unlock() + for k, t := range s.oidcTokens { + if t.Sub == sub { + delete(s.oidcTokens, k) + } + } + return nil +} + +func (s *Store) DeleteOidcTokenByCodeHash(_ context.Context, codeHash string) error { + s.mu.Lock() + defer s.mu.Unlock() + for k, t := range s.oidcTokens { + if t.CodeHash == codeHash { + delete(s.oidcTokens, k) + } + } + return nil +} + +func (s *Store) DeleteExpiredOidcTokens(_ context.Context, arg repository.DeleteExpiredOidcTokensParams) ([]repository.OidcToken, error) { + s.mu.Lock() + defer s.mu.Unlock() + var deleted []repository.OidcToken + for k, t := range s.oidcTokens { + if t.TokenExpiresAt < arg.TokenExpiresAt || t.RefreshTokenExpiresAt < arg.RefreshTokenExpiresAt { + deleted = append(deleted, t) + delete(s.oidcTokens, k) + } + } + return deleted, nil +} + +func (s *Store) CreateOidcUserInfo(_ context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) { + s.mu.Lock() + defer s.mu.Unlock() + u := repository.OidcUserinfo(arg) + s.oidcUsers[arg.Sub] = u + return u, nil +} + +func (s *Store) GetOidcUserInfo(_ context.Context, sub string) (repository.OidcUserinfo, error) { + s.mu.RLock() + defer s.mu.RUnlock() + u, ok := s.oidcUsers[sub] + if !ok { + return repository.OidcUserinfo{}, repository.ErrNotFound + } + return u, nil +} + +func (s *Store) DeleteOidcUserInfo(_ context.Context, sub string) error { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.oidcUsers, sub) + return nil +} diff --git a/internal/repository/memory/session_queries.go b/internal/repository/memory/session_queries.go new file mode 100644 index 00000000..2edde6b1 --- /dev/null +++ b/internal/repository/memory/session_queries.go @@ -0,0 +1,63 @@ +package memory + +import ( + "context" + + "github.com/tinyauthapp/tinyauth/internal/repository" +) + +func (s *Store) CreateSession(_ context.Context, arg repository.CreateSessionParams) (repository.Session, error) { + s.mu.Lock() + defer s.mu.Unlock() + sess := repository.Session(arg) + s.sessions[arg.UUID] = sess + return sess, nil +} + +func (s *Store) GetSession(_ context.Context, uuid string) (repository.Session, error) { + s.mu.RLock() + defer s.mu.RUnlock() + sess, ok := s.sessions[uuid] + if !ok { + return repository.Session{}, repository.ErrNotFound + } + return sess, nil +} + +func (s *Store) UpdateSession(_ context.Context, arg repository.UpdateSessionParams) (repository.Session, error) { + s.mu.Lock() + defer s.mu.Unlock() + sess, ok := s.sessions[arg.UUID] + if !ok { + return repository.Session{}, repository.ErrNotFound + } + sess.Username = arg.Username + sess.Email = arg.Email + sess.Name = arg.Name + sess.Provider = arg.Provider + sess.TotpPending = arg.TotpPending + sess.OAuthGroups = arg.OAuthGroups + sess.Expiry = arg.Expiry + sess.OAuthName = arg.OAuthName + sess.OAuthSub = arg.OAuthSub + s.sessions[arg.UUID] = sess + return sess, nil +} + +func (s *Store) DeleteSession(_ context.Context, uuid string) error { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.sessions, uuid) + return nil +} + +func (s *Store) DeleteExpiredSessions(_ context.Context, expiry int64) error { + s.mu.Lock() + defer s.mu.Unlock() + for k, v := range s.sessions { + if v.Expiry < expiry { + delete(s.sessions, k) + } + } + return nil +} diff --git a/internal/repository/memory/store.go b/internal/repository/memory/store.go new file mode 100644 index 00000000..969cba66 --- /dev/null +++ b/internal/repository/memory/store.go @@ -0,0 +1,27 @@ +// Package memory provides an in-memory implementation of repository.Store for use in tests. +package memory + +import ( + "sync" + + "github.com/tinyauthapp/tinyauth/internal/repository" +) + +// Store is a thread-safe in-memory implementation of repository.Store. +type Store struct { + mu sync.RWMutex + sessions map[string]repository.Session + oidcCodes map[string]repository.OidcCode + oidcTokens map[string]repository.OidcToken + oidcUsers map[string]repository.OidcUserinfo +} + +// New returns a new empty in-memory Store. +func New() repository.Store { + return &Store{ + sessions: make(map[string]repository.Session), + oidcCodes: make(map[string]repository.OidcCode), + oidcTokens: make(map[string]repository.OidcToken), + oidcUsers: make(map[string]repository.OidcUserinfo), + } +} diff --git a/internal/repository/sqlite/store.go b/internal/repository/sqlite/store.go index 65b4e190..f316efa4 100644 --- a/internal/repository/sqlite/store.go +++ b/internal/repository/sqlite/store.go @@ -3,6 +3,8 @@ package sqlite import ( "context" + "database/sql" + "errors" "github.com/tinyauthapp/tinyauth/internal/repository" ) @@ -17,6 +19,22 @@ func NewStore(q *Queries) repository.Store { return &Store{q: q} } +var errMap = []struct { + from error + to error +}{ + {sql.ErrNoRows, repository.ErrNotFound}, +} + +func mapErr(err error) error { + for _, e := range errMap { + if errors.Is(err, e.from) { + return e.to + } + } + return err +} + func oidcCodeToRepo(v OidcCode) repository.OidcCode { return repository.OidcCode(v) } @@ -32,7 +50,7 @@ func sessionToRepo(v Session) repository.Session { func (s *Store) CreateOidcCode(ctx context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) { r, err := s.q.CreateOidcCode(ctx, CreateOidcCodeParams(arg)) if err != nil { - return repository.OidcCode{}, err + return repository.OidcCode{}, mapErr(err) } return oidcCodeToRepo(r), nil } @@ -40,7 +58,7 @@ func (s *Store) CreateOidcCode(ctx context.Context, arg repository.CreateOidcCod func (s *Store) CreateOidcToken(ctx context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) { r, err := s.q.CreateOidcToken(ctx, CreateOidcTokenParams(arg)) if err != nil { - return repository.OidcToken{}, err + return repository.OidcToken{}, mapErr(err) } return oidcTokenToRepo(r), nil } @@ -48,7 +66,7 @@ func (s *Store) CreateOidcToken(ctx context.Context, arg repository.CreateOidcTo func (s *Store) CreateOidcUserInfo(ctx context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) { r, err := s.q.CreateOidcUserInfo(ctx, CreateOidcUserInfoParams(arg)) if err != nil { - return repository.OidcUserinfo{}, err + return repository.OidcUserinfo{}, mapErr(err) } return oidcUserinfoToRepo(r), nil } @@ -56,7 +74,7 @@ func (s *Store) CreateOidcUserInfo(ctx context.Context, arg repository.CreateOid func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionParams) (repository.Session, error) { r, err := s.q.CreateSession(ctx, CreateSessionParams(arg)) if err != nil { - return repository.Session{}, err + return repository.Session{}, mapErr(err) } return sessionToRepo(r), nil } @@ -64,7 +82,7 @@ func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionP func (s *Store) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]repository.OidcCode, error) { rows, err := s.q.DeleteExpiredOidcCodes(ctx, expiresAt) if err != nil { - return nil, err + return nil, mapErr(err) } out := make([]repository.OidcCode, len(rows)) for i, row := range rows { @@ -76,7 +94,7 @@ func (s *Store) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([] func (s *Store) DeleteExpiredOidcTokens(ctx context.Context, arg repository.DeleteExpiredOidcTokensParams) ([]repository.OidcToken, error) { rows, err := s.q.DeleteExpiredOidcTokens(ctx, DeleteExpiredOidcTokensParams(arg)) if err != nil { - return nil, err + return nil, mapErr(err) } out := make([]repository.OidcToken, len(rows)) for i, row := range rows { @@ -86,41 +104,41 @@ func (s *Store) DeleteExpiredOidcTokens(ctx context.Context, arg repository.Dele } func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error { - return s.q.DeleteExpiredSessions(ctx, expiry) + return mapErr(s.q.DeleteExpiredSessions(ctx, expiry)) } func (s *Store) DeleteOidcCode(ctx context.Context, codeHash string) error { - return s.q.DeleteOidcCode(ctx, codeHash) + return mapErr(s.q.DeleteOidcCode(ctx, codeHash)) } func (s *Store) DeleteOidcCodeBySub(ctx context.Context, sub string) error { - return s.q.DeleteOidcCodeBySub(ctx, sub) + return mapErr(s.q.DeleteOidcCodeBySub(ctx, sub)) } func (s *Store) DeleteOidcToken(ctx context.Context, accessTokenHash string) error { - return s.q.DeleteOidcToken(ctx, accessTokenHash) + return mapErr(s.q.DeleteOidcToken(ctx, accessTokenHash)) } func (s *Store) DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error { - return s.q.DeleteOidcTokenByCodeHash(ctx, codeHash) + return mapErr(s.q.DeleteOidcTokenByCodeHash(ctx, codeHash)) } func (s *Store) DeleteOidcTokenBySub(ctx context.Context, sub string) error { - return s.q.DeleteOidcTokenBySub(ctx, sub) + return mapErr(s.q.DeleteOidcTokenBySub(ctx, sub)) } func (s *Store) DeleteOidcUserInfo(ctx context.Context, sub string) error { - return s.q.DeleteOidcUserInfo(ctx, sub) + return mapErr(s.q.DeleteOidcUserInfo(ctx, sub)) } func (s *Store) DeleteSession(ctx context.Context, uuid string) error { - return s.q.DeleteSession(ctx, uuid) + return mapErr(s.q.DeleteSession(ctx, uuid)) } func (s *Store) GetOidcCode(ctx context.Context, codeHash string) (repository.OidcCode, error) { r, err := s.q.GetOidcCode(ctx, codeHash) if err != nil { - return repository.OidcCode{}, err + return repository.OidcCode{}, mapErr(err) } return oidcCodeToRepo(r), nil } @@ -128,7 +146,7 @@ func (s *Store) GetOidcCode(ctx context.Context, codeHash string) (repository.Oi func (s *Store) GetOidcCodeBySub(ctx context.Context, sub string) (repository.OidcCode, error) { r, err := s.q.GetOidcCodeBySub(ctx, sub) if err != nil { - return repository.OidcCode{}, err + return repository.OidcCode{}, mapErr(err) } return oidcCodeToRepo(r), nil } @@ -136,7 +154,7 @@ func (s *Store) GetOidcCodeBySub(ctx context.Context, sub string) (repository.Oi func (s *Store) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (repository.OidcCode, error) { r, err := s.q.GetOidcCodeBySubUnsafe(ctx, sub) if err != nil { - return repository.OidcCode{}, err + return repository.OidcCode{}, mapErr(err) } return oidcCodeToRepo(r), nil } @@ -144,7 +162,7 @@ func (s *Store) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (reposit func (s *Store) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (repository.OidcCode, error) { r, err := s.q.GetOidcCodeUnsafe(ctx, codeHash) if err != nil { - return repository.OidcCode{}, err + return repository.OidcCode{}, mapErr(err) } return oidcCodeToRepo(r), nil } @@ -152,7 +170,7 @@ func (s *Store) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (reposit func (s *Store) GetOidcToken(ctx context.Context, accessTokenHash string) (repository.OidcToken, error) { r, err := s.q.GetOidcToken(ctx, accessTokenHash) if err != nil { - return repository.OidcToken{}, err + return repository.OidcToken{}, mapErr(err) } return oidcTokenToRepo(r), nil } @@ -160,7 +178,7 @@ func (s *Store) GetOidcToken(ctx context.Context, accessTokenHash string) (repos func (s *Store) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (repository.OidcToken, error) { r, err := s.q.GetOidcTokenByRefreshToken(ctx, refreshTokenHash) if err != nil { - return repository.OidcToken{}, err + return repository.OidcToken{}, mapErr(err) } return oidcTokenToRepo(r), nil } @@ -168,7 +186,7 @@ func (s *Store) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash func (s *Store) GetOidcTokenBySub(ctx context.Context, sub string) (repository.OidcToken, error) { r, err := s.q.GetOidcTokenBySub(ctx, sub) if err != nil { - return repository.OidcToken{}, err + return repository.OidcToken{}, mapErr(err) } return oidcTokenToRepo(r), nil } @@ -176,7 +194,7 @@ func (s *Store) GetOidcTokenBySub(ctx context.Context, sub string) (repository.O func (s *Store) GetOidcUserInfo(ctx context.Context, sub string) (repository.OidcUserinfo, error) { r, err := s.q.GetOidcUserInfo(ctx, sub) if err != nil { - return repository.OidcUserinfo{}, err + return repository.OidcUserinfo{}, mapErr(err) } return oidcUserinfoToRepo(r), nil } @@ -184,7 +202,7 @@ func (s *Store) GetOidcUserInfo(ctx context.Context, sub string) (repository.Oid func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session, error) { r, err := s.q.GetSession(ctx, uuid) if err != nil { - return repository.Session{}, err + return repository.Session{}, mapErr(err) } return sessionToRepo(r), nil } @@ -192,7 +210,7 @@ func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session func (s *Store) UpdateOidcTokenByRefreshToken(ctx context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) { r, err := s.q.UpdateOidcTokenByRefreshToken(ctx, UpdateOidcTokenByRefreshTokenParams(arg)) if err != nil { - return repository.OidcToken{}, err + return repository.OidcToken{}, mapErr(err) } return oidcTokenToRepo(r), nil } @@ -200,7 +218,7 @@ func (s *Store) UpdateOidcTokenByRefreshToken(ctx context.Context, arg repositor func (s *Store) UpdateSession(ctx context.Context, arg repository.UpdateSessionParams) (repository.Session, error) { r, err := s.q.UpdateSession(ctx, UpdateSessionParams(arg)) if err != nil { - return repository.Session{}, err + return repository.Session{}, mapErr(err) } return sessionToRepo(r), nil } diff --git a/internal/repository/store.go b/internal/repository/store.go index 765df6a5..302f2f10 100644 --- a/internal/repository/store.go +++ b/internal/repository/store.go @@ -1,6 +1,12 @@ package repository -import "context" +import ( + "context" + "errors" +) + +// ErrNotFound is returned by Store methods when the requested record does not exist. +var ErrNotFound = errors.New("not found") // Store is the interface that all storage drivers must implement. // The sqlc-generated *Queries struct satisfies this interface for SQLite. diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index ab343396..5d2ead2f 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -2,7 +2,6 @@ package service import ( "context" - "database/sql" "errors" "fmt" "regexp" @@ -411,7 +410,7 @@ func (auth *AuthService) GetSessionCookie(c *gin.Context) (repository.Session, e session, err := auth.queries.GetSession(c, cookie) if err != nil { - if errors.Is(err, sql.ErrNoRows) { + if errors.Is(err, repository.ErrNotFound) { return repository.Session{}, fmt.Errorf("session not found") } return repository.Session{}, err diff --git a/internal/service/oidc_service.go b/internal/service/oidc_service.go index e5f7ea76..14d94f61 100644 --- a/internal/service/oidc_service.go +++ b/internal/service/oidc_service.go @@ -7,7 +7,6 @@ import ( "crypto/rsa" "crypto/sha256" "crypto/x509" - "database/sql" "encoding/base64" "encoding/json" "encoding/pem" @@ -420,7 +419,7 @@ func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string, client oidcCode, err := service.queries.GetOidcCode(c, codeHash) if err != nil { - if errors.Is(err, sql.ErrNoRows) { + if errors.Is(err, repository.ErrNotFound) { return repository.OidcCode{}, ErrCodeNotFound } return repository.OidcCode{}, err @@ -564,7 +563,7 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri entry, err := service.queries.GetOidcTokenByRefreshToken(c, service.Hash(refreshToken)) if err != nil { - if errors.Is(err, sql.ErrNoRows) { + if errors.Is(err, repository.ErrNotFound) { return TokenResponse{}, ErrTokenNotFound } return TokenResponse{}, err @@ -643,7 +642,7 @@ func (service *OIDCService) GetAccessToken(c *gin.Context, tokenHash string) (re entry, err := service.queries.GetOidcToken(c, tokenHash) if err != nil { - if errors.Is(err, sql.ErrNoRows) { + if errors.Is(err, repository.ErrNotFound) { return repository.OidcToken{}, ErrTokenNotFound } return repository.OidcToken{}, err @@ -731,15 +730,15 @@ func (service *OIDCService) Hash(token string) string { func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) error { err := service.queries.DeleteOidcCodeBySub(ctx, sub) - if err != nil && !errors.Is(err, sql.ErrNoRows) { + if err != nil && !errors.Is(err, repository.ErrNotFound) { return err } err = service.queries.DeleteOidcTokenBySub(ctx, sub) - if err != nil && !errors.Is(err, sql.ErrNoRows) { + if err != nil && !errors.Is(err, repository.ErrNotFound) { return err } err = service.queries.DeleteOidcUserInfo(ctx, sub) - if err != nil && !errors.Is(err, sql.ErrNoRows) { + if err != nil && !errors.Is(err, repository.ErrNotFound) { return err } return nil @@ -784,7 +783,7 @@ func (service *OIDCService) Cleanup() { token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub) if err != nil { - if errors.Is(err, sql.ErrNoRows) { + if errors.Is(err, repository.ErrNotFound) { continue } tlog.App.Warn().Err(err).Msg("Failed to get OIDC token by sub")