diff --git a/src/aws.zig b/src/aws.zig index a44423b..7f94d29 100644 --- a/src/aws.zig +++ b/src/aws.zig @@ -13,6 +13,9 @@ const xml_serializer = @import("xml_serializer.zig"); const scoped_log = std.log.scoped(.aws); +const Allocator = std.mem.Allocator; +const ArenaAllocator = std.heap.ArenaAllocator; + /// control all logs directly/indirectly used by aws sdk. Not recommended for /// use under normal circumstances, but helpful for times when the zig logging /// controls are insufficient (e.g. use in build script) @@ -92,7 +95,7 @@ pub const Options = struct { pub const Diagnostics = struct { http_code: i64, response_body: []const u8, - allocator: std.mem.Allocator, + allocator: Allocator, pub fn deinit(self: *Diagnostics) void { self.allocator.free(self.response_body); @@ -114,12 +117,12 @@ pub const ClientOptions = struct { proxy: ?std.http.Client.Proxy = null, }; pub const Client = struct { - allocator: std.mem.Allocator, + allocator: Allocator, aws_http: awshttp.AwsHttp, const Self = @This(); - pub fn init(allocator: std.mem.Allocator, options: ClientOptions) Self { + pub fn init(allocator: Allocator, options: ClientOptions) Self { return Self{ .allocator = allocator, .aws_http = awshttp.AwsHttp.init(allocator, options.proxy), @@ -229,7 +232,7 @@ pub fn Request(comptime request_action: anytype) type { // We don't know if we need a body...guessing here, this should cover most var buffer = std.ArrayList(u8).init(options.client.allocator); defer buffer.deinit(); - var nameAllocator = std.heap.ArenaAllocator.init(options.client.allocator); + var nameAllocator = ArenaAllocator.init(options.client.allocator); defer nameAllocator.deinit(); if (Self.service_meta.aws_protocol == .rest_json_1) { if (std.mem.eql(u8, "PUT", aws_request.method) or std.mem.eql(u8, "POST", aws_request.method)) { @@ -326,7 +329,7 @@ pub fn Request(comptime request_action: anytype) type { // for a boxed member with no observable difference." But we're // seeing a lot of differences here between spec and reality // - var nameAllocator = std.heap.ArenaAllocator.init(options.client.allocator); + var nameAllocator = ArenaAllocator.init(options.client.allocator); defer nameAllocator.deinit(); try json.stringify(request, .{ .whitespace = .{} }, buffer.writer()); @@ -359,13 +362,16 @@ pub fn Request(comptime request_action: anytype) type { const continuation = if (buffer.items.len > 0) "&" else ""; const query = if (Self.service_meta.aws_protocol == .query) - try std.fmt.allocPrint(options.client.allocator, "", .{}) + "" else // EC2 try std.fmt.allocPrint(options.client.allocator, "?Action={s}&Version={s}", .{ action.action_name, Self.service_meta.version, }); - defer options.client.allocator.free(query); + + defer if (Self.service_meta.aws_protocol != .query) { + options.client.allocator.free(query); + }; // Note: EC2 avoided the Action={s}&Version={s} in the body, but it's // but it's required, so I'm not sure why that code was put in @@ -378,6 +384,7 @@ pub fn Request(comptime request_action: anytype) type { buffer.items, }); defer options.client.allocator.free(body); + return try Self.callAws(.{ .query = query, .body = body, @@ -465,7 +472,7 @@ pub fn Request(comptime request_action: anytype) type { } fn setHeaderValue( - allocator: std.mem.Allocator, + allocator: Allocator, response: anytype, comptime field_name: []const u8, comptime field_type: type, @@ -491,22 +498,23 @@ pub fn Request(comptime request_action: anytype) type { expected_body_field_len -= std.meta.fields(@TypeOf(action.Response.http_header)).len; } + var buf_request_id: [256]u8 = undefined; + const request_id = try requestIdFromHeaders(&buf_request_id, options.client.allocator, aws_request, response); + if (@hasDecl(action.Response, "http_payload")) { - var rc = FullResponseType{ + var rc = try FullResponseType.init(.{ + .arena = ArenaAllocator.init(options.client.allocator), .response = .{}, - .response_metadata = .{ - .request_id = try requestIdFromHeaders(aws_request, response, options), - }, - .parser_options = .{ .json = .{} }, + .request_id = request_id, .raw_parsed = .{ .raw = .{} }, - .allocator = options.client.allocator, - }; + }); + const body_field = @field(rc.response, action.Response.http_payload); const BodyField = @TypeOf(body_field); if (BodyField == []const u8 or BodyField == ?[]const u8) { expected_body_field_len = 0; // We can't use body_field for this set - only @field will work - @field(rc.response, action.Response.http_payload) = try options.client.allocator.dupe(u8, response.body); + // @field(rc.response, action.Response.http_payload) = try rc.arena.allocator().dupe(u8, response.body); return rc; } rc.deinit(); @@ -515,15 +523,12 @@ pub fn Request(comptime request_action: anytype) type { // We don't care about the body if there are no fields we expect there... if (std.meta.fields(action.Response).len == 0 or expected_body_field_len == 0 or response.body.len == 0) { // Do we care if an unexpected body comes in? - return FullResponseType{ + return try FullResponseType.init(.{ + .arena = ArenaAllocator.init(options.client.allocator), .response = undefined, - .response_metadata = .{ - .request_id = try requestIdFromHeaders(aws_request, response, options), - }, - .parser_options = .{ .json = .{} }, + .request_id = request_id, .raw_parsed = .{ .raw = undefined }, - .allocator = options.client.allocator, - }; + }); } return switch (try getContentType(response.headers)) { @@ -570,26 +575,24 @@ pub fn Request(comptime request_action: anytype) type { // We can grab index [0] as structs are guaranteed by zig to be returned in the order // declared, and we're declaring in that order in ServerResponse(). const real_response = @field(parsed_response, @typeInfo(response_types.NormalResponse).@"struct".fields[0].name); - return FullResponseType{ + + return try FullResponseType.init(.{ + .arena = ArenaAllocator.init(options.client.allocator), .response = @field(real_response, @typeInfo(@TypeOf(real_response)).@"struct".fields[0].name), - .response_metadata = .{ - .request_id = try options.client.allocator.dupe(u8, real_response.ResponseMetadata.RequestId), - }, - .parser_options = .{ .json = parser_options }, + .request_id = real_response.ResponseMetadata.RequestId, .raw_parsed = .{ .server = parsed_response }, - .allocator = options.client.allocator, - }; + }); } else { // Conditions 2 or 3 (no wrapping) - return FullResponseType{ + var buf_request_id: [256]u8 = undefined; + const request_id = try requestIdFromHeaders(&buf_request_id, options.client.allocator, aws_request, response); + + return try FullResponseType.init(.{ + .arena = ArenaAllocator.init(options.client.allocator), .response = parsed_response, - .response_metadata = .{ - .request_id = try requestIdFromHeaders(aws_request, response, options), - }, - .parser_options = .{ .json = parser_options }, + .request_id = request_id, .raw_parsed = .{ .raw = parsed_response }, - .allocator = options.client.allocator, - }; + }); } } @@ -662,23 +665,21 @@ pub fn Request(comptime request_action: anytype) type { defer if (free_body) options.client.allocator.free(body); const parsed = try xml_shaper.parse(action.Response, body, xml_options); errdefer parsed.deinit(); - // This needs to get into FullResponseType somehow: defer parsed.deinit(); - const request_id = blk: { - if (parsed.document.root.getCharData("requestId")) |elem| - break :blk try options.client.allocator.dupe(u8, elem); - break :blk try requestIdFromHeaders(request, result, options); - }; - defer options.client.allocator.free(request_id); - return FullResponseType{ - .response = parsed.parsed_value, - .response_metadata = .{ - .request_id = try options.client.allocator.dupe(u8, request_id), - }, - .parser_options = .{ .xml = xml_options }, - .raw_parsed = .{ .xml = parsed }, - .allocator = options.client.allocator, + var buf_request_id: [256]u8 = undefined; + const request_id = blk: { + if (parsed.document.root.getCharData("requestId")) |elem| { + break :blk elem; + } + break :blk try requestIdFromHeaders(&buf_request_id, options.client.allocator, request, result); }; + + return try FullResponseType.init(.{ + .arena = ArenaAllocator.init(options.client.allocator), + .response = parsed.parsed_value, + .request_id = request_id, + .raw_parsed = .{ .xml = parsed }, + }); } const ServerResponseTypes = struct { NormalResponse: type, @@ -741,7 +742,7 @@ pub fn Request(comptime request_action: anytype) type { fn ParsedJsonData(comptime T: type) type { return struct { parsed_response_ptr: *T, - allocator: std.mem.Allocator, + allocator: Allocator, const MySelf = @This(); @@ -754,6 +755,7 @@ pub fn Request(comptime request_action: anytype) type { fn parseJsonData(comptime response_types: ServerResponseTypes, data: []const u8, options: Options, parser_options: json.ParseOptions) !ParsedJsonData(response_types.NormalResponse) { // Now it's time to start looking at the actual data. Job 1 will // be to figure out if this is a raw response or wrapped + const allocator = options.client.allocator; // Extract the first json key const key = firstJsonKey(data); @@ -763,8 +765,8 @@ pub fn Request(comptime request_action: anytype) type { isOtherNormalResponse(response_types.NormalResponse, key); var stream = json.TokenStream.init(data); const parsed_response_ptr = blk: { - const ptr = try options.client.allocator.create(response_types.NormalResponse); - errdefer options.client.allocator.destroy(ptr); + const ptr = try allocator.create(response_types.NormalResponse); + errdefer allocator.destroy(ptr); if (!response_types.isRawPossible or found_normal_json_response) { ptr.* = (json.parse(response_types.NormalResponse, &stream, parser_options) catch |e| { @@ -807,7 +809,7 @@ pub fn Request(comptime request_action: anytype) type { }; return ParsedJsonData(response_types.NormalResponse){ .parsed_response_ptr = parsed_response_ptr, - .allocator = options.client.allocator, + .allocator = allocator, }; } }; @@ -861,14 +863,14 @@ fn parseInt(comptime T: type, val: []const u8) !T { return rc; } -fn generalAllocPrint(allocator: std.mem.Allocator, val: anytype) !?[]const u8 { +fn generalAllocPrint(allocator: Allocator, val: anytype) !?[]const u8 { switch (@typeInfo(@TypeOf(val))) { .optional => if (val) |v| return generalAllocPrint(allocator, v) else return null, .array, .pointer => return try std.fmt.allocPrint(allocator, "{s}", .{val}), else => return try std.fmt.allocPrint(allocator, "{any}", .{val}), } } -fn headersFor(allocator: std.mem.Allocator, request: anytype) ![]awshttp.Header { +fn headersFor(allocator: Allocator, request: anytype) ![]awshttp.Header { log.debug("Checking for headers to include for type {}", .{@TypeOf(request)}); if (!@hasDecl(@TypeOf(request), "http_header")) return &[_]awshttp.Header{}; const http_header = @TypeOf(request).http_header; @@ -892,7 +894,7 @@ fn headersFor(allocator: std.mem.Allocator, request: anytype) ![]awshttp.Header return headers.toOwnedSlice(); } -fn freeHeadersFor(allocator: std.mem.Allocator, request: anytype, headers: []const awshttp.Header) void { +fn freeHeadersFor(allocator: Allocator, request: anytype, headers: []const awshttp.Header) void { if (!@hasDecl(@TypeOf(request), "http_header")) return; const http_header = @TypeOf(request).http_header; const fields = std.meta.fields(@TypeOf(http_header)); @@ -951,8 +953,9 @@ fn getContentType(headers: []const awshttp.Header) !ContentType { return error.ContentTypeNotFound; } -/// Get request ID from headers. Caller responsible for freeing memory -fn requestIdFromHeaders(request: awshttp.HttpRequest, response: awshttp.HttpResult, options: Options) ![]u8 { +/// Get request ID from headers. +/// Allocation is only used in case of an error. Caller does not need to free the returned buffer. +fn requestIdFromHeaders(buf: []u8, allocator: Allocator, request: awshttp.HttpRequest, response: awshttp.HttpResult) ![]u8 { var rid: ?[]const u8 = null; // This "thing" is called: // * Host ID @@ -972,11 +975,14 @@ fn requestIdFromHeaders(request: awshttp.HttpRequest, response: awshttp.HttpResu host_id = header.value; } if (rid) |r| { - if (host_id) |h| - return try std.fmt.allocPrint(options.client.allocator, "{s}, host_id: {s}", .{ r, h }); - return try options.client.allocator.dupe(u8, r); + if (host_id) |h| { + return try std.fmt.bufPrint(buf, "{s}, host_id: {s}", .{ r, h }); + } + + @memcpy(buf[0..r.len], r); + return buf[0..r.len]; } - try reportTraffic(options.client.allocator, "Request ID not found", request, response, log.err); + try reportTraffic(allocator, "Request ID not found", request, response, log.err); return error.RequestIdNotFound; } fn ServerResponse(comptime action: anytype) type { @@ -1029,65 +1035,62 @@ fn ServerResponse(comptime action: anytype) type { } fn FullResponse(comptime action: anytype) type { return struct { - response: action.Response, - response_metadata: struct { - request_id: []u8, - }, - parser_options: union(enum) { - json: json.ParseOptions, - xml: xml_shaper.ParseOptions, - }, - raw_parsed: union(enum) { + pub const ResponseMetadata = struct { + request_id: []const u8, + }; + + pub const RawParsed = union(enum) { server: ServerResponse(action), raw: action.Response, xml: xml_shaper.Parsed(action.Response), - }, - allocator: std.mem.Allocator, + }; + + pub const FullResponseOptions = struct { + response: action.Response = undefined, + request_id: []const u8, + raw_parsed: RawParsed = .{ .raw = undefined }, + arena: ArenaAllocator, + }; + + response: action.Response = undefined, + raw_parsed: RawParsed = .{ .raw = undefined }, + response_metadata: ResponseMetadata, + arena: ArenaAllocator, const Self = @This(); - pub fn deinit(self: Self) void { - switch (self.raw_parsed) { - // Server is json only (so far) - .server => json.parseFree(ServerResponse(action), self.raw_parsed.server, self.parser_options.json), - // Raw is json only (so far) - .raw => json.parseFree(action.Response, self.raw_parsed.raw, self.parser_options.json), - .xml => |xml| xml.deinit(), - } - self.allocator.free(self.response_metadata.request_id); - const Response = @TypeOf(self.response); - if (@hasDecl(Response, "http_header")) { - inline for (std.meta.fields(@TypeOf(Response.http_header))) |f| { - safeFree(self.allocator, @field(self.response, f.name)); - } - } - if (@hasDecl(Response, "http_payload")) { - const body_field = @field(self.response, Response.http_payload); - const BodyField = @TypeOf(body_field); - if (BodyField == []const u8) { - self.allocator.free(body_field); - } - if (BodyField == ?[]const u8) { - if (body_field) |f| - self.allocator.free(f); - } - } + pub fn init(options: FullResponseOptions) !Self { + var arena = options.arena; + const request_id = try arena.allocator().dupe(u8, options.request_id); + + return Self{ + .arena = arena, + .response = options.response, + .raw_parsed = options.raw_parsed, + .response_metadata = .{ + .request_id = request_id, + }, + }; + } + + pub fn deinit(self: Self) void { + self.arena.deinit(); } }; } -fn safeFree(allocator: std.mem.Allocator, obj: anytype) void { +fn safeFree(allocator: Allocator, obj: anytype) void { switch (@typeInfo(@TypeOf(obj))) { .pointer => allocator.free(obj), .optional => if (obj) |o| safeFree(allocator, o), else => {}, } } -fn queryFieldTransformer(allocator: std.mem.Allocator, field_name: []const u8) anyerror![]const u8 { +fn queryFieldTransformer(allocator: Allocator, field_name: []const u8) anyerror![]const u8 { return try case.snakeToPascal(allocator, field_name); } fn buildPath( - allocator: std.mem.Allocator, + allocator: Allocator, raw_uri: []const u8, comptime ActionRequest: type, request: anytype, @@ -1174,7 +1177,7 @@ fn uriEncodeByte(char: u8, writer: anytype, encode_slash: bool) !void { } } -fn buildQuery(allocator: std.mem.Allocator, request: anytype) ![]const u8 { +fn buildQuery(allocator: Allocator, request: anytype) ![]const u8 { // query should look something like this: // pub const http_query = .{ // .master_region = "MasterRegion", @@ -1296,7 +1299,7 @@ pub fn IgnoringWriter(comptime WriterType: type) type { } fn reportTraffic( - allocator: std.mem.Allocator, + allocator: Allocator, info: []const u8, request: awshttp.HttpRequest, response: awshttp.HttpResult, @@ -1498,7 +1501,7 @@ test "basic json request serialization" { // for a boxed member with no observable difference." But we're // seeing a lot of differences here between spec and reality // - var nameAllocator = std.heap.ArenaAllocator.init(allocator); + var nameAllocator = ArenaAllocator.init(allocator); defer nameAllocator.deinit(); try json.stringify(request, .{ .whitespace = .{} }, buffer.writer()); try std.testing.expectEqualStrings( @@ -1582,8 +1585,8 @@ test { std.testing.refAllDecls(xml_shaper); } const TestOptions = struct { - allocator: std.mem.Allocator, - arena: ?*std.heap.ArenaAllocator = null, + allocator: Allocator, + arena: ?*ArenaAllocator = null, server_port: ?u16 = null, server_remaining_requests: usize = 1, server_response: []const u8 = "unset", @@ -1672,8 +1675,8 @@ const TestOptions = struct { fn threadMain(options: *TestOptions) !void { // https://github.com/ziglang/zig/blob/d2be725e4b14c33dbd39054e33d926913eee3cd4/lib/compiler/std-docs.zig#L22-L54 - options.arena = try options.allocator.create(std.heap.ArenaAllocator); - options.arena.?.* = std.heap.ArenaAllocator.init(options.allocator); + options.arena = try options.allocator.create(ArenaAllocator); + options.arena.?.* = ArenaAllocator.init(options.allocator); const allocator = options.arena.?.allocator(); options.allocator = allocator; @@ -1684,7 +1687,7 @@ fn threadMain(options: *TestOptions) !void { options.test_server_runtime_uri = try std.fmt.allocPrint(options.allocator, "http://127.0.0.1:{d}", .{options.server_port.?}); log.debug("server listening at {s}", .{options.test_server_runtime_uri.?}); log.info("starting server thread, tid {d}", .{std.Thread.getCurrentId()}); - // var arena = std.heap.ArenaAllocator.init(options.allocator); + // var arena = ArenaAllocator.init(options.allocator); // defer arena.deinit(); // var aa = arena.allocator(); // We're in control of all requests/responses, so this flag will tell us @@ -1764,7 +1767,7 @@ fn serveRequest(options: *TestOptions, request: *std.http.Server.Request) !void //////////////////////////////////////////////////////////////////////// const TestSetup = struct { - allocator: std.mem.Allocator, + allocator: Allocator, request_options: TestOptions, server_thread: std.Thread = undefined, creds: aws_auth.Credentials = undefined,