diff --git a/build.zig b/build.zig index 2230d29..633e1c4 100644 --- a/build.zig +++ b/build.zig @@ -67,10 +67,11 @@ pub fn build(b: *Builder) !void { // Creates a step for unit testing. This only builds the test executable // but does not run it. const unit_tests = b.addTest(.{ - .root_source_file = .{ .path = "src/main.zig" }, + .root_source_file = .{ .path = "src/aws.zig" }, .target = target, .optimize = optimize, }); + unit_tests.addModule("smithy", smithy_dep.module("smithy")); const run_unit_tests = b.addRunArtifact(unit_tests); diff --git a/src/aws.zig b/src/aws.zig index d39ad25..2d4f01e 100644 --- a/src/aws.zig +++ b/src/aws.zig @@ -862,7 +862,7 @@ fn FullResponse(comptime action: anytype) type { // TODO: Fix this. We need to make this much more robust // The deal is we have to do the dupe though // Also, this is a memory leak atm - if (field_type == ?[]const u8) { + if (@typeInfo(field_type) == .Optional) { if (@field(self.response, f.name) != null) { self.allocator.free(@field(self.response, f.name).?); } @@ -1120,41 +1120,41 @@ fn reportTraffic(allocator: std.mem.Allocator, info: []const u8, request: awshtt reporter("{s}\n", .{msg.items}); } -// TODO: Where does this belong really? -fn typeForField(comptime T: type, field_name: []const u8) !type { - const ti = @typeInfo(T); - switch (ti) { - .Struct => { - inline for (ti.Struct.fields) |field| { - if (std.mem.eql(u8, field.name, field_name)) - return field.type; - } - }, - else => return error.TypeIsNotAStruct, // should not hit this - } - return error.FieldNotFound; -} - -test "custom serialization for map objects" { - const allocator = std.testing.allocator; - var buffer = std.ArrayList(u8).init(allocator); - defer buffer.deinit(); - var tags = try std.ArrayList(@typeInfo(try typeForField(services.lambda.tag_resource.Request, "tags")).Pointer.child).initCapacity(allocator, 2); - defer tags.deinit(); - tags.appendAssumeCapacity(.{ .key = "Foo", .value = "Bar" }); - tags.appendAssumeCapacity(.{ .key = "Baz", .value = "Qux" }); - const req = services.lambda.tag_resource.Request{ .resource = "hello", .tags = tags.items }; - try json.stringify(req, .{ .whitespace = .{} }, buffer.writer()); - try std.testing.expectEqualStrings( - \\{ - \\ "Resource": "hello", - \\ "Tags": { - \\ "Foo": "Bar", - \\ "Baz": "Qux" - \\ } - \\} - , buffer.items); -} +// // TODO: Where does this belong really? +// fn typeForField(comptime T: type, field_name: []const u8) !type { +// const ti = @typeInfo(T); +// switch (ti) { +// .Struct => { +// inline for (ti.Struct.fields) |field| { +// if (std.mem.eql(u8, field.name, field_name)) +// return field.type; +// } +// }, +// else => return error.TypeIsNotAStruct, // should not hit this +// } +// return error.FieldNotFound; +// } +// +// test "custom serialization for map objects" { +// const allocator = std.testing.allocator; +// var buffer = std.ArrayList(u8).init(allocator); +// defer buffer.deinit(); +// var tags = try std.ArrayList(@typeInfo(try typeForField(services.lambda.tag_resource.Request, "tags")).Pointer.child).initCapacity(allocator, 2); +// defer tags.deinit(); +// tags.appendAssumeCapacity(.{ .key = "Foo", .value = "Bar" }); +// tags.appendAssumeCapacity(.{ .key = "Baz", .value = "Qux" }); +// const req = services.lambda.tag_resource.Request{ .resource = "hello", .tags = tags.items }; +// try json.stringify(req, .{ .whitespace = .{} }, buffer.writer()); +// try std.testing.expectEqualStrings( +// \\{ +// \\ "Resource": "hello", +// \\ "Tags": { +// \\ "Foo": "Bar", +// \\ "Baz": "Qux" +// \\ } +// \\} +// , buffer.items); +// } test "REST Json v1 builds proper queries" { const allocator = std.testing.allocator; @@ -1261,7 +1261,7 @@ test "layer object only" { // const response = // \\ { // \\ "UncompressedCodeSize": 22599541, - // \\ "Arn": "arn:aws:lambda:us-west-2:550620852718:layer:PollyNotes-lib:4" + // \\ "Arn": "arn:aws:lambda:us-west-2:123456789012:layer:PollyNotes-lib:4" // \\ } // ; const allocator = std.testing.allocator; @@ -1298,4 +1298,182 @@ test "layer object only" { // const SResponse = ServerResponse(request); // const r = try json.parse(SResponse, &stream, parser_options); // json.parseFree(SResponse, r, parser_options); -// } + +//////////////////////////////////////////////////////////////////////// +// All code below this line is for testing +//////////////////////////////////////////////////////////////////////// + +test { + // To run nested container tests, either, call `refAllDecls` which will + // reference all declarations located in the given argument. + // `@This()` is a builtin function that returns the innermost container it is called from. + // In this example, the innermost container is this file (implicitly a struct). + // TODO: re-enable this + // std.testing.refAllDecls(@This()); + // std.testing.refAllDecls(config); + // std.testing.refAllDecls(interface); +} +const TestOptions = struct { + allocator: std.mem.Allocator, + server_port: ?u16 = null, + server_remaining_requests: usize = 1, + server_response: []const u8 = "unset", + server_response_headers: [][2][]const u8 = &[_][2][]const u8{}, + request_body: []u8 = "", + test_server_runtime_uri: ?[]u8 = null, + server_ready: bool = false, + + const Self = @This(); + + fn waitForReady(self: *Self) !void { + // While this doesn't return an error, we can use !void + // to prepare for addition of timeout + while (!self.server_ready) + std.time.sleep(100); + } + + fn deinit(self: Self) void { + if (self.request_body.len > 0) + self.allocator.free(self.request_body); + // if (self.test_server_runtime_uri) |_| + // self.allocator.free(self.test_server_runtime_uri.?); + } +}; + +/// This starts a test server. We're not testing the server itself, +/// so the main tests will start this thing up and create an arena around the +/// whole thing so we can just deallocate everything at once at the end, +/// leaks be damned +fn threadMain(options: *TestOptions) !void { + var server = std.http.Server.init(options.allocator, .{ .reuse_address = true }); + // defer server.deinit(); + + const address = try std.net.Address.parseIp("127.0.0.1", 0); + try server.listen(address); + options.server_port = server.socket.listen_address.in.getPort(); + + options.test_server_runtime_uri = try std.fmt.allocPrint(options.allocator, "http://127.0.0.1:{d}", .{options.server_port.?}); + defer options.allocator.free(options.test_server_runtime_uri.?); + log.debug("server listening at {s}", .{options.test_server_runtime_uri.?}); + defer server.deinit(); + log.info("starting server thread, tid {d}", .{std.Thread.getCurrentId()}); + // var arena = std.heap.ArenaAllocator.init(options.allocator); + // defer arena.deinit(); + // var aa = arena.allocator(); + // We're in control of all requests/responses, so this flag will tell us + // when it's time to shut down + while (options.server_remaining_requests > 0) { + options.server_remaining_requests -= 1; + processRequest(options, &server) catch |e| { + log.err("Unexpected error processing request: {any}", .{e}); + if (@errorReturnTrace()) |trace| { + std.debug.dumpStackTrace(trace.*); + } + }; + } +} + +fn processRequest(options: *TestOptions, server: *std.http.Server) !void { + options.server_ready = true; + errdefer options.server_ready = false; + log.debug( + "tid {d} (server): server waiting to accept. requests remaining: {d}", + .{ std.Thread.getCurrentId(), options.server_remaining_requests + 1 }, + ); + var res = try server.accept(.{ .allocator = options.allocator }); + options.server_ready = false; + defer res.deinit(); + defer _ = res.reset(); + try res.wait(); // wait for client to send a complete request head + + const errstr = "Internal Server Error\n"; + var errbuf: [errstr.len]u8 = undefined; + @memcpy(&errbuf, errstr); + var response_bytes: []const u8 = errbuf[0..]; + + if (res.request.content_length) |l| + options.request_body = try res.reader().readAllAlloc(options.allocator, @as(usize, l)); + + log.debug( + "tid {d} (server): {d} bytes read from request", + .{ std.Thread.getCurrentId(), options.request_body.len }, + ); + + // try response.headers.append("content-type", "text/plain"); + response_bytes = serve(options, &res) catch |e| brk: { + res.status = .internal_server_error; + // TODO: more about this particular request + log.err("Unexpected error from executor processing request: {any}", .{e}); + if (@errorReturnTrace()) |trace| { + std.debug.dumpStackTrace(trace.*); + } + break :brk "Unexpected error generating request to lambda"; + }; + res.transfer_encoding = .{ .content_length = response_bytes.len }; + try res.do(); + _ = try res.writer().writeAll(response_bytes); + try res.finish(); + log.debug( + "tid {d} (server): sent response", + .{std.Thread.getCurrentId()}, + ); +} + +fn serve(options: *TestOptions, res: *std.http.Server.Response) ![]const u8 { + for (options.server_response_headers) |h| + try res.headers.append(h[0], h[1]); + // try res.headers.append("content-length", try std.fmt.allocPrint(allocator, "{d}", .{server_response.len})); + return options.server_response; +} +const TestHeader = struct { + name: []const u8, + value: []const u8, +}; +test "sts get_caller_identity comptime" { + // std.testing.log_level = .debug; + const allocator = std.testing.allocator; + // [debug] (awshttp): x-amzn-RequestId: 8f0d54da-1230-40f7-b4ac-95015c4b84cd + // [debug] (awshttp): Content-Type: application/json + var requestOptions: TestOptions = .{ + .allocator = allocator, + .server_response = + \\{"GetCallerIdentityResponse":{"GetCallerIdentityResult":{"Account":"123456789012","Arn":"arn:aws:iam::123456789012:user/admin","UserId":"AIDAYAM4POHXHRVANDQBQ"},"ResponseMetadata":{"RequestId":"8f0d54da-1230-40f7-b4ac-95015c4b84cd"}}} + , + .server_response_headers = @constCast(&[_][2][]const u8{ + .{ "Content-Type", "application/json" }, + .{ "x-amzn-RequestId", "8f0d54da-1230-40f7-b4ac-95015c4b84cd" }, + }), + }; + defer requestOptions.deinit(); + // Needs to go away: .request_body: []u8 = "", + const server_thread = try std.Thread.spawn( + .{}, + threadMain, + .{&requestOptions}, + ); + try requestOptions.waitForReady(); + + awshttp.endpoint_override = requestOptions.test_server_runtime_uri; + var client = try Client.init(allocator, .{}); + const options = Options{ + .region = "us-west-2", + .client = client, + }; + defer client.deinit(); + const sts = (Services(.{.sts}){}).sts; + const call = try Request(sts.get_caller_identity).call(.{}, options); + // const call = try client.call(services.sts.get_caller_identity.Request{}, options); + defer call.deinit(); + + server_thread.join(); + try std.testing.expectEqualStrings( + \\Action=GetCallerIdentity&Version=2011-06-15 + , requestOptions.request_body); + try std.testing.expectEqualStrings( + "arn:aws:iam::123456789012:user/admin", + call.response.arn.?, + ); + try std.testing.expectEqualStrings("AIDAYAM4POHXHRVANDQBQ", call.response.user_id.?); + try std.testing.expectEqualStrings("123456789012", call.response.account.?); + try std.testing.expectEqualStrings("8f0d54da-1230-40f7-b4ac-95015c4b84cd", call.response_metadata.request_id); +} diff --git a/src/aws_http.zig b/src/aws_http.zig index 1cfa7f2..37bedb0 100644 --- a/src/aws_http.zig +++ b/src/aws_http.zig @@ -263,10 +263,14 @@ fn getEnvironmentVariable(allocator: std.mem.Allocator, key: []const u8) !?[]con }; } +/// override endpoint url. Intended for use in testing. Normally, you should +/// rely on AWS_ENDPOINT_URL environment variable for this +pub var endpoint_override: ?[]const u8 = null; + fn endpointForRequest(allocator: std.mem.Allocator, service: []const u8, request: HttpRequest, options: Options) !EndPoint { - const environment_override = try getEnvironmentVariable(allocator, "AWS_ENDPOINT_URL"); + const environment_override = endpoint_override orelse try getEnvironmentVariable(allocator, "AWS_ENDPOINT_URL"); if (environment_override) |override| { - const uri = try allocator.dupeZ(u8, override); + const uri = try allocator.dupe(u8, override); return endPointFromUri(allocator, uri); } // Fallback to us-east-1 if global endpoint does not exist.