diff --git a/codegen/src/main.zig b/codegen/src/main.zig index 0108a2d..9abca93 100644 --- a/codegen/src/main.zig +++ b/codegen/src/main.zig @@ -340,12 +340,32 @@ fn generateSimpleTypeFor(_: anytype, type_name: []const u8, writer: anytype, all fn generateComplexTypeFor(allocator: *std.mem.Allocator, members: []smithy.TypeMember, type_type_name: []const u8, shapes: anytype, writer: anytype, prefix: []const u8, all_required: bool, type_stack: anytype) anyerror!void { const Mapping = struct { snake: []const u8, json: []const u8 }; - var mappings = try std.ArrayList(Mapping).initCapacity(allocator, members.len); + var json_field_name_mappings = try std.ArrayList(Mapping).initCapacity(allocator, members.len); defer { - for (mappings.items) |mapping| { + for (json_field_name_mappings.items) |mapping| { allocator.free(mapping.snake); } - mappings.deinit(); + json_field_name_mappings.deinit(); + } + // There is an httpQueryParams trait as well, but nobody is using it. API GW + // pretends to, but it's an empty map + // + // Same with httpPayload + // + // httpLabel is interesting - right now we just assume anything can be used - do we need to track this? + var http_query_mappings = try std.ArrayList(Mapping).initCapacity(allocator, members.len); + defer { + for (http_query_mappings.items) |mapping| { + allocator.free(mapping.snake); + } + http_query_mappings.deinit(); + } + var http_header_mappings = try std.ArrayList(Mapping).initCapacity(allocator, members.len); + defer { + for (http_header_mappings.items) |mapping| { + allocator.free(mapping.snake); + } + http_header_mappings.deinit(); } // prolog. We'll rely on caller to get the spacing correct here _ = try writer.write(type_type_name); @@ -359,15 +379,20 @@ fn generateComplexTypeFor(allocator: *std.mem.Allocator, members: []smithy.TypeM // in API Gateway. Not sure what we're supposed to do there. Checking the go // sdk, they move this particular duplicate to 'http_method' - not sure yet // if this is a hard-coded exception` - var found_trait = false; + var found_name_trait = false; for (member.traits) |trait| { - if (trait == .json_name) { - found_trait = true; - mappings.appendAssumeCapacity(.{ .snake = try allocator.dupe(u8, snake_case_member), .json = trait.json_name }); + switch (trait) { + .json_name => { + found_name_trait = true; + json_field_name_mappings.appendAssumeCapacity(.{ .snake = try allocator.dupe(u8, snake_case_member), .json = trait.json_name }); + }, + .http_query => http_query_mappings.appendAssumeCapacity(.{ .snake = try allocator.dupe(u8, snake_case_member), .json = trait.http_query }), + .http_header => http_header_mappings.appendAssumeCapacity(.{ .snake = try allocator.dupe(u8, snake_case_member), .json = trait.http_header }), + else => {}, } } - if (!found_trait) - mappings.appendAssumeCapacity(.{ .snake = try allocator.dupe(u8, snake_case_member), .json = member.name }); + if (!found_name_trait) + json_field_name_mappings.appendAssumeCapacity(.{ .snake = try allocator.dupe(u8, snake_case_member), .json = member.name }); defer allocator.free(snake_case_member); try writer.print("{s} {s}: ", .{ prefix, avoidReserved(snake_case_member) }); if (!all_required) try writeOptional(member.traits, writer, null); @@ -377,6 +402,20 @@ fn generateComplexTypeFor(allocator: *std.mem.Allocator, members: []smithy.TypeM _ = try writer.write(",\n"); } + // Add in http query metadata (only relevant to REST JSON APIs - do we care? + // pub const http_query = .{ + // .master_region = "MasterRegion", + // .function_version = "FunctionVersion", + // .marker = "Marker", + // .max_items = "MaxItems", + // }; + var constprefix = try std.fmt.allocPrint(allocator, "{s} ", .{prefix}); + defer allocator.free(constprefix); + if (http_query_mappings.items.len > 0) _ = try writer.write("\n"); + try writeMappings(constprefix, "pub ", "http_query", http_query_mappings, writer); + if (http_query_mappings.items.len > 0 and http_header_mappings.items.len > 0) _ = try writer.write("\n"); + try writeMappings(constprefix, "pub ", "http_header", http_header_mappings, writer); + // Add in json mappings. The function looks like this: // // pub fn jsonFieldNameFor(_: @This(), comptime field_name: []const u8) []const u8 { @@ -387,21 +426,26 @@ fn generateComplexTypeFor(allocator: *std.mem.Allocator, members: []smithy.TypeM // return @field(mappings, field_name); // } // - // TODO: There is a smithy trait that will specify the json name. We should be using - // this instead if applicable. + var fieldnameprefix = try std.fmt.allocPrint(allocator, "{s} ", .{prefix}); + defer allocator.free(fieldnameprefix); try writer.print("\n{s} pub fn jsonFieldNameFor(_: @This(), comptime field_name: []const u8) []const u8 ", .{prefix}); _ = try writer.write("{\n"); - try writer.print("{s} const mappings = .", .{prefix}); - _ = try writer.write("{\n"); - for (mappings.items) |mapping| { - try writer.print("{s} .{s} = \"{s}\",\n", .{ prefix, avoidReserved(mapping.snake), mapping.json }); - } - _ = try writer.write(prefix); - _ = try writer.write(" };\n"); + try writeMappings(fieldnameprefix, "", "mappings", json_field_name_mappings, writer); try writer.print("{s} return @field(mappings, field_name);\n{s}", .{ prefix, prefix }); _ = try writer.write(" }\n"); } +fn writeMappings(prefix: []const u8, @"pub": []const u8, mapping_name: []const u8, mappings: anytype, writer: anytype) !void { + if (mappings.items.len == 0) return; + try writer.print("{s}{s}const {s} = .", .{ prefix, @"pub", mapping_name }); + _ = try writer.write("{\n"); + for (mappings.items) |mapping| { + try writer.print("{s} .{s} = \"{s}\",\n", .{ prefix, avoidReserved(mapping.snake), mapping.json }); + } + _ = try writer.write(prefix); + _ = try writer.write("};\n"); +} + fn writeOptional(traits: ?[]smithy.Trait, writer: anytype, value: ?[]const u8) !void { if (traits) |ts| { for (ts) |t| diff --git a/smithy/src/smithy.zig b/smithy/src/smithy.zig index 44411e3..636124b 100644 --- a/smithy/src/smithy.zig +++ b/smithy/src/smithy.zig @@ -92,6 +92,9 @@ pub const TraitType = enum { aws_protocol, ec2_query_name, http, + http_header, + http_label, + http_query, json_name, required, documentation, @@ -120,6 +123,9 @@ pub const Trait = union(TraitType) { uri: []const u8, code: i64 = 200, }, + http_header: []const u8, + http_label: []const u8, + http_query: []const u8, required: struct {}, documentation: []const u8, pattern: []const u8, @@ -559,6 +565,10 @@ fn getTrait(trait_type: []const u8, value: std.json.Value) SmithyParseError!?Tra } if (std.mem.eql(u8, trait_type, "smithy.api#jsonName")) return Trait{ .json_name = value.String }; + if (std.mem.eql(u8, trait_type, "smithy.api#httpQuery")) + return Trait{ .http_query = value.String }; + if (std.mem.eql(u8, trait_type, "smithy.api#httpHeader")) + return Trait{ .http_header = value.String }; // TODO: Maybe care about these traits? if (std.mem.eql(u8, trait_type, "smithy.api#title")) @@ -583,14 +593,11 @@ fn getTrait(trait_type: []const u8, value: std.json.Value) SmithyParseError!?Tra \\smithy.api#eventPayload \\smithy.api#externalDocumentation \\smithy.api#hostLabel - \\smithy.api#http \\smithy.api#httpError \\smithy.api#httpChecksumRequired - \\smithy.api#httpHeader \\smithy.api#httpLabel \\smithy.api#httpPayload \\smithy.api#httpPrefixHeaders - \\smithy.api#httpQuery \\smithy.api#httpQueryParams \\smithy.api#httpResponseCode \\smithy.api#idempotencyToken diff --git a/src/aws.zig b/src/aws.zig index fbb9810..ce6da9d 100644 --- a/src/aws.zig +++ b/src/aws.zig @@ -11,6 +11,7 @@ const log = std.log.scoped(.aws); pub const Options = struct { region: []const u8 = "aws-global", dualstack: bool = false, + success_http_code: i64 = 200, }; /// Using this constant may blow up build times. Recommed using Services() @@ -65,12 +66,45 @@ pub const Aws = struct { switch (service_meta.aws_protocol) { .query => return self.callQuery(request, service_meta, action, options), // .query, .ec2_query => return self.callQuery(request, service_meta, action, options), - .rest_json_1, .json_1_0, .json_1_1 => return self.callJson(request, service_meta, action, options), + .json_1_0, .json_1_1 => return self.callJson(request, service_meta, action, options), + .rest_json_1 => return self.callRestJson(request, service_meta, action, options), .ec2_query, .rest_xml => @compileError("XML responses may be blocked on a zig compiler bug scheduled to be fixed in 0.9.0"), } } - /// Calls using one of the json protocols (rest_json_1, json_1_0, json_1_1 + /// Rest Json is the most complex and so we handle this seperately + fn callRestJson(self: Self, comptime request: anytype, comptime service_meta: anytype, action: anytype, options: Options) !FullResponse(request) { + const Action = @TypeOf(action); + var aws_request: awshttp.HttpRequest = .{ + .method = Action.http_config.method, + .content_type = "application/json", + .path = Action.http_config.uri, + }; + + log.debug("Rest JSON v1 method: {s}", .{aws_request.method}); + log.debug("Rest JSON v1 success code: {d}", .{Action.http_config.success_code}); + log.debug("Rest JSON v1 raw uri: {s}", .{Action.http_config.uri}); + + aws_request.query = try buildQuery(self.allocator, request); + log.debug("Rest JSON v1 query: {s}", .{aws_request.query}); + defer self.allocator.free(aws_request.query); + // We don't know if we need a body...guessing here, this should cover most + var buffer = std.ArrayList(u8).init(self.allocator); + defer buffer.deinit(); + var nameAllocator = std.heap.ArenaAllocator.init(self.allocator); + defer nameAllocator.deinit(); + if (std.mem.eql(u8, "PUT", aws_request.method) or std.mem.eql(u8, "POST", aws_request.method)) { + try json.stringify(request, .{ .whitespace = .{} }, buffer.writer()); + } + + return try self.callAws(request, service_meta, aws_request, .{ + .success_http_code = Action.http_config.success_code, + .region = options.region, + .dualstack = options.dualstack, + }); + } + + /// Calls using one of the json protocols (json_1_0, json_1_1) fn callJson(self: Self, comptime request: anytype, comptime service_meta: anytype, action: anytype, options: Options) !FullResponse(request) { const target = try std.fmt.allocPrint(self.allocator, "{s}.{s}", .{ @@ -97,7 +131,6 @@ pub const Aws = struct { var content_type: []const u8 = undefined; switch (service_meta.aws_protocol) { - .rest_json_1 => content_type = "application/json", .json_1_0 => content_type = "application/x-amz-json-1.0", .json_1_1 => content_type = "application/x-amz-json-1.1", else => unreachable, @@ -377,6 +410,53 @@ fn queryFieldTransformer(field_name: []const u8, encoding_options: url.EncodingO return try case.snakeToPascal(encoding_options.allocator.?, field_name); } +fn buildQuery(allocator: *std.mem.Allocator, comptime request: anytype) ![]const u8 { + // query should look something like this: + // pub const http_query = .{ + // .master_region = "MasterRegion", + // .function_version = "FunctionVersion", + // .marker = "Marker", + // }; + const query_arguments = @TypeOf(request).http_query; + var buffer = std.ArrayList(u8).init(allocator); + const writer = buffer.writer(); + defer buffer.deinit(); + var has_begun = false; + inline for (@typeInfo(@TypeOf(query_arguments)).Struct.fields) |arg| { + const val = @field(request, arg.name); + if (@typeInfo(@TypeOf(val)) == .Optional) { + if (val) |v| { + try addQueryArg(@field(query_arguments, arg.name), v, writer, !has_begun); + has_begun = true; + } + } else { + try addQueryArg(@field(query_arguments, arg.name), val, writer, !has_begun); + has_begun = true; + } + } + return buffer.toOwnedSlice(); +} + +fn addQueryArg(key: []const u8, value: anytype, writer: anytype, start: bool) !void { + if (start) + _ = try writer.write("?") + else + _ = try writer.write("&"); + // TODO: url escaping + try writer.print("{s}=", .{key}); + try json.stringify(value, .{}, writer); +} + +test "REST Json v1 builds proper queries" { + const allocator = std.testing.allocator; + const svs = Services(.{.lambda}){}; + const request = svs.lambda.list_functions.Request{ + .max_items = 1, + }; + const query = try buildQuery(allocator, request); + defer allocator.free(query); + try std.testing.expectEqualStrings("?MaxItems=1", query); +} test "basic json request serialization" { const allocator = std.testing.allocator; const svs = Services(.{.dynamo_db}){}; diff --git a/src/main.zig b/src/main.zig index 9fa87ec..b07279d 100644 --- a/src/main.zig +++ b/src/main.zig @@ -31,6 +31,8 @@ const Tests = enum { json_1_0_query_no_input, json_1_1_query_with_input, json_1_1_query_no_input, + rest_json_1_query_no_input, + rest_json_1_query_with_input, }; pub fn main() anyerror!void { @@ -69,53 +71,67 @@ pub fn main() anyerror!void { var client = aws.Aws.init(allocator); defer client.deinit(); - const services = aws.Services(.{ .sts, .ec2, .dynamo_db, .ecs }){}; + const services = aws.Services(.{ .sts, .ec2, .dynamo_db, .ecs, .lambda }){}; for (tests.items) |t| { std.log.info("===== Start Test: {s} =====", .{@tagName(t)}); switch (t) { .query_no_input => { - const resp = try client.call(services.sts.get_caller_identity.Request{}, options); - defer resp.deinit(); - std.log.info("arn: {s}", .{resp.response.arn}); - std.log.info("id: {s}", .{resp.response.user_id}); - std.log.info("account: {s}", .{resp.response.account}); - std.log.info("requestId: {s}", .{resp.response_metadata.request_id}); + const call = try client.call(services.sts.get_caller_identity.Request{}, options); + defer call.deinit(); + std.log.info("arn: {s}", .{call.response.arn}); + std.log.info("id: {s}", .{call.response.user_id}); + std.log.info("account: {s}", .{call.response.account}); + std.log.info("requestId: {s}", .{call.response_metadata.request_id}); }, .query_with_input => { // TODO: Find test without sensitive info - const access = try client.call(services.sts.get_session_token.Request{ + const call = try client.call(services.sts.get_session_token.Request{ .duration_seconds = 900, }, options); - defer access.deinit(); - std.log.info("access key: {s}", .{access.response.credentials.?.access_key_id}); + defer call.deinit(); + std.log.info("call key: {s}", .{call.response.credentials.?.access_key_id}); }, .json_1_0_query_with_input => { - const tables = try client.call(services.dynamo_db.list_tables.Request{ + const call = try client.call(services.dynamo_db.list_tables.Request{ .limit = 1, }, options); - defer tables.deinit(); - std.log.info("request id: {s}", .{tables.response_metadata.request_id}); - std.log.info("account has tables: {b}", .{tables.response.table_names.?.len > 0}); + defer call.deinit(); + std.log.info("request id: {s}", .{call.response_metadata.request_id}); + std.log.info("account has call: {b}", .{call.response.table_names.?.len > 0}); }, .json_1_0_query_no_input => { - const limits = try client.call(services.dynamo_db.describe_limits.Request{}, options); - defer limits.deinit(); - std.log.info("account read capacity limit: {d}", .{limits.response.account_max_read_capacity_units}); + const call = try client.call(services.dynamo_db.describe_limits.Request{}, options); + defer call.deinit(); + std.log.info("account read capacity limit: {d}", .{call.response.account_max_read_capacity_units}); }, .json_1_1_query_with_input => { - const clusters = try client.call(services.ecs.list_clusters.Request{ + const call = try client.call(services.ecs.list_clusters.Request{ .max_results = 1, }, options); - defer clusters.deinit(); - std.log.info("request id: {s}", .{clusters.response_metadata.request_id}); - std.log.info("account has clusters: {b}", .{clusters.response.cluster_arns.?.len > 0}); + defer call.deinit(); + std.log.info("request id: {s}", .{call.response_metadata.request_id}); + std.log.info("account has call: {b}", .{call.response.cluster_arns.?.len > 0}); }, .json_1_1_query_no_input => { - const clusters = try client.call(services.ecs.list_clusters.Request{}, options); - defer clusters.deinit(); - std.log.info("request id: {s}", .{clusters.response_metadata.request_id}); - std.log.info("account has clusters: {b}", .{clusters.response.cluster_arns.?.len > 0}); + const call = try client.call(services.ecs.list_clusters.Request{}, options); + defer call.deinit(); + std.log.info("request id: {s}", .{call.response_metadata.request_id}); + std.log.info("account has call: {b}", .{call.response.cluster_arns.?.len > 0}); + }, + .rest_json_1_query_with_input => { + const call = try client.call(services.lambda.list_functions.Request{ + .max_items = 1, + }, options); + defer call.deinit(); + std.log.info("request id: {s}", .{call.response_metadata.request_id}); + std.log.info("account has call: {b}", .{call.response.functions.?.len > 0}); + }, + .rest_json_1_query_no_input => { + const call = try client.call(services.lambda.list_functions.Request{}, options); + defer call.deinit(); + std.log.info("request id: {s}", .{call.response_metadata.request_id}); + std.log.info("account has call: {b}", .{call.response.functions.?.len > 0}); }, .ec2_query_no_input => { std.log.err("EC2 Test disabled due to compiler bug", .{});