chore: wip refactor of FullResponse to use arena allocator to simplify memory management

This commit is contained in:
Simon Hartcher 2025-04-23 16:33:20 +10:00
parent bd2aede64e
commit 5a8cceaa0b

View file

@ -13,6 +13,9 @@ const xml_serializer = @import("xml_serializer.zig");
const scoped_log = std.log.scoped(.aws);
const Allocator = std.mem.Allocator;
const ArenaAllocator = std.heap.ArenaAllocator;
/// control all logs directly/indirectly used by aws sdk. Not recommended for
/// use under normal circumstances, but helpful for times when the zig logging
/// controls are insufficient (e.g. use in build script)
@ -92,7 +95,7 @@ pub const Options = struct {
pub const Diagnostics = struct {
http_code: i64,
response_body: []const u8,
allocator: std.mem.Allocator,
allocator: Allocator,
pub fn deinit(self: *Diagnostics) void {
self.allocator.free(self.response_body);
@ -114,12 +117,12 @@ pub const ClientOptions = struct {
proxy: ?std.http.Client.Proxy = null,
};
pub const Client = struct {
allocator: std.mem.Allocator,
allocator: Allocator,
aws_http: awshttp.AwsHttp,
const Self = @This();
pub fn init(allocator: std.mem.Allocator, options: ClientOptions) Self {
pub fn init(allocator: Allocator, options: ClientOptions) Self {
return Self{
.allocator = allocator,
.aws_http = awshttp.AwsHttp.init(allocator, options.proxy),
@ -229,7 +232,7 @@ pub fn Request(comptime request_action: anytype) type {
// We don't know if we need a body...guessing here, this should cover most
var buffer = std.ArrayList(u8).init(options.client.allocator);
defer buffer.deinit();
var nameAllocator = std.heap.ArenaAllocator.init(options.client.allocator);
var nameAllocator = ArenaAllocator.init(options.client.allocator);
defer nameAllocator.deinit();
if (Self.service_meta.aws_protocol == .rest_json_1) {
if (std.mem.eql(u8, "PUT", aws_request.method) or std.mem.eql(u8, "POST", aws_request.method)) {
@ -326,7 +329,7 @@ pub fn Request(comptime request_action: anytype) type {
// for a boxed member with no observable difference." But we're
// seeing a lot of differences here between spec and reality
//
var nameAllocator = std.heap.ArenaAllocator.init(options.client.allocator);
var nameAllocator = ArenaAllocator.init(options.client.allocator);
defer nameAllocator.deinit();
try json.stringify(request, .{ .whitespace = .{} }, buffer.writer());
@ -359,13 +362,16 @@ pub fn Request(comptime request_action: anytype) type {
const continuation = if (buffer.items.len > 0) "&" else "";
const query = if (Self.service_meta.aws_protocol == .query)
try std.fmt.allocPrint(options.client.allocator, "", .{})
""
else // EC2
try std.fmt.allocPrint(options.client.allocator, "?Action={s}&Version={s}", .{
action.action_name,
Self.service_meta.version,
});
defer options.client.allocator.free(query);
defer if (Self.service_meta.aws_protocol != .query) {
options.client.allocator.free(query);
};
// Note: EC2 avoided the Action={s}&Version={s} in the body, but it's
// but it's required, so I'm not sure why that code was put in
@ -378,6 +384,7 @@ pub fn Request(comptime request_action: anytype) type {
buffer.items,
});
defer options.client.allocator.free(body);
return try Self.callAws(.{
.query = query,
.body = body,
@ -465,7 +472,7 @@ pub fn Request(comptime request_action: anytype) type {
}
fn setHeaderValue(
allocator: std.mem.Allocator,
allocator: Allocator,
response: anytype,
comptime field_name: []const u8,
comptime field_type: type,
@ -491,22 +498,23 @@ pub fn Request(comptime request_action: anytype) type {
expected_body_field_len -= std.meta.fields(@TypeOf(action.Response.http_header)).len;
}
var buf_request_id: [256]u8 = undefined;
const request_id = try requestIdFromHeaders(&buf_request_id, options.client.allocator, aws_request, response);
if (@hasDecl(action.Response, "http_payload")) {
var rc = FullResponseType{
var rc = try FullResponseType.init(.{
.arena = ArenaAllocator.init(options.client.allocator),
.response = .{},
.response_metadata = .{
.request_id = try requestIdFromHeaders(aws_request, response, options),
},
.parser_options = .{ .json = .{} },
.request_id = request_id,
.raw_parsed = .{ .raw = .{} },
.allocator = options.client.allocator,
};
});
const body_field = @field(rc.response, action.Response.http_payload);
const BodyField = @TypeOf(body_field);
if (BodyField == []const u8 or BodyField == ?[]const u8) {
expected_body_field_len = 0;
// We can't use body_field for this set - only @field will work
@field(rc.response, action.Response.http_payload) = try options.client.allocator.dupe(u8, response.body);
// @field(rc.response, action.Response.http_payload) = try rc.arena.allocator().dupe(u8, response.body);
return rc;
}
rc.deinit();
@ -515,15 +523,12 @@ pub fn Request(comptime request_action: anytype) type {
// We don't care about the body if there are no fields we expect there...
if (std.meta.fields(action.Response).len == 0 or expected_body_field_len == 0 or response.body.len == 0) {
// Do we care if an unexpected body comes in?
return FullResponseType{
return try FullResponseType.init(.{
.arena = ArenaAllocator.init(options.client.allocator),
.response = undefined,
.response_metadata = .{
.request_id = try requestIdFromHeaders(aws_request, response, options),
},
.parser_options = .{ .json = .{} },
.request_id = request_id,
.raw_parsed = .{ .raw = undefined },
.allocator = options.client.allocator,
};
});
}
return switch (try getContentType(response.headers)) {
@ -570,26 +575,24 @@ pub fn Request(comptime request_action: anytype) type {
// We can grab index [0] as structs are guaranteed by zig to be returned in the order
// declared, and we're declaring in that order in ServerResponse().
const real_response = @field(parsed_response, @typeInfo(response_types.NormalResponse).@"struct".fields[0].name);
return FullResponseType{
return try FullResponseType.init(.{
.arena = ArenaAllocator.init(options.client.allocator),
.response = @field(real_response, @typeInfo(@TypeOf(real_response)).@"struct".fields[0].name),
.response_metadata = .{
.request_id = try options.client.allocator.dupe(u8, real_response.ResponseMetadata.RequestId),
},
.parser_options = .{ .json = parser_options },
.request_id = real_response.ResponseMetadata.RequestId,
.raw_parsed = .{ .server = parsed_response },
.allocator = options.client.allocator,
};
});
} else {
// Conditions 2 or 3 (no wrapping)
return FullResponseType{
var buf_request_id: [256]u8 = undefined;
const request_id = try requestIdFromHeaders(&buf_request_id, options.client.allocator, aws_request, response);
return try FullResponseType.init(.{
.arena = ArenaAllocator.init(options.client.allocator),
.response = parsed_response,
.response_metadata = .{
.request_id = try requestIdFromHeaders(aws_request, response, options),
},
.parser_options = .{ .json = parser_options },
.request_id = request_id,
.raw_parsed = .{ .raw = parsed_response },
.allocator = options.client.allocator,
};
});
}
}
@ -662,23 +665,21 @@ pub fn Request(comptime request_action: anytype) type {
defer if (free_body) options.client.allocator.free(body);
const parsed = try xml_shaper.parse(action.Response, body, xml_options);
errdefer parsed.deinit();
// This needs to get into FullResponseType somehow: defer parsed.deinit();
const request_id = blk: {
if (parsed.document.root.getCharData("requestId")) |elem|
break :blk try options.client.allocator.dupe(u8, elem);
break :blk try requestIdFromHeaders(request, result, options);
};
defer options.client.allocator.free(request_id);
return FullResponseType{
.response = parsed.parsed_value,
.response_metadata = .{
.request_id = try options.client.allocator.dupe(u8, request_id),
},
.parser_options = .{ .xml = xml_options },
.raw_parsed = .{ .xml = parsed },
.allocator = options.client.allocator,
var buf_request_id: [256]u8 = undefined;
const request_id = blk: {
if (parsed.document.root.getCharData("requestId")) |elem| {
break :blk elem;
}
break :blk try requestIdFromHeaders(&buf_request_id, options.client.allocator, request, result);
};
return try FullResponseType.init(.{
.arena = ArenaAllocator.init(options.client.allocator),
.response = parsed.parsed_value,
.request_id = request_id,
.raw_parsed = .{ .xml = parsed },
});
}
const ServerResponseTypes = struct {
NormalResponse: type,
@ -741,7 +742,7 @@ pub fn Request(comptime request_action: anytype) type {
fn ParsedJsonData(comptime T: type) type {
return struct {
parsed_response_ptr: *T,
allocator: std.mem.Allocator,
allocator: Allocator,
const MySelf = @This();
@ -754,6 +755,7 @@ pub fn Request(comptime request_action: anytype) type {
fn parseJsonData(comptime response_types: ServerResponseTypes, data: []const u8, options: Options, parser_options: json.ParseOptions) !ParsedJsonData(response_types.NormalResponse) {
// Now it's time to start looking at the actual data. Job 1 will
// be to figure out if this is a raw response or wrapped
const allocator = options.client.allocator;
// Extract the first json key
const key = firstJsonKey(data);
@ -763,8 +765,8 @@ pub fn Request(comptime request_action: anytype) type {
isOtherNormalResponse(response_types.NormalResponse, key);
var stream = json.TokenStream.init(data);
const parsed_response_ptr = blk: {
const ptr = try options.client.allocator.create(response_types.NormalResponse);
errdefer options.client.allocator.destroy(ptr);
const ptr = try allocator.create(response_types.NormalResponse);
errdefer allocator.destroy(ptr);
if (!response_types.isRawPossible or found_normal_json_response) {
ptr.* = (json.parse(response_types.NormalResponse, &stream, parser_options) catch |e| {
@ -807,7 +809,7 @@ pub fn Request(comptime request_action: anytype) type {
};
return ParsedJsonData(response_types.NormalResponse){
.parsed_response_ptr = parsed_response_ptr,
.allocator = options.client.allocator,
.allocator = allocator,
};
}
};
@ -861,14 +863,14 @@ fn parseInt(comptime T: type, val: []const u8) !T {
return rc;
}
fn generalAllocPrint(allocator: std.mem.Allocator, val: anytype) !?[]const u8 {
fn generalAllocPrint(allocator: Allocator, val: anytype) !?[]const u8 {
switch (@typeInfo(@TypeOf(val))) {
.optional => if (val) |v| return generalAllocPrint(allocator, v) else return null,
.array, .pointer => return try std.fmt.allocPrint(allocator, "{s}", .{val}),
else => return try std.fmt.allocPrint(allocator, "{any}", .{val}),
}
}
fn headersFor(allocator: std.mem.Allocator, request: anytype) ![]awshttp.Header {
fn headersFor(allocator: Allocator, request: anytype) ![]awshttp.Header {
log.debug("Checking for headers to include for type {}", .{@TypeOf(request)});
if (!@hasDecl(@TypeOf(request), "http_header")) return &[_]awshttp.Header{};
const http_header = @TypeOf(request).http_header;
@ -892,7 +894,7 @@ fn headersFor(allocator: std.mem.Allocator, request: anytype) ![]awshttp.Header
return headers.toOwnedSlice();
}
fn freeHeadersFor(allocator: std.mem.Allocator, request: anytype, headers: []const awshttp.Header) void {
fn freeHeadersFor(allocator: Allocator, request: anytype, headers: []const awshttp.Header) void {
if (!@hasDecl(@TypeOf(request), "http_header")) return;
const http_header = @TypeOf(request).http_header;
const fields = std.meta.fields(@TypeOf(http_header));
@ -951,8 +953,9 @@ fn getContentType(headers: []const awshttp.Header) !ContentType {
return error.ContentTypeNotFound;
}
/// Get request ID from headers. Caller responsible for freeing memory
fn requestIdFromHeaders(request: awshttp.HttpRequest, response: awshttp.HttpResult, options: Options) ![]u8 {
/// Get request ID from headers.
/// Allocation is only used in case of an error. Caller does not need to free the returned buffer.
fn requestIdFromHeaders(buf: []u8, allocator: Allocator, request: awshttp.HttpRequest, response: awshttp.HttpResult) ![]u8 {
var rid: ?[]const u8 = null;
// This "thing" is called:
// * Host ID
@ -972,11 +975,14 @@ fn requestIdFromHeaders(request: awshttp.HttpRequest, response: awshttp.HttpResu
host_id = header.value;
}
if (rid) |r| {
if (host_id) |h|
return try std.fmt.allocPrint(options.client.allocator, "{s}, host_id: {s}", .{ r, h });
return try options.client.allocator.dupe(u8, r);
if (host_id) |h| {
return try std.fmt.bufPrint(buf, "{s}, host_id: {s}", .{ r, h });
}
try reportTraffic(options.client.allocator, "Request ID not found", request, response, log.err);
@memcpy(buf[0..r.len], r);
return buf[0..r.len];
}
try reportTraffic(allocator, "Request ID not found", request, response, log.err);
return error.RequestIdNotFound;
}
fn ServerResponse(comptime action: anytype) type {
@ -1029,65 +1035,62 @@ fn ServerResponse(comptime action: anytype) type {
}
fn FullResponse(comptime action: anytype) type {
return struct {
response: action.Response,
response_metadata: struct {
request_id: []u8,
},
parser_options: union(enum) {
json: json.ParseOptions,
xml: xml_shaper.ParseOptions,
},
raw_parsed: union(enum) {
pub const ResponseMetadata = struct {
request_id: []const u8,
};
pub const RawParsed = union(enum) {
server: ServerResponse(action),
raw: action.Response,
xml: xml_shaper.Parsed(action.Response),
},
allocator: std.mem.Allocator,
};
pub const FullResponseOptions = struct {
response: action.Response = undefined,
request_id: []const u8,
raw_parsed: RawParsed = .{ .raw = undefined },
arena: ArenaAllocator,
};
response: action.Response = undefined,
raw_parsed: RawParsed = .{ .raw = undefined },
response_metadata: ResponseMetadata,
arena: ArenaAllocator,
const Self = @This();
pub fn deinit(self: Self) void {
switch (self.raw_parsed) {
// Server is json only (so far)
.server => json.parseFree(ServerResponse(action), self.raw_parsed.server, self.parser_options.json),
// Raw is json only (so far)
.raw => json.parseFree(action.Response, self.raw_parsed.raw, self.parser_options.json),
.xml => |xml| xml.deinit(),
pub fn init(options: FullResponseOptions) !Self {
var arena = options.arena;
const request_id = try arena.allocator().dupe(u8, options.request_id);
return Self{
.arena = arena,
.response = options.response,
.raw_parsed = options.raw_parsed,
.response_metadata = .{
.request_id = request_id,
},
};
}
self.allocator.free(self.response_metadata.request_id);
const Response = @TypeOf(self.response);
if (@hasDecl(Response, "http_header")) {
inline for (std.meta.fields(@TypeOf(Response.http_header))) |f| {
safeFree(self.allocator, @field(self.response, f.name));
}
}
if (@hasDecl(Response, "http_payload")) {
const body_field = @field(self.response, Response.http_payload);
const BodyField = @TypeOf(body_field);
if (BodyField == []const u8) {
self.allocator.free(body_field);
}
if (BodyField == ?[]const u8) {
if (body_field) |f|
self.allocator.free(f);
}
}
pub fn deinit(self: Self) void {
self.arena.deinit();
}
};
}
fn safeFree(allocator: std.mem.Allocator, obj: anytype) void {
fn safeFree(allocator: Allocator, obj: anytype) void {
switch (@typeInfo(@TypeOf(obj))) {
.pointer => allocator.free(obj),
.optional => if (obj) |o| safeFree(allocator, o),
else => {},
}
}
fn queryFieldTransformer(allocator: std.mem.Allocator, field_name: []const u8) anyerror![]const u8 {
fn queryFieldTransformer(allocator: Allocator, field_name: []const u8) anyerror![]const u8 {
return try case.snakeToPascal(allocator, field_name);
}
fn buildPath(
allocator: std.mem.Allocator,
allocator: Allocator,
raw_uri: []const u8,
comptime ActionRequest: type,
request: anytype,
@ -1174,7 +1177,7 @@ fn uriEncodeByte(char: u8, writer: anytype, encode_slash: bool) !void {
}
}
fn buildQuery(allocator: std.mem.Allocator, request: anytype) ![]const u8 {
fn buildQuery(allocator: Allocator, request: anytype) ![]const u8 {
// query should look something like this:
// pub const http_query = .{
// .master_region = "MasterRegion",
@ -1296,7 +1299,7 @@ pub fn IgnoringWriter(comptime WriterType: type) type {
}
fn reportTraffic(
allocator: std.mem.Allocator,
allocator: Allocator,
info: []const u8,
request: awshttp.HttpRequest,
response: awshttp.HttpResult,
@ -1498,7 +1501,7 @@ test "basic json request serialization" {
// for a boxed member with no observable difference." But we're
// seeing a lot of differences here between spec and reality
//
var nameAllocator = std.heap.ArenaAllocator.init(allocator);
var nameAllocator = ArenaAllocator.init(allocator);
defer nameAllocator.deinit();
try json.stringify(request, .{ .whitespace = .{} }, buffer.writer());
try std.testing.expectEqualStrings(
@ -1582,8 +1585,8 @@ test {
std.testing.refAllDecls(xml_shaper);
}
const TestOptions = struct {
allocator: std.mem.Allocator,
arena: ?*std.heap.ArenaAllocator = null,
allocator: Allocator,
arena: ?*ArenaAllocator = null,
server_port: ?u16 = null,
server_remaining_requests: usize = 1,
server_response: []const u8 = "unset",
@ -1672,8 +1675,8 @@ const TestOptions = struct {
fn threadMain(options: *TestOptions) !void {
// https://github.com/ziglang/zig/blob/d2be725e4b14c33dbd39054e33d926913eee3cd4/lib/compiler/std-docs.zig#L22-L54
options.arena = try options.allocator.create(std.heap.ArenaAllocator);
options.arena.?.* = std.heap.ArenaAllocator.init(options.allocator);
options.arena = try options.allocator.create(ArenaAllocator);
options.arena.?.* = ArenaAllocator.init(options.allocator);
const allocator = options.arena.?.allocator();
options.allocator = allocator;
@ -1684,7 +1687,7 @@ fn threadMain(options: *TestOptions) !void {
options.test_server_runtime_uri = try std.fmt.allocPrint(options.allocator, "http://127.0.0.1:{d}", .{options.server_port.?});
log.debug("server listening at {s}", .{options.test_server_runtime_uri.?});
log.info("starting server thread, tid {d}", .{std.Thread.getCurrentId()});
// var arena = std.heap.ArenaAllocator.init(options.allocator);
// var arena = ArenaAllocator.init(options.allocator);
// defer arena.deinit();
// var aa = arena.allocator();
// We're in control of all requests/responses, so this flag will tell us
@ -1764,7 +1767,7 @@ fn serveRequest(options: *TestOptions, request: *std.http.Server.Request) !void
////////////////////////////////////////////////////////////////////////
const TestSetup = struct {
allocator: std.mem.Allocator,
allocator: Allocator,
request_options: TestOptions,
server_thread: std.Thread = undefined,
creds: aws_auth.Credentials = undefined,