pass allocator to zig programs that opt in/pass request to lib
This commit is contained in:
		
							parent
							
								
									271c53f650
								
							
						
					
					
						commit
						cfd190d29f
					
				
					 3 changed files with 131 additions and 45 deletions
				
			
		|  | @ -16,7 +16,7 @@ pub const Response = extern struct { | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| pub const Request = extern struct { | pub const Request = extern struct { | ||||||
|     method: [*]u8, |     method: [*:0]u8, | ||||||
|     method_len: usize, |     method_len: usize, | ||||||
| 
 | 
 | ||||||
|     content: [*]u8, |     content: [*]u8, | ||||||
|  |  | ||||||
|  | @ -3,23 +3,28 @@ const interface = @import("interface.zig"); | ||||||
| const testing = std.testing; | const testing = std.testing; | ||||||
| 
 | 
 | ||||||
| const log = std.log.scoped(.@"main-lib"); | const log = std.log.scoped(.@"main-lib"); | ||||||
| var child_allocator = std.heap.raw_c_allocator; // raw allocator recommended for use in arenas |  | ||||||
| var arena: std.heap.ArenaAllocator = undefined; |  | ||||||
| 
 | 
 | ||||||
|  | var allocator: ?*std.mem.Allocator = null; | ||||||
| const Response = struct { | const Response = struct { | ||||||
|     body: *std.ArrayList(u8), |     body: *std.ArrayList(u8), | ||||||
|     headers: *std.StringHashMap([]const u8), |     headers: *std.StringHashMap([]const u8), | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
|  | /// This function is optional and can be exported by zig libraries for | ||||||
|  | /// initialization. If exported, it will be called once in the beginning of | ||||||
|  | /// a request and will be provided a pointer to std.mem.Allocator, which is | ||||||
|  | /// useful for reusing the parent allocator | ||||||
|  | export fn zigInit(parent_allocator: *anyopaque) void { | ||||||
|  |     allocator = @ptrCast(*std.mem.Allocator, @alignCast(@alignOf(*std.mem.Allocator), parent_allocator)); | ||||||
|  | } | ||||||
| export fn handle_request() ?*interface.Response { | export fn handle_request() ?*interface.Response { | ||||||
|     arena = std.heap.ArenaAllocator.init(child_allocator); |     var alloc = if (allocator) |a| a.* else @panic("zigInit not called prior to handle_request. This is a coding error"); | ||||||
|     var allocator = arena.allocator(); |  | ||||||
| 
 | 
 | ||||||
|     // setup response body |     // setup response body | ||||||
|     var response = std.ArrayList(u8).init(allocator); |     var response = std.ArrayList(u8).init(alloc); | ||||||
| 
 | 
 | ||||||
|     // setup headers |     // setup headers | ||||||
|     var headers = std.StringHashMap([]const u8).init(allocator); |     var headers = std.StringHashMap([]const u8).init(alloc); | ||||||
|     handleRequest(.{ |     handleRequest(.{ | ||||||
|         .body = &response, |         .body = &response, | ||||||
|         .headers = &headers, |         .headers = &headers, | ||||||
|  | @ -34,10 +39,10 @@ export fn handle_request() ?*interface.Response { | ||||||
|     log.debug("response ptr: {*}", .{response.items.ptr}); |     log.debug("response ptr: {*}", .{response.items.ptr}); | ||||||
|     // Marshall data back for handling by server |     // Marshall data back for handling by server | ||||||
| 
 | 
 | ||||||
|     var rc = allocator.create(interface.Response) catch @panic("OOM"); |     var rc = alloc.create(interface.Response) catch @panic("OOM"); | ||||||
|     rc.ptr = response.items.ptr; |     rc.ptr = response.items.ptr; | ||||||
|     rc.len = response.items.len; |     rc.len = response.items.len; | ||||||
|     rc.headers = interface.toHeaders(allocator, headers) catch |e| { |     rc.headers = interface.toHeaders(alloc, headers) catch |e| { | ||||||
|         log.err("Unexpected error processing request: {any}", .{e}); |         log.err("Unexpected error processing request: {any}", .{e}); | ||||||
|         if (@errorReturnTrace()) |trace| { |         if (@errorReturnTrace()) |trace| { | ||||||
|             std.debug.dumpStackTrace(trace.*); |             std.debug.dumpStackTrace(trace.*); | ||||||
|  | @ -48,10 +53,10 @@ export fn handle_request() ?*interface.Response { | ||||||
|     return rc; |     return rc; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /// having request_deinit allows for a general deinit as well | /// request_deinit is an optional export and will be called a the end of the | ||||||
| export fn request_deinit() void { | /// request. Useful for deallocating memory | ||||||
|     arena.deinit(); | // export fn request_deinit() void { | ||||||
| } | // } | ||||||
| 
 | 
 | ||||||
| // ************************************************************************ | // ************************************************************************ | ||||||
| // Boilerplate ^^, Custom code below | // Boilerplate ^^, Custom code below | ||||||
|  | @ -69,8 +74,10 @@ fn handleRequest(response: Response) !void { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| test "handle_request" { | test "handle_request" { | ||||||
|     defer request_deinit(); |     var arena = std.heap.ArenaAllocator.init(std.testing.allocator); | ||||||
|     child_allocator = std.testing.allocator; |     defer arena.deinit(); | ||||||
|  |     var aa = arena.allocator(); | ||||||
|  |     allocator = &aa; | ||||||
|     const response = handle_request().?; |     const response = handle_request().?; | ||||||
|     try testing.expectEqualStrings(" 2.", response.ptr[0..response.len]); |     try testing.expectEqualStrings(" 2.", response.ptr[0..response.len]); | ||||||
|     try testing.expectEqualStrings("X-custom-foo", response.headers[0].name_ptr[0..response.headers[0].name_len]); |     try testing.expectEqualStrings("X-custom-foo", response.headers[0].name_ptr[0..response.headers[0].name_len]); | ||||||
|  |  | ||||||
							
								
								
									
										137
									
								
								src/main.zig
									
										
									
									
									
								
							
							
						
						
									
										137
									
								
								src/main.zig
									
										
									
									
									
								
							|  | @ -2,7 +2,8 @@ const std = @import("std"); | ||||||
| const builtin = @import("builtin"); | const builtin = @import("builtin"); | ||||||
| const interface = @import("interface.zig"); | const interface = @import("interface.zig"); | ||||||
| const Watch = @import("Watch.zig"); | const Watch = @import("Watch.zig"); | ||||||
| const serveFn = *const fn () ?*interface.Response; | const serveFn = *const fn (*interface.Request) ?*interface.Response; | ||||||
|  | const zigInitFn = *const fn (*anyopaque) void; | ||||||
| const requestDeinitFn = *const fn () void; | const requestDeinitFn = *const fn () void; | ||||||
| 
 | 
 | ||||||
| const timeout = 250; | const timeout = 250; | ||||||
|  | @ -19,6 +20,7 @@ const Executor = struct { | ||||||
|     // fields used at runtime to do real work |     // fields used at runtime to do real work | ||||||
|     library: ?*anyopaque = null, |     library: ?*anyopaque = null, | ||||||
|     serveFn: ?serveFn = null, |     serveFn: ?serveFn = null, | ||||||
|  |     zigInitFn: ?zigInitFn = null, | ||||||
|     requestDeinitFn: ?requestDeinitFn = null, |     requestDeinitFn: ?requestDeinitFn = null, | ||||||
| 
 | 
 | ||||||
|     // fields used for internal accounting |     // fields used for internal accounting | ||||||
|  | @ -36,6 +38,7 @@ var executors = [_]Executor{ | ||||||
| 
 | 
 | ||||||
| var watcher = Watch.init(executorChanged); | var watcher = Watch.init(executorChanged); | ||||||
| var watcher_thread: ?std.Thread = null; | var watcher_thread: ?std.Thread = null; | ||||||
|  | var timer: ?std.time.Timer = null; // timer used by processRequest | ||||||
| 
 | 
 | ||||||
| const log = std.log.scoped(.main); | const log = std.log.scoped(.main); | ||||||
| pub const std_options = struct { | pub const std_options = struct { | ||||||
|  | @ -48,26 +51,33 @@ pub const std_options = struct { | ||||||
| const SERVE_FN_NAME = "handle_request"; | const SERVE_FN_NAME = "handle_request"; | ||||||
| const PORT = 8069; | const PORT = 8069; | ||||||
| 
 | 
 | ||||||
| fn serve(allocator: std.mem.Allocator, response: *std.http.Server.Response) !*FullReturn { | fn serve(allocator: *std.mem.Allocator, response: *std.http.Server.Response) !*FullReturn { | ||||||
|     // pub const Request = extern struct { |  | ||||||
|     //     method: [*]u8, |  | ||||||
|     //     method_len: usize, |  | ||||||
|     // |  | ||||||
|     //     content: [*]u8, |  | ||||||
|     //     content_len: usize, |  | ||||||
|     // |  | ||||||
|     //     headers: [*]Header, |  | ||||||
|     //     headers_len: usize, |  | ||||||
|     // }; |  | ||||||
|     // if (some path routing thing) { |     // if (some path routing thing) { | ||||||
|     // TODO: Get request body into executor |  | ||||||
|     // TODO: Get headers back from executor |  | ||||||
|     // TODO: Get request headers into executor |  | ||||||
|     const executor = try getExecutor(0); |     const executor = try getExecutor(0); | ||||||
|  |     if (executor.zigInitFn) |f| | ||||||
|  |         f(allocator); | ||||||
|  | 
 | ||||||
|     executor.in_request_lock = true; |     executor.in_request_lock = true; | ||||||
|     errdefer executor.in_request_lock = false; |     errdefer executor.in_request_lock = false; | ||||||
|     // Call external library |     // Call external library | ||||||
|     var serve_result = executor.serveFn.?().?; // ok for this pointer deref to fail |     const method_tag = @tagName(response.request.method); | ||||||
|  |     const headers = try toHeaders(allocator.*, response.request.headers); | ||||||
|  |     var request_content: []u8 = &[_]u8{}; | ||||||
|  |     if (response.request.content_length) |l| { | ||||||
|  |         request_content = try response.reader().readAllAlloc(allocator.*, @as(usize, l)); | ||||||
|  |     } | ||||||
|  |     log.debug("{d} bytes read from request", .{request_content.len}); | ||||||
|  |     var request = interface.Request{ | ||||||
|  |         .method = @constCast(method_tag[0..].ptr), | ||||||
|  |         .method_len = method_tag.len, | ||||||
|  | 
 | ||||||
|  |         .headers = headers, | ||||||
|  |         .headers_len = response.request.headers.list.items.len, | ||||||
|  | 
 | ||||||
|  |         .content = request_content.ptr, | ||||||
|  |         .content_len = request_content.len, | ||||||
|  |     }; | ||||||
|  |     var serve_result = executor.serveFn.?(&request).?; // ok for this pointer deref to fail | ||||||
|     log.debug("target: {s}", .{response.request.target}); |     log.debug("target: {s}", .{response.request.target}); | ||||||
|     log.warn("response ptr: {*}", .{serve_result.ptr}); // BUG: This works in tests, but does not when compiled (even debug mode) |     log.warn("response ptr: {*}", .{serve_result.ptr}); // BUG: This works in tests, but does not when compiled (even debug mode) | ||||||
|     var slice: []u8 = serve_result.ptr[0..serve_result.len]; |     var slice: []u8 = serve_result.ptr[0..serve_result.len]; | ||||||
|  | @ -77,14 +87,13 @@ fn serve(allocator: std.mem.Allocator, response: *std.http.Server.Response) !*Fu | ||||||
|     var content_type_added = false; |     var content_type_added = false; | ||||||
|     for (0..serve_result.headers_len) |inx| { |     for (0..serve_result.headers_len) |inx| { | ||||||
|         const head = serve_result.headers[inx]; |         const head = serve_result.headers[inx]; | ||||||
|         // head.name_ptr[0..head.name_len], |  | ||||||
|         try response.headers.append( |         try response.headers.append( | ||||||
|             head.name_ptr[0..head.name_len], |             head.name_ptr[0..head.name_len], | ||||||
|             head.value_ptr[0..head.value_len], |             head.value_ptr[0..head.value_len], | ||||||
|         ); |         ); | ||||||
| 
 | 
 | ||||||
|         // TODO: are these headers case insensitive? |         // headers are case insensitive | ||||||
|         content_type_added = std.mem.eql(u8, head.name_ptr[0..head.name_len], "content-type"); |         content_type_added = std.ascii.eqlIgnoreCase(head.name_ptr[0..head.name_len], "content-type"); | ||||||
|     } |     } | ||||||
|     if (!content_type_added) |     if (!content_type_added) | ||||||
|         try response.headers.append("content-type", "text/plain"); |         try response.headers.append("content-type", "text/plain"); | ||||||
|  | @ -129,6 +138,9 @@ fn loadOptionalSymbols(executor: *Executor) void { | ||||||
|     if (std.c.dlsym(executor.library.?, "request_deinit")) |s| { |     if (std.c.dlsym(executor.library.?, "request_deinit")) |s| { | ||||||
|         executor.requestDeinitFn = @ptrCast(requestDeinitFn, s); |         executor.requestDeinitFn = @ptrCast(requestDeinitFn, s); | ||||||
|     } |     } | ||||||
|  |     if (std.c.dlsym(executor.library.?, "zigInit")) |s| { | ||||||
|  |         executor.zigInitFn = @ptrCast(zigInitFn, s); | ||||||
|  |     } | ||||||
| } | } | ||||||
| fn executorChanged(watch: usize) void { | fn executorChanged(watch: usize) void { | ||||||
|     // NOTE: This will be called off the main thread |     // NOTE: This will be called off the main thread | ||||||
|  | @ -240,31 +252,53 @@ pub fn main() !void { | ||||||
|         log.info("pid: {d}", .{std.os.linux.getpid()}); |         log.info("pid: {d}", .{std.os.linux.getpid()}); | ||||||
| 
 | 
 | ||||||
|     try installSignalHandler(); |     try installSignalHandler(); | ||||||
|     while (true) { |  | ||||||
|     var arena = std.heap.ArenaAllocator.init(allocator); |     var arena = std.heap.ArenaAllocator.init(allocator); | ||||||
|     defer arena.deinit(); |     defer arena.deinit(); | ||||||
|  |     var aa = arena.allocator(); | ||||||
|  |     const bytes_preallocated = try preWarmArena(aa, &arena, 1); | ||||||
|  |     while (true) { | ||||||
|  |         // TODO: Learn what is typical and change this to .retain_with_limit = <value> | ||||||
|  |         defer { | ||||||
|  |             if (!arena.reset(.{ .retain_capacity = {} })) { | ||||||
|  |                 // reallocation failed, arena is degraded | ||||||
|  |                 log.warn("Arena reset failed and is degraded. Resetting arena", .{}); | ||||||
|  |                 arena.deinit(); | ||||||
|  |                 arena = std.heap.ArenaAllocator.init(allocator); | ||||||
|  |                 aa = arena.allocator(); | ||||||
|  |             } | ||||||
|  |         } | ||||||
| 
 | 
 | ||||||
|         processRequest(arena.allocator(), &server) catch |e| { |         processRequest(&aa, &server, stdout) catch |e| { | ||||||
|             log.err("Unexpected error processing request: {any}", .{e}); |             log.err("Unexpected error processing request: {any}", .{e}); | ||||||
|             if (@errorReturnTrace()) |trace| { |             if (@errorReturnTrace()) |trace| { | ||||||
|                 std.debug.dumpStackTrace(trace.*); |                 std.debug.dumpStackTrace(trace.*); | ||||||
|             } |             } | ||||||
|         }; |         }; | ||||||
|  |         try stdout.print(" (pre-alloc: {}, alloc: {})\n", .{ bytes_preallocated, arena.queryCapacity() }); | ||||||
|  |         try bw.flush(); | ||||||
|     } |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| fn processRequest(allocator: std.mem.Allocator, server: *std.http.Server) !void { | fn processRequest(allocator: *std.mem.Allocator, server: *std.http.Server, writer: anytype) !void { | ||||||
|     const max_header_size = 8192; |     const max_header_size = 8192; | ||||||
|  |     if (timer == null) timer = try std.time.Timer.start(); | ||||||
|  |     var tm = timer.?; | ||||||
|     const res = try server.accept(.{ .dynamic = max_header_size }); |     const res = try server.accept(.{ .dynamic = max_header_size }); | ||||||
|     defer res.deinit(); |     defer res.deinit(); | ||||||
|     defer res.reset(); |     defer res.reset(); | ||||||
|     try res.wait(); |     try res.wait(); // wait for client to send a complete request head | ||||||
| 
 |     // I believe it's fair to start our timer after this is done | ||||||
|     // TODO: deal with this |     tm.reset(); | ||||||
|     var buf: [1024]u8 = undefined; |  | ||||||
|     const n = try res.readAll(&buf); |  | ||||||
|     _ = n; |  | ||||||
| 
 | 
 | ||||||
|  |     // This is an nginx log: | ||||||
|  |     // git.lerch.org 50.39.111.175 - - [16/May/2023:02:56:31 +0000] "POST /api/actions/runner.v1.RunnerService/FetchTask HTTP/2.0" 200 0 "-" "connect-go/1.2.0-dev (go1.20.1)" "172.20.0.5:3000" | ||||||
|  |     // TODO: replicate this | ||||||
|  |     try writer.print("{} - - \"{s} {s} {s}\"", .{ | ||||||
|  |         res.address, | ||||||
|  |         @tagName(res.request.method), | ||||||
|  |         res.request.target, | ||||||
|  |         @tagName(res.request.version), | ||||||
|  |     }); | ||||||
|     // TODO: we need to also have a defer statement to deinit whatever happens |     // TODO: we need to also have a defer statement to deinit whatever happens | ||||||
|     // with the executor library. This will also add a race condition where |     // with the executor library. This will also add a race condition where | ||||||
|     // we could have a memory leak if the executor reloads in the middle of a |     // we could have a memory leak if the executor reloads in the middle of a | ||||||
|  | @ -294,12 +328,48 @@ fn processRequest(allocator: std.mem.Allocator, server: *std.http.Server) !void | ||||||
|         response_bytes = f.response; |         response_bytes = f.response; | ||||||
|     res.transfer_encoding = .{ .content_length = response_bytes.len }; |     res.transfer_encoding = .{ .content_length = response_bytes.len }; | ||||||
|     try res.headers.append("connection", "close"); |     try res.headers.append("connection", "close"); | ||||||
|  |     try writer.print(" {d} ttfb {d:.3}ms", .{ @enumToInt(res.status), @intToFloat(f64, tm.read()) / std.time.ns_per_ms }); | ||||||
|     if (builtin.is_test) writeToTestBuffers(response_bytes, res); |     if (builtin.is_test) writeToTestBuffers(response_bytes, res); | ||||||
|     try res.do(); |     try res.do(); | ||||||
|     _ = try res.writer().writeAll(response_bytes); |     _ = try res.writer().writeAll(response_bytes); | ||||||
|     try res.finish(); |     try res.finish(); | ||||||
|  |     try writer.print(" {d} ttlb {d:.3}ms", .{ | ||||||
|  |         response_bytes.len, | ||||||
|  |         @intToFloat(f64, tm.read()) / std.time.ns_per_ms, | ||||||
|  |     }); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | fn toHeaders(allocator: std.mem.Allocator, headers: std.http.Headers) ![*]interface.Header { | ||||||
|  |     var header_array = try std.ArrayList(interface.Header).initCapacity(allocator, headers.list.items.len); | ||||||
|  |     for (headers.list.items) |kv| { | ||||||
|  |         header_array.appendAssumeCapacity(.{ | ||||||
|  |             .name_ptr = @constCast(kv.name).ptr, | ||||||
|  |             .name_len = kv.name.len, | ||||||
|  | 
 | ||||||
|  |             .value_ptr = @constCast(kv.value).ptr, | ||||||
|  |             .value_len = kv.value.len, | ||||||
|  |         }); | ||||||
|  |     } | ||||||
|  |     return header_array.items.ptr; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | /// Allocates at least preallocated_kb kilobytes of ram for usage. Some overhead | ||||||
|  | /// will mean that more | ||||||
|  | fn preWarmArena(aa: std.mem.Allocator, arena: *std.heap.ArenaAllocator, preallocated_kb: usize) !usize { | ||||||
|  |     if (preallocated_kb == 0) return 0; | ||||||
|  |     // capacity 0 at this point | ||||||
|  |     const warm_array = try aa.alloc(u8, 1024 * preallocated_kb); // after this, we are at 1569 (545 extra) | ||||||
|  |     aa.free(warm_array); | ||||||
|  |     log.debug( | ||||||
|  |         "allocator preallocation. Limiting to: {d} bytes", | ||||||
|  |         .{(arena.queryCapacity() + @as(usize, 1023)) / @as(usize, 1024) * 1024}, | ||||||
|  |     ); | ||||||
|  |     if (!arena.reset(.{ .retain_with_limit = (arena.queryCapacity() + @as(usize, 1023)) / @as(usize, 1024) * 1024 })) | ||||||
|  |         log.warn("arena reset failed, arena degraded", .{}); | ||||||
|  |     var bytes_allocated = arena.queryCapacity(); | ||||||
|  |     log.debug("preallocated {d} bytes", .{bytes_allocated}); | ||||||
|  |     return bytes_allocated; | ||||||
|  | } | ||||||
| fn writeToTestBuffers(response: []const u8, res: *std.http.Server.Response) void { | fn writeToTestBuffers(response: []const u8, res: *std.http.Server.Response) void { | ||||||
|     _ = res; |     _ = res; | ||||||
|     log.debug("writing to test buffers", .{}); |     log.debug("writing to test buffers", .{}); | ||||||
|  | @ -325,10 +395,17 @@ fn testRequest(request_bytes: []const u8) !void { | ||||||
|     try server.listen(address); |     try server.listen(address); | ||||||
|     const server_port = server.socket.listen_address.in.getPort(); |     const server_port = server.socket.listen_address.in.getPort(); | ||||||
| 
 | 
 | ||||||
|  |     var al = std.ArrayList(u8).init(allocator); | ||||||
|  |     defer al.deinit(); | ||||||
|  |     var writer = al.writer(); | ||||||
|  |     var aa = arena.allocator(); | ||||||
|  |     var bytes_allocated: usize = 0; | ||||||
|  |     // pre-warm | ||||||
|  |     bytes_allocated = try preWarmArena(aa, &arena, 1); | ||||||
|     const server_thread = try std.Thread.spawn( |     const server_thread = try std.Thread.spawn( | ||||||
|         .{}, |         .{}, | ||||||
|         processRequest, |         processRequest, | ||||||
|         .{ arena.allocator(), &server }, |         .{ &aa, &server, writer }, | ||||||
|     ); |     ); | ||||||
| 
 | 
 | ||||||
|     const stream = try std.net.tcpConnectToHost(allocator, "127.0.0.1", server_port); |     const stream = try std.net.tcpConnectToHost(allocator, "127.0.0.1", server_port); | ||||||
|  | @ -336,6 +413,8 @@ fn testRequest(request_bytes: []const u8) !void { | ||||||
|     _ = try stream.writeAll(request_bytes[0..]); |     _ = try stream.writeAll(request_bytes[0..]); | ||||||
| 
 | 
 | ||||||
|     server_thread.join(); |     server_thread.join(); | ||||||
|  |     log.debug("Bytes allocated during request: {d}", .{arena.queryCapacity() - bytes_allocated}); | ||||||
|  |     log.debug("Stdout: {s}", .{al.items}); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| fn testGet(comptime path: []const u8) !void { | fn testGet(comptime path: []const u8) !void { | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		
		Reference in a new issue