Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 65 additions & 36 deletions pkg/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,30 +26,41 @@ var (
// DbClient represents the active database connection in a single-session mode
DbClient *client.Client

// DbSession represents the active database session in single-session mode
DbSession *Session

// DbSessions represents the mapping for client connections
DbSessions *SessionManager

// QueryStore reads the SQL queries stores in the home directory
QueryStore *queries.Store
)

// DBSession returns a database session from the client context
func DBSession(c *gin.Context) *Session {
if command.Opts.Sessions {
return DbSessions.GetSession(getSessionId(c.Request))
}
return DbSession
}

// DB returns a database connection from the client context
func DB(c *gin.Context) *client.Client {
if command.Opts.Sessions {
return DbSessions.Get(getSessionId(c.Request))
session := DBSession(c)
if session == nil {
return nil
}
return DbClient
return session.Client
}

// setClient sets the database client connection for the sessions
func setClient(c *gin.Context, newClient *client.Client) error {
func setSession(c *gin.Context, newSession *Session) error {
currentClient := DB(c)
if currentClient != nil {
currentClient.Close()
}

if !command.Opts.Sessions {
DbClient = newClient
DbSession = newSession
return nil
}

Expand All @@ -58,7 +69,7 @@ func setClient(c *gin.Context, newClient *client.Client) error {
return errSessionRequired
}

DbSessions.Add(sid, newClient)
DbSessions.AddSession(sid, newSession)
return nil
}

Expand Down Expand Up @@ -99,16 +110,6 @@ func ConnectWithBackend(c *gin.Context) {
PassHeaders: strings.Split(command.Opts.ConnectHeaders, ","),
}

ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()

// Fetch connection credentials
cred, err := backend.FetchCredential(ctx, c.Param("resource"), c)
if err != nil {
badRequest(c, err)
return
}

// Make the new session
sid, err := securerandom.Uuid()
if err != nil {
Expand All @@ -117,21 +118,44 @@ func ConnectWithBackend(c *gin.Context) {
}
c.Request.Header.Add("x-session-id", sid)

// Connect to the database
cl, err := client.NewFromUrl(cred.DatabaseURL, nil)
if err != nil {
badRequest(c, err)
return
}
cl.External = true

// Finalize session seetup
_, err = cl.Info()
if err == nil {
err = setClient(c, cl)
}
err = setSession(c, &Session{
SessionRefresh: func(s *Session) (*Session, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()

if s == nil {
s = &Session{}
}
if s.Client != nil {
s.Client.Close()
}

// Fetch connection credentials
cred, err := backend.FetchCredential(ctx, c.Param("resource"), c)
if err != nil {
return s, err
}

s.SessionExpiry = cred.SessionExpiry

// Connect to the database
s.Client, err = client.NewFromUrl(cred.DatabaseURL, nil)
if err != nil {
return s, err
}
s.Client.External = true

// Finalize session setup
_, err = s.Client.Info()
if err != nil {
s.Client.Close()
return s, err
}

return s, nil
},
})
if err != nil {
cl.Close()
badRequest(c, err)
return
}
Expand Down Expand Up @@ -172,7 +196,9 @@ func Connect(c *gin.Context) {

info, err := cl.Info()
if err == nil {
err = setClient(c, cl)
err = setSession(c, &Session{
Client: cl,
})
}
if err != nil {
cl.Close()
Expand Down Expand Up @@ -265,7 +291,9 @@ func SwitchDb(c *gin.Context) {

info, err := cl.Info()
if err == nil {
err = setClient(c, cl)
err = setSession(c, &Session{
Client: cl,
})
}
if err != nil {
cl.Close()
Expand All @@ -291,18 +319,19 @@ func Disconnect(c *gin.Context) {
return
}

conn := DB(c)
if conn == nil {
session := DBSession(c)
if session == nil {
badRequest(c, errNotConnected)
return
}

err := conn.Close()
err := session.Client.Close()
if err != nil {
badRequest(c, err)
return
}

DbSession = nil
DbClient = nil
successResponse(c, gin.H{"success": true})
}
Expand Down
4 changes: 3 additions & 1 deletion pkg/api/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"net/http"
"strings"
"time"

"github.com/gin-gonic/gin"
)
Expand All @@ -27,7 +28,8 @@ type BackendRequest struct {

// BackendCredential represents the third-party response
type BackendCredential struct {
DatabaseURL string `json:"database_url"`
DatabaseURL string `json:"database_url"`
SessionExpiry time.Time `json:"session_expiry"`
}

// FetchCredential sends an authentication request to a third-party service
Expand Down
114 changes: 104 additions & 10 deletions pkg/api/session_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,23 @@ import (
"github.com/sosedoff/pgweb/pkg/metrics"
)

type Session struct {
Client *client.Client
SessionExpiry time.Time
SessionRefresh func(*Session) (*Session, error)
}

type SessionManager struct {
logger *logrus.Logger
sessions map[string]*client.Client
sessions map[string]*Session
mu sync.Mutex
idleTimeout time.Duration
}

func NewSessionManager(logger *logrus.Logger) *SessionManager {
return &SessionManager{
logger: logger,
sessions: map[string]*client.Client{},
sessions: map[string]*Session{},
mu: sync.Mutex{},
}
}
Expand All @@ -46,39 +52,102 @@ func (m *SessionManager) Sessions() map[string]*client.Client {
sessions := m.sessions
m.mu.Unlock()

return sessions
mapOfClients := map[string]*client.Client{}
for k, v := range sessions {
mapOfClients[k] = v.Client
}

return mapOfClients
}

func (m *SessionManager) Get(id string) *client.Client {
func (m *SessionManager) GetSession(id string) *Session {
m.mu.Lock()
c := m.sessions[id]
m.mu.Unlock()

return c
}

func (m *SessionManager) Get(id string) *client.Client {
m.mu.Lock()
c := m.sessions[id]
m.mu.Unlock()

if c == nil {
return nil
}

return c.Client
}

func (m *SessionManager) Add(id string, conn *client.Client) {
m.AddSession(id, &Session{
Client: conn,
})
}

func (m *SessionManager) AddSession(id string, session *Session) error {
m.mu.Lock()
defer m.mu.Unlock()

m.sessions[id] = conn
if session.Client == nil && session.SessionRefresh != nil {
var err error
session, err = session.SessionRefresh(session)
if err != nil {
return err
}
}
m.sessions[id] = session

metrics.SetSessionsCount(len(m.sessions))
return nil
}

func (m *SessionManager) Remove(id string) bool {
m.mu.Lock()
defer m.mu.Unlock()

conn, ok := m.sessions[id]
session, ok := m.sessions[id]
if ok {
conn.Close()
session.Client.Close()
delete(m.sessions, id)
}

metrics.SetSessionsCount(len(m.sessions))
return ok
}

func (m *SessionManager) RefreshSession(id string) error {
m.mu.Lock()
defer m.mu.Unlock()

session, ok := m.sessions[id]
if !ok {
// session not found
return nil
}

if session.SessionRefresh == nil || session.SessionExpiry.IsZero() {
// ClientFactory or SessionExpiry is not set so it is impossible to refresh
// the session
return nil
}

if session.SessionExpiry.After(time.Now()) {
// session has not expired yet
return nil
}

session, err := session.SessionRefresh(session)
if err != nil {
return err
}

m.sessions[id] = session

return nil
}

func (m *SessionManager) Len() int {
m.mu.Lock()
sz := len(m.sessions)
Expand Down Expand Up @@ -118,17 +187,42 @@ func (m *SessionManager) RunPeriodicCleanup() {
}

func (m *SessionManager) staleSessions() []string {
m.mu.TryLock()
m.mu.Lock()
defer m.mu.Unlock()

now := time.Now()
ids := []string{}

for id, conn := range m.sessions {
if now.Sub(conn.LastQueryTime()) > m.idleTimeout {
for id, session := range m.sessions {
if now.Sub(session.Client.LastQueryTime()) > m.idleTimeout {
ids = append(ids, id)
}
}

return ids
}

func (m *SessionManager) RefreshSessions() error {
m.mu.Lock()
sessions := m.sessions
m.mu.Unlock()

for id := range sessions {
if err := m.RefreshSession(id); err != nil {
return err
}
}

return nil
}

func (m *SessionManager) RunPeriodicRefresh() {
m.logger.Info("session manager refresh enabled")

for range time.Tick(time.Minute) {
if err := m.RefreshSessions(); err != nil {
// TODO: better error handling and logging
m.logger.Error("error refreshing sessions:", err)
}
}
}
Loading