diff --git a/pkg/api/api.go b/pkg/api/api.go index 5e4d1c2f0..7f94005b2 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -26,6 +26,9 @@ var ( // DbClient represents the active database connection in a single-session mode DbClient *client.Client + // DbSession represents the active database session in single-session mode + DbSession *Session + // DbSessions represents the mapping for client connections DbSessions *SessionManager @@ -33,23 +36,31 @@ var ( QueryStore *queries.Store ) +// DBSession returns a database session from the client context +func DBSession(c *gin.Context) *Session { + if command.Opts.Sessions { + return DbSessions.GetSession(getSessionId(c.Request)) + } + return DbSession +} + // DB returns a database connection from the client context func DB(c *gin.Context) *client.Client { - if command.Opts.Sessions { - return DbSessions.Get(getSessionId(c.Request)) + session := DBSession(c) + if session == nil { + return nil } - return DbClient + return session.Client } -// setClient sets the database client connection for the sessions -func setClient(c *gin.Context, newClient *client.Client) error { +func setSession(c *gin.Context, newSession *Session) error { currentClient := DB(c) if currentClient != nil { currentClient.Close() } if !command.Opts.Sessions { - DbClient = newClient + DbSession = newSession return nil } @@ -58,7 +69,7 @@ func setClient(c *gin.Context, newClient *client.Client) error { return errSessionRequired } - DbSessions.Add(sid, newClient) + DbSessions.AddSession(sid, newSession) return nil } @@ -99,16 +110,6 @@ func ConnectWithBackend(c *gin.Context) { PassHeaders: strings.Split(command.Opts.ConnectHeaders, ","), } - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - defer cancel() - - // Fetch connection credentials - cred, err := backend.FetchCredential(ctx, c.Param("resource"), c) - if err != nil { - badRequest(c, err) - return - } - // Make the new session sid, err := securerandom.Uuid() if err != nil { @@ -117,21 +118,44 @@ func ConnectWithBackend(c *gin.Context) { } c.Request.Header.Add("x-session-id", sid) - // Connect to the database - cl, err := client.NewFromUrl(cred.DatabaseURL, nil) - if err != nil { - badRequest(c, err) - return - } - cl.External = true - - // Finalize session seetup - _, err = cl.Info() - if err == nil { - err = setClient(c, cl) - } + err = setSession(c, &Session{ + SessionRefresh: func(s *Session) (*Session, error) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + if s == nil { + s = &Session{} + } + if s.Client != nil { + s.Client.Close() + } + + // Fetch connection credentials + cred, err := backend.FetchCredential(ctx, c.Param("resource"), c) + if err != nil { + return s, err + } + + s.SessionExpiry = cred.SessionExpiry + + // Connect to the database + s.Client, err = client.NewFromUrl(cred.DatabaseURL, nil) + if err != nil { + return s, err + } + s.Client.External = true + + // Finalize session setup + _, err = s.Client.Info() + if err != nil { + s.Client.Close() + return s, err + } + + return s, nil + }, + }) if err != nil { - cl.Close() badRequest(c, err) return } @@ -172,7 +196,9 @@ func Connect(c *gin.Context) { info, err := cl.Info() if err == nil { - err = setClient(c, cl) + err = setSession(c, &Session{ + Client: cl, + }) } if err != nil { cl.Close() @@ -265,7 +291,9 @@ func SwitchDb(c *gin.Context) { info, err := cl.Info() if err == nil { - err = setClient(c, cl) + err = setSession(c, &Session{ + Client: cl, + }) } if err != nil { cl.Close() @@ -291,18 +319,19 @@ func Disconnect(c *gin.Context) { return } - conn := DB(c) - if conn == nil { + session := DBSession(c) + if session == nil { badRequest(c, errNotConnected) return } - err := conn.Close() + err := session.Client.Close() if err != nil { badRequest(c, err) return } + DbSession = nil DbClient = nil successResponse(c, gin.H{"success": true}) } diff --git a/pkg/api/backend.go b/pkg/api/backend.go index 906dd06a2..3afaf64e8 100644 --- a/pkg/api/backend.go +++ b/pkg/api/backend.go @@ -7,6 +7,7 @@ import ( "fmt" "net/http" "strings" + "time" "github.com/gin-gonic/gin" ) @@ -27,7 +28,8 @@ type BackendRequest struct { // BackendCredential represents the third-party response type BackendCredential struct { - DatabaseURL string `json:"database_url"` + DatabaseURL string `json:"database_url"` + SessionExpiry time.Time `json:"session_expiry"` } // FetchCredential sends an authentication request to a third-party service diff --git a/pkg/api/session_manager.go b/pkg/api/session_manager.go index dbd1749f4..3fcfe3e61 100644 --- a/pkg/api/session_manager.go +++ b/pkg/api/session_manager.go @@ -10,9 +10,15 @@ import ( "github.com/sosedoff/pgweb/pkg/metrics" ) +type Session struct { + Client *client.Client + SessionExpiry time.Time + SessionRefresh func(*Session) (*Session, error) +} + type SessionManager struct { logger *logrus.Logger - sessions map[string]*client.Client + sessions map[string]*Session mu sync.Mutex idleTimeout time.Duration } @@ -20,7 +26,7 @@ type SessionManager struct { func NewSessionManager(logger *logrus.Logger) *SessionManager { return &SessionManager{ logger: logger, - sessions: map[string]*client.Client{}, + sessions: map[string]*Session{}, mu: sync.Mutex{}, } } @@ -46,10 +52,15 @@ func (m *SessionManager) Sessions() map[string]*client.Client { sessions := m.sessions m.mu.Unlock() - return sessions + mapOfClients := map[string]*client.Client{} + for k, v := range sessions { + mapOfClients[k] = v.Client + } + + return mapOfClients } -func (m *SessionManager) Get(id string) *client.Client { +func (m *SessionManager) GetSession(id string) *Session { m.mu.Lock() c := m.sessions[id] m.mu.Unlock() @@ -57,21 +68,48 @@ func (m *SessionManager) Get(id string) *client.Client { return c } +func (m *SessionManager) Get(id string) *client.Client { + m.mu.Lock() + c := m.sessions[id] + m.mu.Unlock() + + if c == nil { + return nil + } + + return c.Client +} + func (m *SessionManager) Add(id string, conn *client.Client) { + m.AddSession(id, &Session{ + Client: conn, + }) +} + +func (m *SessionManager) AddSession(id string, session *Session) error { m.mu.Lock() defer m.mu.Unlock() - m.sessions[id] = conn + if session.Client == nil && session.SessionRefresh != nil { + var err error + session, err = session.SessionRefresh(session) + if err != nil { + return err + } + } + m.sessions[id] = session + metrics.SetSessionsCount(len(m.sessions)) + return nil } func (m *SessionManager) Remove(id string) bool { m.mu.Lock() defer m.mu.Unlock() - conn, ok := m.sessions[id] + session, ok := m.sessions[id] if ok { - conn.Close() + session.Client.Close() delete(m.sessions, id) } @@ -79,6 +117,37 @@ func (m *SessionManager) Remove(id string) bool { return ok } +func (m *SessionManager) RefreshSession(id string) error { + m.mu.Lock() + defer m.mu.Unlock() + + session, ok := m.sessions[id] + if !ok { + // session not found + return nil + } + + if session.SessionRefresh == nil || session.SessionExpiry.IsZero() { + // ClientFactory or SessionExpiry is not set so it is impossible to refresh + // the session + return nil + } + + if session.SessionExpiry.After(time.Now()) { + // session has not expired yet + return nil + } + + session, err := session.SessionRefresh(session) + if err != nil { + return err + } + + m.sessions[id] = session + + return nil +} + func (m *SessionManager) Len() int { m.mu.Lock() sz := len(m.sessions) @@ -118,17 +187,42 @@ func (m *SessionManager) RunPeriodicCleanup() { } func (m *SessionManager) staleSessions() []string { - m.mu.TryLock() + m.mu.Lock() defer m.mu.Unlock() now := time.Now() ids := []string{} - for id, conn := range m.sessions { - if now.Sub(conn.LastQueryTime()) > m.idleTimeout { + for id, session := range m.sessions { + if now.Sub(session.Client.LastQueryTime()) > m.idleTimeout { ids = append(ids, id) } } return ids } + +func (m *SessionManager) RefreshSessions() error { + m.mu.Lock() + sessions := m.sessions + m.mu.Unlock() + + for id := range sessions { + if err := m.RefreshSession(id); err != nil { + return err + } + } + + return nil +} + +func (m *SessionManager) RunPeriodicRefresh() { + m.logger.Info("session manager refresh enabled") + + for range time.Tick(time.Minute) { + if err := m.RefreshSessions(); err != nil { + // TODO: better error handling and logging + m.logger.Error("error refreshing sessions:", err) + } + } +} diff --git a/pkg/api/session_manager_test.go b/pkg/api/session_manager_test.go index 594af3635..0628a6e7c 100644 --- a/pkg/api/session_manager_test.go +++ b/pkg/api/session_manager_test.go @@ -16,8 +16,8 @@ func TestSessionManager(t *testing.T) { manager := NewSessionManager(nil) assert.Equal(t, []string{}, manager.IDs()) - manager.sessions["foo"] = &client.Client{} - manager.sessions["bar"] = &client.Client{} + manager.sessions["foo"] = &Session{Client: &client.Client{}} + manager.sessions["bar"] = &Session{Client: &client.Client{}} ids := manager.IDs() sort.Strings(ids) @@ -28,7 +28,7 @@ func TestSessionManager(t *testing.T) { manager := NewSessionManager(nil) assert.Nil(t, manager.Get("foo")) - manager.sessions["foo"] = &client.Client{} + manager.sessions["foo"] = &Session{Client: &client.Client{}} assert.NotNil(t, manager.Get("foo")) }) @@ -53,8 +53,8 @@ func TestSessionManager(t *testing.T) { t.Run("return len", func(t *testing.T) { manager := NewSessionManager(nil) - manager.sessions["foo"] = &client.Client{} - manager.sessions["bar"] = &client.Client{} + manager.sessions["foo"] = &Session{Client: &client.Client{}} + manager.sessions["bar"] = &Session{Client: &client.Client{}} assert.Equal(t, 2, manager.Len()) }) diff --git a/pkg/cli/cli.go b/pkg/cli/cli.go index 9ac45c57a..5c3035b1f 100644 --- a/pkg/cli/cli.go +++ b/pkg/cli/cli.go @@ -99,6 +99,9 @@ func initClient() { exitWithMessage(err.Error()) } + api.DbSession = &api.Session{ + Client: cl, + } api.DbClient = cl } @@ -298,8 +301,8 @@ func Run() { initOptions() initClient() - if api.DbClient != nil { - defer api.DbClient.Close() + if api.DbSession != nil { + defer api.DbSession.Client.Close() } if !options.Debug { @@ -321,6 +324,31 @@ func Run() { } } + // TODO: make refresh optional + if options.Sessions { + go api.DbSessions.RunPeriodicRefresh() + } else { + go func() { + for { + if api.DbSession.Client == nil { + continue + } + if api.DbSession.SessionRefresh == nil || api.DbSession.SessionExpiry.IsZero() { + continue + } + newSession, err := api.DbSession.SessionRefresh(api.DbSession) + if err != nil { + // TODO: better error handling and logging + logger.Error("error refreshing sessions:", err) + continue + } + // FIXME: potential data race here + api.DbSession = newSession + api.DbClient = api.DbSession.Client + } + }() + } + // Start a separate metrics http server. If metrics addr is not provided, we // add the metrics endpoint in the existing application server (see api.go). if options.MetricsEnabled && options.MetricsAddr != "" { diff --git a/pkg/client/client_test.go b/pkg/client/client_test.go index 7e84e0c09..146e6b857 100644 --- a/pkg/client/client_test.go +++ b/pkg/client/client_test.go @@ -106,6 +106,7 @@ func setup() { out, err := exec.Command( testCommands["createdb"], + "-w", "-U", serverUser, "-h", serverHost, "-p", serverPort, @@ -120,6 +121,7 @@ func setup() { out, err = exec.Command( testCommands["psql"], + "-w", "-U", serverUser, "-h", serverHost, "-p", serverPort, @@ -148,6 +150,7 @@ func teardownClient() { func teardown(t *testing.T, allowFail bool) { output, err := exec.Command( testCommands["dropdb"], + "-w", "-U", serverUser, "-h", serverHost, "-p", serverPort,