diff --git a/src/bun.js/api/server.zig b/src/bun.js/api/server.zig index c0e782a206..001f8841f1 100644 --- a/src/bun.js/api/server.zig +++ b/src/bun.js/api/server.zig @@ -7801,14 +7801,45 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp }; } + fn upgradeWebSocketUserRoute(this: *UserRoute, resp: *App.Response, req: *uws.Request, upgrade_ctx: *uws.uws_socket_context_t) void { + const server = this.server; + const index = this.id; + + var should_deinit_context = false; + var prepared = server.prepareJsRequestContext(req, resp, &should_deinit_context, false) orelse return; + prepared.ctx.upgrade_context = upgrade_ctx; // set the upgrade context + const server_request_list = NamespaceType.routeListGetCached(server.jsValueAssertAlive()).?; + var response_value = Bun__ServerRouteList__callRoute(server.globalThis, index, prepared.request_object, server.jsValueAssertAlive(), server_request_list, &prepared.js_request, req); + + if (server.globalThis.tryTakeException()) |exception| { + response_value = exception; + } + + server.handleRequest(&should_deinit_context, prepared, req, response_value); + } + pub fn onWebSocketUpgrade( this: *ThisServer, resp: *App.Response, req: *uws.Request, upgrade_ctx: *uws.uws_socket_context_t, - _: usize, + id: usize, ) void { JSC.markBinding(@src()); + if (id == 1) { + // user route this is actually a UserRoute its safe to cast + upgradeWebSocketUserRoute(@ptrCast(this), resp, req, upgrade_ctx); + return; + } + // only access this as *ThisServer only if id is 0 + bun.assert(id == 0); + if (this.config.onRequest == .zero) { + // require fetch method to be set otherwise we dont know what route to call + // this should be the fallback in case no route is provided to upgrade + resp.writeStatus("403 Forbidden"); + resp.endWithoutBody(true); + return; + } this.pending_requests += 1; req.setYield(false); var ctx = this.request_pool_allocator.tryGet() catch bun.outOfMemory(); @@ -7914,15 +7945,47 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp } user_routes_to_build.deinit(bun.default_allocator); } + var has_any_ws = false; + if (this.config.websocket) |*websocket| { + websocket.globalObject = this.globalThis; + websocket.handler.app = app; + websocket.handler.flags.ssl = ssl_enabled; + } // This may get applied multiple times. for (this.user_routes.items) |*user_route| { switch (user_route.route.method) { .any => { app.any(user_route.route.path, *UserRoute, user_route, onUserRouteRequest); + + if (this.config.websocket) |*websocket| { + // Setup user websocket in the route if needed. + if (!has_any_ws) { + // mark if the route is a catch-all so we dont override it + has_any_ws = strings.eqlComptime(user_route.route.path, "/*"); + } + app.ws( + user_route.route.path, + user_route, + 1, // id 1 means is a user route + ServerWebSocket.behavior(ThisServer, ssl_enabled, websocket.toBehavior()), + ); + } }, .specific => |method| { app.method(method, user_route.route.path, *UserRoute, user_route, onUserRouteRequest); + // Setup user websocket in the route if needed. + if (this.config.websocket) |*websocket| { + // Websocket upgrade is a GET request + if (method == HTTP.Method.GET) { + app.ws( + user_route.route.path, + user_route, + 1, // id 1 means is a user route + ServerWebSocket.behavior(ThisServer, ssl_enabled, websocket.toBehavior()), + ); + } + } }, } } @@ -7965,17 +8028,16 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp } }; - // Setup user websocket routes. - if (this.config.websocket) |*websocket| { - websocket.globalObject = this.globalThis; - websocket.handler.app = app; - websocket.handler.flags.ssl = ssl_enabled; - app.ws( - "/*", - this, - 0, - ServerWebSocket.behavior(ThisServer, ssl_enabled, websocket.toBehavior()), - ); + // Setup user websocket fallback route aka fetch function if fetch is not provided will respond with 403. + if (!has_any_ws) { + if (this.config.websocket) |*websocket| { + app.ws( + "/*", + this, + 0, // id 0 means is a fallback route and ctx is the server + ServerWebSocket.behavior(ThisServer, ssl_enabled, websocket.toBehavior()), + ); + } } if (debug_mode) { diff --git a/test/js/bun/http/bun-server.test.ts b/test/js/bun/http/bun-server.test.ts index c1adc9d0d3..ac7778de75 100644 --- a/test/js/bun/http/bun-server.test.ts +++ b/test/js/bun/http/bun-server.test.ts @@ -1116,3 +1116,142 @@ describe("HEAD requests #15355", () => { }); }); }); + +describe("websocket and routes test", () => { + const serverConfigurations = [ + { + // main route for upgrade + routes: { + "/": (req: Request, server: Server) => { + if (server.upgrade(req)) return; + return new Response("Forbidden", { status: 403 }); + }, + }, + shouldBeUpgraded: true, + hasPOST: false, + testName: "main route for upgrade", + }, + { + // Generic route for upgrade + routes: { + "/*": (req: Request, server: Server) => { + if (server.upgrade(req)) return; + return new Response("Forbidden", { status: 403 }); + }, + }, + shouldBeUpgraded: true, + hasPOST: false, + expectedPath: "/bun", + testName: "generic route for upgrade", + }, + // GET route for upgrade + { + routes: { + "/ws": { + GET: (req: Request, server: Server) => { + if (server.upgrade(req)) return; + return new Response("Forbidden", { status: 403 }); + }, + POST: (req: Request) => { + return new Response(req.body); + }, + }, + }, + shouldBeUpgraded: true, + hasPOST: true, + expectedPath: "/ws", + testName: "GET route for upgrade", + }, + // POST route and fetch route for upgrade + { + routes: { + "/": { + POST: (req: Request, server: Server) => { + return new Response("Hello World"); + }, + }, + }, + fetch: (req: Request, server: Server) => { + if (server.upgrade(req)) return; + return new Response("Forbidden", { status: 403 }); + }, + shouldBeUpgraded: true, + hasPOST: true, + testName: "POST route + fetch route for upgrade", + }, + // POST route for upgrade + { + routes: { + "/": { + POST: (req: Request, server: Server) => { + return new Response("Hello World"); + }, + }, + }, + shouldBeUpgraded: false, + hasPOST: true, + testName: "POST route for upgrade and no fetch", + }, + // fetch only + { + fetch: (req: Request, server: Server) => { + if (server.upgrade(req)) return; + return new Response("Forbidden", { status: 403 }); + }, + shouldBeUpgraded: true, + hasPOST: false, + testName: "fetch only for upgrade", + }, + ]; + for (const config of serverConfigurations) { + const { routes, fetch: serverFetch, shouldBeUpgraded, hasPOST, expectedPath, testName } = config; + test(testName, async () => { + using server = Bun.serve({ + port: 0, + routes, + fetch: serverFetch, + websocket: { + message: (ws, message) => { + // PING PONG + ws.send(`recv: ${message}`); + }, + }, + }); + + { + const { promise, resolve, reject } = Promise.withResolvers(); + const url = new URL(server.url); + url.pathname = expectedPath || "/"; + url.hostname = "127.0.0.1"; + const ws = new WebSocket(url.toString()); // bun crashes here + ws.onopen = () => { + ws.send("Hello server"); + }; + ws.onmessage = event => { + resolve(event.data); + ws.close(); + }; + ws.onerror = reject; + ws.onclose = event => { + reject(event.code); + }; + if (shouldBeUpgraded) { + const result = await promise; + expect(result).toBe("recv: Hello server"); + } else { + const result = await promise.catch(e => e); + expect(result).toBe(1002); + } + if (hasPOST) { + const result = await fetch(url, { + method: "POST", + body: "Hello World", + }); + expect(result.status).toBe(200); + const body = await result.text(); + expect(body).toBe("Hello World"); + } + } + }); + } +});