From 17b9b8b4a3ea94a239ec6fe3872006e6038365c1 Mon Sep 17 00:00:00 2001 From: sugghosh Date: Tue, 1 Jul 2025 22:28:11 +0530 Subject: [PATCH] logout when restart the server --- .../ha/config/FormAuthConfiguration.java | 15 +++ .../gateway/ha/resource/LoginResource.java | 14 +++ .../ha/security/LbFormAuthManager.java | 36 +++++++ .../gateway/ha/security/SessionCookie.java | 2 +- webapp/src/App.tsx | 13 +++ webapp/src/api/base.ts | 34 +++++- webapp/src/api/webapp/login.ts | 4 + webapp/src/store/access.ts | 50 ++++++++- webapp/src/utils/session.ts | 102 ++++++++++++++++++ 9 files changed, 267 insertions(+), 3 deletions(-) create mode 100644 webapp/src/utils/session.ts diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/config/FormAuthConfiguration.java b/gateway-ha/src/main/java/io/trino/gateway/ha/config/FormAuthConfiguration.java index cd5e29ebc..62d3ed9c9 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/config/FormAuthConfiguration.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/config/FormAuthConfiguration.java @@ -13,10 +13,15 @@ */ package io.trino.gateway.ha.config; +import io.airlift.units.Duration; + +import java.util.concurrent.TimeUnit; + public class FormAuthConfiguration { private SelfSignKeyPairConfiguration selfSignKeyPair; private String ldapConfigPath; + private Duration sessionTimeout = new Duration(30, TimeUnit.MINUTES); public FormAuthConfiguration(SelfSignKeyPairConfiguration selfSignKeyPair, String ldapConfigPath) { @@ -45,4 +50,14 @@ public void setLdapConfigPath(String ldapConfigPath) { this.ldapConfigPath = ldapConfigPath; } + + public Duration getSessionTimeout() + { + return this.sessionTimeout; + } + + public void setSessionTimeout(Duration sessionTimeout) + { + this.sessionTimeout = sessionTimeout; + } } diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/resource/LoginResource.java b/gateway-ha/src/main/java/io/trino/gateway/ha/resource/LoginResource.java index a5a89cbd4..a14480487 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/resource/LoginResource.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/resource/LoginResource.java @@ -185,4 +185,18 @@ else if (oauthManager != null) { } return Response.ok(Result.ok("Ok", loginType)).build(); } + + @POST + @Path("serverInfo") + @Consumes(MediaType.APPLICATION_JSON) + @Produces(MediaType.APPLICATION_JSON) + public Response serverInfo() + { + long serverStartTime = System.currentTimeMillis(); + if (formAuthManager != null) { + serverStartTime = formAuthManager.getServerStartTime(); + } + Map serverInfo = Map.of("serverStart", serverStartTime); + return Response.ok(Result.ok(serverInfo)).build(); + } } diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/security/LbFormAuthManager.java b/gateway-ha/src/main/java/io/trino/gateway/ha/security/LbFormAuthManager.java index 34b5a89ec..f889fe5ff 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/security/LbFormAuthManager.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/security/LbFormAuthManager.java @@ -19,6 +19,7 @@ import com.auth0.jwt.interfaces.Claim; import com.auth0.jwt.interfaces.DecodedJWT; import io.airlift.log.Logger; +import io.airlift.units.Duration; import io.trino.gateway.ha.config.FormAuthConfiguration; import io.trino.gateway.ha.config.LdapConfiguration; import io.trino.gateway.ha.config.UserConfiguration; @@ -26,10 +27,13 @@ import io.trino.gateway.ha.domain.request.RestLoginRequest; import io.trino.gateway.ha.security.util.BasicCredentials; +import java.time.Instant; import java.util.Collections; +import java.util.Date; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.concurrent.TimeUnit; import java.util.stream.Stream; import static com.google.common.collect.ImmutableMap.toImmutableMap; @@ -38,6 +42,8 @@ public class LbFormAuthManager { private static final Logger log = Logger.get(LbFormAuthManager.class); + private static final long SERVER_START_TIME = System.currentTimeMillis(); + private static final Duration DEFAULT_SESSION_TIMEOUT = new Duration(30, TimeUnit.MINUTES); /** * Cookie key to pass the token. */ @@ -45,6 +51,7 @@ public class LbFormAuthManager private final Map presetUsers; private final Map pagePermissions; private final LbLdapClient lbLdapClient; + private final Duration sessionTimeout; public LbFormAuthManager(FormAuthConfiguration configuration, Map presetUsers, @@ -58,9 +65,12 @@ public LbFormAuthManager(FormAuthConfiguration configuration, if (configuration != null) { this.lbKeyProvider = new LbKeyProvider(configuration .getSelfSignKeyPair()); + this.sessionTimeout = configuration.getSessionTimeout() != null ? + configuration.getSessionTimeout() : DEFAULT_SESSION_TIMEOUT; } else { this.lbKeyProvider = null; + this.sessionTimeout = DEFAULT_SESSION_TIMEOUT; } if (configuration != null && configuration.getLdapConfigPath() != null) { @@ -105,6 +115,21 @@ public Optional> getClaimsFromIdToken(String idToken) DecodedJWT jwt = JWT.decode(idToken); if (LbTokenUtil.validateToken(idToken, lbKeyProvider.getRsaPublicKey(), jwt.getIssuer(), Optional.empty())) { + // Check if token was issued before server restart + Optional serverStartClaim = Optional.ofNullable(jwt.getClaim("server_start")); + if (serverStartClaim.isPresent() && !serverStartClaim.orElseThrow().isNull()) { + long tokenServerStart = serverStartClaim.orElseThrow().asLong(); + if (tokenServerStart != SERVER_START_TIME) { + log.info("Token invalidated due to server restart"); + return Optional.empty(); + } + } + // Check token expiration + Optional expiresAt = Optional.ofNullable(jwt.getExpiresAt()); + if (expiresAt.isPresent() && expiresAt.orElseThrow().before(new Date())) { + log.info("Token expired"); + return Optional.empty(); + } return Optional.of(jwt.getClaims()); } } @@ -124,10 +149,16 @@ private String getSelfSignedToken(String username) Map headers = Map.of("alg", "RS256"); + Instant now = Instant.now(); + Instant expiration = now.plusSeconds(sessionTimeout.roundTo(TimeUnit.SECONDS)); + token = JWT.create() .withHeader(headers) .withIssuer(SessionCookie.SELF_ISSUER_ID) .withSubject(username) + .withIssuedAt(Date.from(now)) + .withExpiresAt(Date.from(expiration)) + .withClaim("server_start", SERVER_START_TIME) .sign(algorithm); } catch (JWTCreationException exception) { @@ -167,4 +198,9 @@ public List processPagePermissions(List roles) .flatMap(role -> Stream.of(pagePermissions.get(role).split("_"))) .distinct().toList(); } + + public long getServerStartTime() + { + return SERVER_START_TIME; + } } diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/security/SessionCookie.java b/gateway-ha/src/main/java/io/trino/gateway/ha/security/SessionCookie.java index b990947bb..b32d273b1 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/security/SessionCookie.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/security/SessionCookie.java @@ -30,7 +30,7 @@ public static NewCookie getTokenCookie(String token) .path("/") .domain("") .comment("") - .maxAge(60 * 60 * 24) + .maxAge(60 * 15) // 15 minutes session timeout .secure(true) .build(); } diff --git a/webapp/src/App.tsx b/webapp/src/App.tsx index 283669b21..db8f62fff 100644 --- a/webapp/src/App.tsx +++ b/webapp/src/App.tsx @@ -16,6 +16,7 @@ import { useEffect } from 'react'; import { getCSSVar } from './utils/utils'; import { IllustrationIdle, IllustrationIdleDark } from '@douyinfe/semi-illustrations'; import Cookies from 'js-cookie'; +import { SessionManager } from './utils/session'; function App() { return ( @@ -40,6 +41,18 @@ function Screen() { access.updateToken(token); Cookies.remove('token'); } + // Initialize session management + const sessionManager = SessionManager.getInstance(); + sessionManager.setSessionExpiredCallback(() => { + access.logout(); + }); + + // Check token validity on app start + access.checkTokenValidity().catch(console.error); + + return () => { + sessionManager.clearTimeout(); + }; }, []) return ( <> diff --git a/webapp/src/api/base.ts b/webapp/src/api/base.ts index 47654b782..56ffe407b 100644 --- a/webapp/src/api/base.ts +++ b/webapp/src/api/base.ts @@ -1,9 +1,12 @@ import { useAccessStore } from "../store"; import Locale, { getServerLang } from "../locales"; import { Toast } from "@douyinfe/semi-ui"; +import { SessionManager } from "../utils/session"; export class ClientApi { async get(url: string, params: Record = {}): Promise { + // Check token validity before making request + await this.validateTokenBeforeRequest(url); let queryString = ""; if (Object.keys(params).length > 0) { queryString = "?" + new URLSearchParams(params).toString(); @@ -40,6 +43,8 @@ export class ClientApi { } async post(url: string, body: Record = {}): Promise { + // Check token validity before making request + await this.validateTokenBeforeRequest(url); const res: Response = await fetch( this.path(url), { @@ -76,6 +81,8 @@ export class ClientApi { } async postForm(url: string, formData: FormData = new FormData()): Promise { + // Check token validity before making request + await this.validateTokenBeforeRequest(url); const res: Response = await fetch( this.path(url), { @@ -104,6 +111,26 @@ export class ClientApi { return resJson.data; } + private async validateTokenBeforeRequest(url: string): Promise { + // Skip validation for login-related endpoints to avoid infinite loops + if (url.includes('/login') || url.includes('/serverInfo') || url.includes('/loginType')) { + return; + } + + const accessStore = useAccessStore.getState(); + if (accessStore.token) { + try { + const isValid = await accessStore.checkTokenValidity(); + if (!isValid) { + throw new Error('Token validation failed'); + } + } catch (error) { + // Token validation failed, user will be logged out + throw error; + } + } + } + path(path: string): string { const proxyPath = import.meta.env.VITE_PROXY_PATH; return [proxyPath, path].join(""); @@ -134,7 +161,12 @@ export function getHeaders(): Record { const validString = (x: string) => x && x.length > 0; if (validString(accessStore.token)) { - headers.Authorization = makeBearer(accessStore.token); + // For synchronous header generation, we'll do basic token validation + // The async server restart check will happen in the session manager + const sessionManager = SessionManager.getInstance(); + if (!sessionManager.isTokenExpired(accessStore.token)) { + headers.Authorization = makeBearer(accessStore.token); + } } return headers; diff --git a/webapp/src/api/webapp/login.ts b/webapp/src/api/webapp/login.ts index da5396021..5b2de5558 100644 --- a/webapp/src/api/webapp/login.ts +++ b/webapp/src/api/webapp/login.ts @@ -23,3 +23,7 @@ export async function loginTypeApi(): Promise { export async function getUIConfiguration(): Promise { return api.get('/webapp/getUIConfiguration') } + +export async function serverInfoApi(): Promise { + return api.post('/serverInfo', {}) +} diff --git a/webapp/src/store/access.ts b/webapp/src/store/access.ts index cf8fe1d69..08aceb2f7 100644 --- a/webapp/src/store/access.ts +++ b/webapp/src/store/access.ts @@ -1,7 +1,8 @@ import { create } from "zustand"; import { persist } from "zustand/middleware"; import { StoreKey } from "../constant"; -import { getInfoApi } from "../api/webapp/login"; +import { getInfoApi, serverInfoApi } from "../api/webapp/login"; +import { SessionManager } from "../utils/session"; export enum Role { ADMIN = "ADMIN", @@ -28,6 +29,8 @@ export interface AccessControlStore { getUserInfo: (_?: boolean) => void; hasRole: (role: Role) => boolean; hasPermission: (permission: string | undefined) => boolean; + logout: () => void; + checkTokenValidity: () => Promise; } let fetchState: number = 0; // 0 not fetch, 1 fetching, 2 done @@ -78,6 +81,51 @@ export const useAccessStore = create()( const permissions = get().permissions return permission == undefined || permissions == null || permissions.length == 0 || permissions.includes(permission); }, + logout() { + const sessionManager = SessionManager.getInstance(); + sessionManager.clearTimeout(); + set(() => ({ + token: "", + userId: "", + userName: "", + nickName: "", + userType: "", + email: "", + phonenumber: "", + sex: "", + avatar: "", + permissions: [], + roles: [], + })); + fetchState = 0; + }, + async checkTokenValidity() { + const token = get().token; + if (!token) return false; + + const sessionManager = SessionManager.getInstance(); + + // Check if token is expired + if (sessionManager.isTokenExpired(token)) { + get().logout(); + return false; + } + + // Check for server restart + try { + const serverInfo = await serverInfoApi(); + if (sessionManager.checkServerRestart(token, serverInfo.serverStart)) { + console.log('Server restart detected, logging out'); + get().logout(); + return false; + } + } catch (error) { + console.error('Error checking server info:', error); + // Don't logout on API error, just continue + } + + return true; + }, }), { name: StoreKey.Access, diff --git a/webapp/src/utils/session.ts b/webapp/src/utils/session.ts new file mode 100644 index 000000000..d43db6266 --- /dev/null +++ b/webapp/src/utils/session.ts @@ -0,0 +1,102 @@ +export class SessionManager { + private static instance: SessionManager; + private timeoutId: number | null = null; + private readonly TIMEOUT_MINUTES = 15; + private readonly CHECK_INTERVAL = 60000; // Check every minute + private lastActivity: number = Date.now(); + private onSessionExpired?: () => void; + + private constructor() { + this.setupActivityListeners(); + this.startTimeoutCheck(); + } + + public static getInstance(): SessionManager { + if (!SessionManager.instance) { + SessionManager.instance = new SessionManager(); + } + return SessionManager.instance; + } + + public setSessionExpiredCallback(callback: () => void): void { + this.onSessionExpired = callback; + } + + public resetTimeout(): void { + this.lastActivity = Date.now(); + } + + public clearTimeout(): void { + if (this.timeoutId) { + clearInterval(this.timeoutId); + this.timeoutId = null; + } + } + + private setupActivityListeners(): void { + const events = ['mousedown', 'mousemove', 'keypress', 'scroll', 'touchstart', 'click']; + + events.forEach(event => { + document.addEventListener(event, () => { + this.resetTimeout(); + }, true); + }); + } + + private startTimeoutCheck(): void { + this.timeoutId = setInterval(() => { + const now = Date.now(); + const timeSinceLastActivity = now - this.lastActivity; + const timeoutMs = this.TIMEOUT_MINUTES * 60 * 1000; + + if (timeSinceLastActivity >= timeoutMs) { + this.handleSessionExpired(); + } + }, this.CHECK_INTERVAL); + } + + private handleSessionExpired(): void { + this.clearTimeout(); + if (this.onSessionExpired) { + this.onSessionExpired(); + } + } + + public isTokenExpired(token: string): boolean { + if (!token) return true; + + try { + // Decode JWT token to check expiration + const payload = JSON.parse(atob(token.split('.')[1])); + const currentTime = Math.floor(Date.now() / 1000); + + // Check if token has expired + if (payload.exp && payload.exp < currentTime) { + return true; + } + + return false; + } catch (error) { + console.error('Error decoding token:', error); + return true; + } + } + + public checkServerRestart(token: string, currentServerStart: number): boolean { + if (!token) return false; + + try { + const payload = JSON.parse(atob(token.split('.')[1])); + + // Check if token was issued before server restart + if (payload.server_start && payload.server_start !== currentServerStart) { + return true; + } + + return false; + } catch (error) { + console.error('Error checking server restart:', error); + return false; + } + } +}