diff --git a/src/aws.zig b/src/aws.zig index 7fbc46d..cad5b79 100644 --- a/src/aws.zig +++ b/src/aws.zig @@ -102,10 +102,9 @@ pub fn Request(comptime action: anytype) type { .method = Action.http_config.method, .content_type = "application/json", .path = Action.http_config.uri, + .headers = try headersFor(options.client.allocator, request), }; - if (Self.service_meta.aws_protocol == .rest_xml) { - aws_request.content_type = "application/xml"; - } + defer freeHeadersFor(options.client.allocator, request, aws_request.headers); log.debug("Rest method: '{s}'", .{aws_request.method}); log.debug("Rest success code: '{d}'", .{Action.http_config.success_code}); @@ -113,7 +112,25 @@ pub fn Request(comptime action: anytype) type { aws_request.path = try buildPath(options.client.allocator, Action.http_config.uri, ActionRequest, request); defer options.client.allocator.free(aws_request.path); log.debug("Rest processed uri: '{s}'", .{aws_request.path}); + // TODO: Make sure this doesn't get escaped here for S3 aws_request.query = try buildQuery(options.client.allocator, request); + if (aws_request.query.len == 0) { + if (std.mem.indexOf(u8, aws_request.path, "?")) |inx| { + log.debug("Detected query in path. Adjusting", .{}); + // Sometimes (looking at you, s3), the uri in the model + // has a query string shoved into it. If that's the case, + // we need to parse and straighten this all out + const orig_path = aws_request.path; // save as we'll need to dealloc + const orig_query = aws_request.query; // save as we'll need to dealloc + // We need to chop the query off because apparently the other one whacks the + // query string. TODO: RTFM on zig to figure out why + aws_request.query = try options.client.allocator.dupe(u8, aws_request.path[inx..]); + aws_request.path = try options.client.allocator.dupe(u8, aws_request.path[0..inx]); + log.debug("inx: {d}\n\tnew path: {s}\n\tnew query: {s}", .{ inx, aws_request.path, aws_request.query }); + options.client.allocator.free(orig_path); + options.client.allocator.free(orig_query); + } + } log.debug("Rest query: '{s}'", .{aws_request.query}); defer options.client.allocator.free(aws_request.query); // We don't know if we need a body...guessing here, this should cover most @@ -126,12 +143,20 @@ pub fn Request(comptime action: anytype) type { try json.stringify(request, .{ .whitespace = .{} }, buffer.writer()); } } + aws_request.body = buffer.items; if (Self.service_meta.aws_protocol == .rest_xml) { if (std.mem.eql(u8, "PUT", aws_request.method) or std.mem.eql(u8, "POST", aws_request.method)) { - return error.NotImplemented; + if (@hasDecl(ActionRequest, "http_payload")) { + // We will assign the body to the value of the field denoted by + // the http_payload declaration on the request type. + // Hopefully these will always be ?[]const u8, otherwise + // we should see a compile error on this line + aws_request.body = @field(request, ActionRequest.http_payload).?; + } else { + return error.NotImplemented; + } } } - aws_request.body = buffer.items; return try Self.callAws(aws_request, .{ .success_http_code = Action.http_config.success_code, @@ -238,8 +263,58 @@ pub fn Request(comptime action: anytype) type { return error.HttpFailure; } + var fullResponse = try getFullResponseFromBody(aws_request, response, options); + // Fill in any fields that require a header. Note doing it post-facto + // assumes all response header fields are optional, which may be incorrect + if (@hasDecl(action.Response, "http_header")) { + inline for (std.meta.fields(@TypeOf(action.Response.http_header))) |f| { + const header_name = @field(action.Response.http_header, f.name); + for (response.headers) |h| { + if (std.ascii.eqlIgnoreCase(h.name, header_name)) { + log.debug("Response header {s} configured for field. Setting {s} = {s}", .{ h.name, f.name, h.value }); + const field_type = @TypeOf(@field(fullResponse.response, f.name)); + // TODO: Fix this. We need to make this much more robust + // The deal is we have to do the dupe though + // Also, this is a memory leak atm + if (field_type == ?[]const u8) { + @field(fullResponse.response, f.name) = try options.client.allocator.dupe(u8, (try coerceFromString(field_type, h.value)).?); + } else { + @field(fullResponse.response, f.name) = try coerceFromString(field_type, h.value); + } + + break; + } + } + } + } + return fullResponse; + } + + fn getFullResponseFromBody(aws_request: awshttp.HttpRequest, response: awshttp.HttpResult, options: Options) !FullResponseType { + // First, we need to determine if we care about a response at all + // If the expected result has no fields, there's no sense in + // doing any more work. Let's bail early + var expected_body_field_len = std.meta.fields(action.Response).len; + if (@hasDecl(action.Response, "http_header")) + expected_body_field_len -= std.meta.fields(@TypeOf(action.Response.http_header)).len; + + // We don't care about the body if there are no fields we expect there... + if (std.meta.fields(action.Response).len == 0 or expected_body_field_len == 0) { + // ^^ This should be redundant, but is necessary. I suspect it's a compiler quirk + // + // Do we care if an unexpected body comes in? + return FullResponseType{ + .response = .{}, + .response_metadata = .{ + .request_id = try requestIdFromHeaders(aws_request, response, options), + }, + .parser_options = .{ .json = .{} }, + .raw_parsed = .{ .raw = .{} }, + .allocator = options.client.allocator, + }; + } const isJson = try isJsonResponse(response.headers); - if (!isJson) return try xmlReturn(options, response); + if (!isJson) return try xmlReturn(aws_request, options, response); return try jsonReturn(aws_request, options, response); } @@ -252,19 +327,6 @@ pub fn Request(comptime action: anytype) type { .allow_missing_fields = false, // new option. Cannot yet handle non-struct fields though }; - // If the expected result has no fields, there's no sense in - // doing any more work. Let's bail early - if (std.meta.fields(action.Response).len == 0) // We don't care about the body if there are no fields - // Do we care if an unexpected body comes in? - return FullResponseType{ - .response = .{}, - .response_metadata = .{ - .request_id = try requestIdFromHeaders(aws_request, response, options), - }, - .parser_options = .{ .json = parser_options }, - .raw_parsed = .{ .raw = .{} }, - }; - // Get our possible response types. There are 3: // // 1. A result wrapped with metadata like request ID. This is ServerResponse(action) @@ -301,6 +363,7 @@ pub fn Request(comptime action: anytype) type { }, .parser_options = .{ .json = parser_options }, .raw_parsed = .{ .server = parsed_response }, + .allocator = options.client.allocator, }; } else { // Conditions 2 or 3 (no wrapping) @@ -311,11 +374,12 @@ pub fn Request(comptime action: anytype) type { }, .parser_options = .{ .json = parser_options }, .raw_parsed = .{ .raw = parsed_response }, + .allocator = options.client.allocator, }; } } - fn xmlReturn(options: Options, result: awshttp.HttpResult) !FullResponseType { + fn xmlReturn(request: awshttp.HttpRequest, options: Options, result: awshttp.HttpResult) !FullResponseType { // Server shape be all like: // // @@ -353,42 +417,13 @@ pub fn Request(comptime action: anytype) type { defer if (free_body) options.client.allocator.free(body); const parsed = try xml_shaper.parse(action.Response, body, xml_options); errdefer parsed.deinit(); - var free_rid = false; // This needs to get into FullResponseType somehow: defer parsed.deinit(); const request_id = blk: { if (parsed.document.root.getCharData("requestId")) |elem| - break :blk elem; - var rid: ?[]const u8 = null; - // This "thing" is called: - // * Host ID - // * Extended Request ID - // * Request ID 2 - // - // I suspect it identifies the S3 frontend server and they are - // trying to obscure that fact. But several SDKs go with host id, - // so we'll use that - var host_id: ?[]const u8 = null; - for (result.headers) |header| { - if (std.ascii.eqlIgnoreCase(header.name, "x-amzn-requestid")) { // CloudFront - rid = header.value; - } - if (std.ascii.eqlIgnoreCase(header.name, "x-amz-request-id")) { // S3 - rid = header.value; - } - if (std.ascii.eqlIgnoreCase(header.name, "x-amz-id-2")) { // S3 - host_id = header.value; - } - } - if (rid) |r| { - if (host_id) |h| { - free_rid = true; - break :blk try std.fmt.allocPrint(options.client.allocator, "{s}, host_id: {s}", .{ r, h }); - } - break :blk r; - } - return error.RequestIdNotFound; + break :blk try options.client.allocator.dupe(u8, elem); + break :blk try requestIdFromHeaders(request, result, options); }; - defer if (free_rid) options.client.allocator.free(request_id); + defer options.client.allocator.free(request_id); return FullResponseType{ .response = parsed.parsed_value, @@ -397,6 +432,7 @@ pub fn Request(comptime action: anytype) type { }, .parser_options = .{ .xml = xml_options }, .raw_parsed = .{ .xml = parsed }, + .allocator = options.client.allocator, }; } const ServerResponseTypes = struct { @@ -532,6 +568,63 @@ pub fn Request(comptime action: anytype) type { }; } +fn coerceFromString(comptime T: type, val: []const u8) !T { + if (@typeInfo(T) == .Optional) return try coerceFromString(@typeInfo(T).Optional.child, val); + // TODO: This is terrible...fix it + switch (T) { + bool => return std.ascii.eqlIgnoreCase(val, "true"), + i64 => return try std.fmt.parseInt(T, val, 10), + else => return val, + } +} + +fn generalAllocPrint(allocator: std.mem.Allocator, val: anytype) !?[]const u8 { + switch (@typeInfo(@TypeOf(val))) { + .Optional => if (val) |v| return generalAllocPrint(allocator, v) else return null, + .Array, .Pointer => return try std.fmt.allocPrint(allocator, "{s}", .{val}), + else => return try std.fmt.allocPrint(allocator, "{any}", .{val}), + } +} +fn headersFor(allocator: std.mem.Allocator, request: anytype) ![]awshttp.Header { + log.debug("Checking for headers to include for type {s}", .{@TypeOf(request)}); + if (!@hasDecl(@TypeOf(request), "http_header")) return &[_]awshttp.Header{}; + const http_header = @TypeOf(request).http_header; + const fields = std.meta.fields(@TypeOf(http_header)); + log.debug("Found {d} possible custom headers", .{fields.len}); + // It would be awesome to have a fixed array, but we can't because + // it depends on a runtime value based on whether these variables are null + var headers = try std.ArrayList(awshttp.Header).initCapacity(allocator, fields.len); + inline for (fields) |f| { + // Header name = value of field + // Header value = value of the field of the request based on field name + const val = @field(request, f.name); + const final_val: ?[]const u8 = try generalAllocPrint(allocator, val); + if (final_val) |v| { + headers.appendAssumeCapacity(.{ + .name = @field(http_header, f.name), + .value = v, + }); + } + } + return headers.toOwnedSlice(); +} + +fn freeHeadersFor(allocator: std.mem.Allocator, request: anytype, headers: []awshttp.Header) void { + if (!@hasDecl(@TypeOf(request), "http_header")) return; + const http_header = @TypeOf(request).http_header; + const fields = std.meta.fields(@TypeOf(http_header)); + inline for (fields) |f| { + const header_name = @field(http_header, f.name); + for (headers) |h| { + if (std.mem.eql(u8, h.name, header_name)) { + allocator.free(h.value); + break; + } + } + } + allocator.free(headers); +} + fn firstJsonKey(data: []const u8) []const u8 { const start = std.mem.indexOf(u8, data, "\"") orelse 0; // Should never be 0 if (start == 0) log.warn("Response body missing json key?!", .{}); @@ -572,19 +665,31 @@ fn isJsonResponse(headers: []awshttp.Header) !bool { } /// Get request ID from headers. Caller responsible for freeing memory fn requestIdFromHeaders(request: awshttp.HttpRequest, response: awshttp.HttpResult, options: Options) ![]u8 { - var request_id: []u8 = undefined; - var found = false; - for (response.headers) |h| { - if (std.ascii.eqlIgnoreCase(h.name, "X-Amzn-RequestId")) { - found = true; - request_id = try std.fmt.allocPrint(options.client.allocator, "{s}", .{h.value}); // will be freed in FullR.deinit() - } + var rid: ?[]const u8 = null; + // This "thing" is called: + // * Host ID + // * Extended Request ID + // * Request ID 2 + // + // I suspect it identifies the S3 frontend server and they are + // trying to obscure that fact. But several SDKs go with host id, + // so we'll use that + var host_id: ?[]const u8 = null; + for (response.headers) |header| { + if (std.ascii.eqlIgnoreCase(header.name, "x-amzn-requestid")) // CloudFront + rid = header.value; + if (std.ascii.eqlIgnoreCase(header.name, "x-amz-request-id")) // S3 + rid = header.value; + if (std.ascii.eqlIgnoreCase(header.name, "x-amz-id-2")) // S3 + host_id = header.value; } - if (!found) { - try reportTraffic(options.client.allocator, "Request ID not found", request, response, log.err); - return error.RequestIdNotFound; + if (rid) |r| { + if (host_id) |h| + return try std.fmt.allocPrint(options.client.allocator, "{s}, host_id: {s}", .{ r, h }); + return try options.client.allocator.dupe(u8, r); } - return request_id; + try reportTraffic(options.client.allocator, "Request ID not found", request, response, log.err); + return error.RequestIdNotFound; } fn ServerResponse(comptime action: anytype) type { const T = action.Response; @@ -649,6 +754,7 @@ fn FullResponse(comptime action: anytype) type { raw: action.Response, xml: xml_shaper.Parsed(action.Response), }, + allocator: std.mem.Allocator, const Self = @This(); pub fn deinit(self: Self) void { @@ -660,12 +766,21 @@ fn FullResponse(comptime action: anytype) type { .xml => |xml| xml.deinit(), } - var allocator: std.mem.Allocator = undefined; - switch (self.parser_options) { - .json => |j| allocator = j.allocator.?, - .xml => |x| allocator = x.allocator.?, + self.allocator.free(self.response_metadata.request_id); + const Response = @TypeOf(self.response); + if (@hasDecl(Response, "http_header")) { + inline for (std.meta.fields(@TypeOf(Response.http_header))) |f| { + const field_type = @TypeOf(@field(self.response, f.name)); + // TODO: Fix this. We need to make this much more robust + // The deal is we have to do the dupe though + // Also, this is a memory leak atm + if (field_type == ?[]const u8) { + if (@field(self.response, f.name) != null) { + self.allocator.free(@field(self.response, f.name).?); + } + } + } } - allocator.free(self.response_metadata.request_id); } }; } diff --git a/src/aws_http.zig b/src/aws_http.zig index f8215dc..3cc2d07 100644 --- a/src/aws_http.zig +++ b/src/aws_http.zig @@ -150,6 +150,7 @@ pub const AwsHttp = struct { log.debug("Request Path: {s}", .{request_cp.path}); log.debug("Endpoint Path (actually used): {s}", .{endpoint.path}); log.debug("Query: {s}", .{request_cp.query}); + log.debug("Request additional header count: {d}", .{request_cp.headers.len}); log.debug("Method: {s}", .{request_cp.method}); log.debug("body length: {d}", .{request_cp.body.len}); log.debug("Body\n====\n{s}\n====", .{request_cp.body}); @@ -240,10 +241,18 @@ fn getRegion(service: []const u8, region: []const u8) []const u8 { } fn addHeaders(allocator: std.mem.Allocator, headers: *std.ArrayList(base.Header), host: []const u8, body: []const u8, content_type: []const u8, additional_headers: []Header) !?[]const u8 { + var has_content_type = false; + for (additional_headers) |h| { + if (std.ascii.eqlIgnoreCase(h.name, "Content-Type")) { + has_content_type = true; + break; + } + } try headers.append(.{ .name = "Accept", .value = "application/json" }); try headers.append(.{ .name = "Host", .value = host }); try headers.append(.{ .name = "User-Agent", .value = "zig-aws 1.0, Powered by the AWS Common Runtime." }); - try headers.append(.{ .name = "Content-Type", .value = content_type }); + if (!has_content_type) + try headers.append(.{ .name = "Content-Type", .value = content_type }); try headers.appendSlice(additional_headers); if (body.len > 0) { const len = try std.fmt.allocPrint(allocator, "{d}", .{body.len}); diff --git a/src/aws_signing.zig b/src/aws_signing.zig index 37361d8..6793e13 100644 --- a/src/aws_signing.zig +++ b/src/aws_signing.zig @@ -357,7 +357,9 @@ fn createCanonicalRequest(allocator: std.mem.Allocator, request: base.Request, p // TODO: This is all better as a writer - less allocations/copying const canonical_method = canonicalRequestMethod(request.method); - const canonical_url = try canonicalUri(allocator, request.path, true); // TODO: set false for s3 + // Let's not mess around here...s3 is the oddball + const double_encode = !std.mem.eql(u8, config.service, "s3"); + const canonical_url = try canonicalUri(allocator, request.path, double_encode); defer allocator.free(canonical_url); log.debug("final uri: {s}", .{canonical_url}); const canonical_query = try canonicalQueryString(allocator, request.query); @@ -408,8 +410,6 @@ fn canonicalUri(allocator: std.mem.Allocator, path: []const u8, double_encode: b // // For now, we will "Remove redundant and relative path components". This // doesn't apply to S3 anyway, and we'll make it the callers's problem - if (!double_encode) - return SigningError.S3NotImplemented; if (path.len == 0 or path[0] == '?' or path[0] == '#') return try allocator.dupe(u8, "/"); log.debug("encoding path: {s}", .{path}); diff --git a/src/main.zig b/src/main.zig index 621d0d8..5ab5a01 100644 --- a/src/main.zig +++ b/src/main.zig @@ -50,6 +50,7 @@ const Tests = enum { rest_json_1_work_with_lambda, rest_xml_no_input, rest_xml_anything_but_s3, + rest_xml_work_with_s3, }; pub fn main() anyerror!void {