diff --git a/src/main/java/icu/samnyan/aqua/sega/allnet/AllNetSecure.kt b/src/main/java/icu/samnyan/aqua/sega/allnet/AllNetSecure.kt index d72b7f7f..c1318a95 100644 --- a/src/main/java/icu/samnyan/aqua/sega/allnet/AllNetSecure.kt +++ b/src/main/java/icu/samnyan/aqua/sega/allnet/AllNetSecure.kt @@ -1,9 +1,11 @@ package icu.samnyan.aqua.sega.allnet +import ext.Str import jakarta.servlet.http.HttpServletRequest import jakarta.servlet.http.HttpServletRequestWrapper import jakarta.servlet.http.HttpServletResponse import org.slf4j.LoggerFactory +import org.springframework.boot.autoconfigure.condition.ConditionalOnBean import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty import org.springframework.context.annotation.Configuration import org.springframework.stereotype.Component @@ -19,6 +21,7 @@ import org.springframework.web.servlet.config.annotation.WebMvcConfigurer * This interceptor will check if the token exists in the database. */ @Component +@ConditionalOnBean(AllNetSecureInit::class) class TokenChecker( val keyChipRepo: KeyChipRepo ) : HandlerInterceptor { @@ -28,15 +31,19 @@ class TokenChecker( * Handle request before it's processed. */ override fun preHandle(req: HttpServletRequest, resp: HttpServletResponse, handler: Any): Boolean { + // Skip the interceptor if the request is already forwarded + if (req.getAttribute("token") != null) return true + + // Parse the token from the request path val token = extractTokenFromPath(req.requestURI) log.debug("PreHandle: ${req.requestURI} from ip ${req.remoteAddr}, token: $token") // Check whether the token exists in the database // The token can either be a keychip id (old method) or a session id (new method) - if (token != null && keyChipRepo.existsByKeychipId(token)) + if (token.isNotBlank() && keyChipRepo.existsByKeychipId(token)) { // Forward the request - val w = RewriteWrapper(req) + val w = RewriteWrapper(req, token).apply { setAttribute("token", token) } req.getRequestDispatcher(w.requestURI).forward(w, resp) // Prevent the request from being processed twice @@ -51,24 +58,19 @@ class TokenChecker( /** * Extract the token from the request path. - * Example: "/gs/SS12033897/mai2/SomeEndpoint" -> "12033897" + * Example: "/gs/12033897/mai2/SomeEndpoint" -> "12033897" */ - fun extractTokenFromPath(path: String) = path.split("/").find { it.startsWith("SS") }?.substring(2) + fun extractTokenFromPath(path: String) = path.substringAfter("/gs/", "").substringBefore("/", "") } -val tokenRegex = Regex("/gs/SS.*?/") - /** * Request wrapper for rewriting the URI after token check. */ -class RewriteWrapper(req: HttpServletRequest) : HttpServletRequestWrapper(req) { - val newUri = req.requestURI.replace(tokenRegex, "/g/") - val newUrl = req.requestURL.toString().replace(tokenRegex, "/g/") - val newSp = req.servletPath.replace(tokenRegex, "/g/") - - init { - println("RewriteWrapper: $newUri, $newUrl, $newSp") - } +class RewriteWrapper(req: HttpServletRequest, token: Str) : HttpServletRequestWrapper(req) { + val replace = "/gs/$token/" + val newUri = req.requestURI.replace(replace, "/g/") + val newUrl = req.requestURL.toString().replace(replace, "/g/") + val newSp = req.servletPath.replace(replace, "/g/") override fun getRequestURI() = newUri override fun getRequestURL() = StringBuffer(newUrl) @@ -91,6 +93,6 @@ class AllNetSecureInit( override fun addInterceptors(reg: InterceptorRegistry) { log.info("AllNet: Added token interceptor to secure requests.") - reg.addInterceptor(tokenChecker).addPathPatterns("/gs/**") + reg.addInterceptor(tokenChecker).addPathPatterns("/gs/**", "/g/**") } }