fix(serve) fix WS upgrade with routes (#17805)

This commit is contained in:
Ciro Spaciari
2025-02-28 19:25:55 -08:00
committed by GitHub
parent 12a2f412fc
commit 01fb872095
2 changed files with 213 additions and 12 deletions

View File

@@ -7801,14 +7801,45 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp
};
}
fn upgradeWebSocketUserRoute(this: *UserRoute, resp: *App.Response, req: *uws.Request, upgrade_ctx: *uws.uws_socket_context_t) void {
const server = this.server;
const index = this.id;
var should_deinit_context = false;
var prepared = server.prepareJsRequestContext(req, resp, &should_deinit_context, false) orelse return;
prepared.ctx.upgrade_context = upgrade_ctx; // set the upgrade context
const server_request_list = NamespaceType.routeListGetCached(server.jsValueAssertAlive()).?;
var response_value = Bun__ServerRouteList__callRoute(server.globalThis, index, prepared.request_object, server.jsValueAssertAlive(), server_request_list, &prepared.js_request, req);
if (server.globalThis.tryTakeException()) |exception| {
response_value = exception;
}
server.handleRequest(&should_deinit_context, prepared, req, response_value);
}
pub fn onWebSocketUpgrade(
this: *ThisServer,
resp: *App.Response,
req: *uws.Request,
upgrade_ctx: *uws.uws_socket_context_t,
_: usize,
id: usize,
) void {
JSC.markBinding(@src());
if (id == 1) {
// user route this is actually a UserRoute its safe to cast
upgradeWebSocketUserRoute(@ptrCast(this), resp, req, upgrade_ctx);
return;
}
// only access this as *ThisServer only if id is 0
bun.assert(id == 0);
if (this.config.onRequest == .zero) {
// require fetch method to be set otherwise we dont know what route to call
// this should be the fallback in case no route is provided to upgrade
resp.writeStatus("403 Forbidden");
resp.endWithoutBody(true);
return;
}
this.pending_requests += 1;
req.setYield(false);
var ctx = this.request_pool_allocator.tryGet() catch bun.outOfMemory();
@@ -7914,15 +7945,47 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp
}
user_routes_to_build.deinit(bun.default_allocator);
}
var has_any_ws = false;
if (this.config.websocket) |*websocket| {
websocket.globalObject = this.globalThis;
websocket.handler.app = app;
websocket.handler.flags.ssl = ssl_enabled;
}
// This may get applied multiple times.
for (this.user_routes.items) |*user_route| {
switch (user_route.route.method) {
.any => {
app.any(user_route.route.path, *UserRoute, user_route, onUserRouteRequest);
if (this.config.websocket) |*websocket| {
// Setup user websocket in the route if needed.
if (!has_any_ws) {
// mark if the route is a catch-all so we dont override it
has_any_ws = strings.eqlComptime(user_route.route.path, "/*");
}
app.ws(
user_route.route.path,
user_route,
1, // id 1 means is a user route
ServerWebSocket.behavior(ThisServer, ssl_enabled, websocket.toBehavior()),
);
}
},
.specific => |method| {
app.method(method, user_route.route.path, *UserRoute, user_route, onUserRouteRequest);
// Setup user websocket in the route if needed.
if (this.config.websocket) |*websocket| {
// Websocket upgrade is a GET request
if (method == HTTP.Method.GET) {
app.ws(
user_route.route.path,
user_route,
1, // id 1 means is a user route
ServerWebSocket.behavior(ThisServer, ssl_enabled, websocket.toBehavior()),
);
}
}
},
}
}
@@ -7965,17 +8028,16 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp
}
};
// Setup user websocket routes.
if (this.config.websocket) |*websocket| {
websocket.globalObject = this.globalThis;
websocket.handler.app = app;
websocket.handler.flags.ssl = ssl_enabled;
app.ws(
"/*",
this,
0,
ServerWebSocket.behavior(ThisServer, ssl_enabled, websocket.toBehavior()),
);
// Setup user websocket fallback route aka fetch function if fetch is not provided will respond with 403.
if (!has_any_ws) {
if (this.config.websocket) |*websocket| {
app.ws(
"/*",
this,
0, // id 0 means is a fallback route and ctx is the server
ServerWebSocket.behavior(ThisServer, ssl_enabled, websocket.toBehavior()),
);
}
}
if (debug_mode) {

View File

@@ -1116,3 +1116,142 @@ describe("HEAD requests #15355", () => {
});
});
});
describe("websocket and routes test", () => {
const serverConfigurations = [
{
// main route for upgrade
routes: {
"/": (req: Request, server: Server) => {
if (server.upgrade(req)) return;
return new Response("Forbidden", { status: 403 });
},
},
shouldBeUpgraded: true,
hasPOST: false,
testName: "main route for upgrade",
},
{
// Generic route for upgrade
routes: {
"/*": (req: Request, server: Server) => {
if (server.upgrade(req)) return;
return new Response("Forbidden", { status: 403 });
},
},
shouldBeUpgraded: true,
hasPOST: false,
expectedPath: "/bun",
testName: "generic route for upgrade",
},
// GET route for upgrade
{
routes: {
"/ws": {
GET: (req: Request, server: Server) => {
if (server.upgrade(req)) return;
return new Response("Forbidden", { status: 403 });
},
POST: (req: Request) => {
return new Response(req.body);
},
},
},
shouldBeUpgraded: true,
hasPOST: true,
expectedPath: "/ws",
testName: "GET route for upgrade",
},
// POST route and fetch route for upgrade
{
routes: {
"/": {
POST: (req: Request, server: Server) => {
return new Response("Hello World");
},
},
},
fetch: (req: Request, server: Server) => {
if (server.upgrade(req)) return;
return new Response("Forbidden", { status: 403 });
},
shouldBeUpgraded: true,
hasPOST: true,
testName: "POST route + fetch route for upgrade",
},
// POST route for upgrade
{
routes: {
"/": {
POST: (req: Request, server: Server) => {
return new Response("Hello World");
},
},
},
shouldBeUpgraded: false,
hasPOST: true,
testName: "POST route for upgrade and no fetch",
},
// fetch only
{
fetch: (req: Request, server: Server) => {
if (server.upgrade(req)) return;
return new Response("Forbidden", { status: 403 });
},
shouldBeUpgraded: true,
hasPOST: false,
testName: "fetch only for upgrade",
},
];
for (const config of serverConfigurations) {
const { routes, fetch: serverFetch, shouldBeUpgraded, hasPOST, expectedPath, testName } = config;
test(testName, async () => {
using server = Bun.serve({
port: 0,
routes,
fetch: serverFetch,
websocket: {
message: (ws, message) => {
// PING PONG
ws.send(`recv: ${message}`);
},
},
});
{
const { promise, resolve, reject } = Promise.withResolvers();
const url = new URL(server.url);
url.pathname = expectedPath || "/";
url.hostname = "127.0.0.1";
const ws = new WebSocket(url.toString()); // bun crashes here
ws.onopen = () => {
ws.send("Hello server");
};
ws.onmessage = event => {
resolve(event.data);
ws.close();
};
ws.onerror = reject;
ws.onclose = event => {
reject(event.code);
};
if (shouldBeUpgraded) {
const result = await promise;
expect(result).toBe("recv: Hello server");
} else {
const result = await promise.catch(e => e);
expect(result).toBe(1002);
}
if (hasPOST) {
const result = await fetch(url, {
method: "POST",
body: "Hello World",
});
expect(result.status).toBe(200);
const body = await result.text();
expect(body).toBe("Hello World");
}
}
});
}
});