From 934323acf1719d401b7c848ea35c7c6c74b83835 Mon Sep 17 00:00:00 2001 From: Simon Hartcher Date: Wed, 14 May 2025 15:29:21 +1000 Subject: [PATCH] refactor: generate types for maps TODO: handle json stringify for generated map types --- codegen/src/main.zig | 147 ++++++++++++++++++++++++------------------ lib/json/src/json.zig | 58 ++++++++--------- src/aws.zig | 2 +- src/main.zig | 2 +- 4 files changed, 113 insertions(+), 96 deletions(-) diff --git a/codegen/src/main.zig b/codegen/src/main.zig index 984e3e7..7caac18 100644 --- a/codegen/src/main.zig +++ b/codegen/src/main.zig @@ -467,7 +467,9 @@ fn generateAdditionalTypes(allocator: std.mem.Allocator, file_state: FileGenerat .allocator = allocator, .indent_level = 0, }; - const type_name = avoidReserved(t.name); + const type_name = try getTypeName(allocator, t); + defer allocator.free(type_name); + try writer.print("\npub const {s} = ", .{type_name}); try file_state.additional_types_generated.putNoClobber(t.name, {}); _ = try generateTypeFor(t.id, writer, state, true); @@ -637,6 +639,18 @@ fn endsWith(item: []const u8, str: []const u8) bool { return std.mem.eql(u8, item, str[str.len - item.len ..]); } +fn getTypeName(allocator: std.mem.Allocator, shape: smithy.ShapeInfo) ![]const u8 { + const type_name = avoidReserved(shape.name); + + switch (shape.shape) { + .map => { + const map_type_name = avoidReserved(shape.name); + return try std.fmt.allocPrint(allocator, "{s}KeyValue", .{map_type_name[0 .. map_type_name.len - 1]}); + }, + else => return allocator.dupe(u8, type_name), + } +} + fn reuseCommonType(shape: smithy.ShapeInfo, writer: anytype, state: GenerationState) !bool { // We want to return if we're at the top level of the stack. There are three // reasons for this: @@ -651,12 +665,21 @@ fn reuseCommonType(shape: smithy.ShapeInfo, writer: anytype, state: GenerationSt // can at least see the top level. // 3. When we come through at the end, we want to make sure we're writing // something or we'll have an infinite loop! + + switch (shape.shape) { + .structure, .uniontype, .map => {}, + else => return false, + } + + const type_name = try getTypeName(state.allocator, shape); + defer state.allocator.free(type_name); + if (state.type_stack.items.len == 1) return false; var rc = false; if (state.file_state.shape_references.get(shape.id)) |r| { - if (r > 1 and (shape.shape == .structure or shape.shape == .uniontype)) { + if (r > 1) { rc = true; - _ = try writer.write(avoidReserved(shape.name)); // This can't possibly be this easy... + _ = try writer.write(type_name); // This can't possibly be this easy... if (state.file_state.additional_types_generated.getEntry(shape.name) == null) try state.file_state.additional_types_to_generate.append(shape); } @@ -755,34 +778,14 @@ fn generateTypeFor(shape_id: []const u8, writer: anytype, state: GenerationState .double => |s| try generateSimpleTypeFor(s, "f64", writer), .float => |s| try generateSimpleTypeFor(s, "f32", writer), .long => |s| try generateSimpleTypeFor(s, "i64", writer), - .map => { - _ = try writer.write("[]struct {\n"); - var child_state = state; - child_state.indent_level += 1; - try outputIndent(child_state, writer); - _ = try writer.write("key: "); - try writeOptional(shape.map.traits, writer, null); - var sub_maps = std.ArrayList([]const u8).init(state.allocator); - defer sub_maps.deinit(); - if (try generateTypeFor(shape.map.key, writer, child_state, true)) - try sub_maps.append("key"); - try writeOptional(shape.map.traits, writer, " = null"); - _ = try writer.write(",\n"); - try outputIndent(child_state, writer); - _ = try writer.write("value: "); - try writeOptional(shape.map.traits, writer, null); - if (try generateTypeFor(shape.map.value, writer, child_state, true)) - try sub_maps.append("value"); - try writeOptional(shape.map.traits, writer, " = null"); - _ = try writer.write(",\n"); - if (sub_maps.items.len > 0) { - _ = try writer.write("\n"); - try writeStringify(state, sub_maps.items, writer); + .map => |m| { + if (!try reuseCommonType(shape_info, std.io.null_writer, state)) { + try generateMapTypeFor(m, writer, state); + rc = true; + } else { + try writer.writeAll("[]"); + _ = try reuseCommonType(shape_info, writer, state); } - try outputIndent(state, writer); - _ = try writer.write("}"); - - rc = true; }, else => { std.log.err("encountered unimplemented shape type {s} for shape_id {s}. Generated code will not compile", .{ @tagName(shape), shape_id }); @@ -793,41 +796,59 @@ fn generateTypeFor(shape_id: []const u8, writer: anytype, state: GenerationState return rc; } +fn generateMapTypeFor(map: anytype, writer: anytype, state: GenerationState) anyerror!void { + _ = try writer.write("struct {\n"); + + var child_state = state; + child_state.indent_level += 1; + + _ = try writer.write("key: "); + try writeOptional(map.traits, writer, null); + + _ = try generateTypeFor(map.key, writer, child_state, true); + + try writeOptional(map.traits, writer, " = null"); + _ = try writer.write(",\n"); + + _ = try writer.write("value: "); + try writeOptional(map.traits, writer, null); + + _ = try generateTypeFor(map.value, writer, child_state, true); + + try writeOptional(map.traits, writer, " = null"); + _ = try writer.write(",\n"); + _ = try writer.write("}"); +} + fn generateSimpleTypeFor(_: anytype, type_name: []const u8, writer: anytype) !void { _ = try writer.write(type_name); // This had required stuff but the problem was elsewhere. Better to leave as function just in case } + +const Mapping = struct { snake: []const u8, original: []const u8 }; fn generateComplexTypeFor(shape_id: []const u8, members: []smithy.TypeMember, type_type_name: []const u8, writer: anytype, state: GenerationState) anyerror!void { _ = shape_id; - const Mapping = struct { snake: []const u8, original: []const u8 }; - var field_name_mappings = try std.ArrayList(Mapping).initCapacity(state.allocator, members.len); - defer { - for (field_name_mappings.items) |mapping| - state.allocator.free(mapping.snake); - field_name_mappings.deinit(); - } + + var arena = std.heap.ArenaAllocator.init(state.allocator); + defer arena.deinit(); + const allocator = arena.allocator(); + + var field_name_mappings = try std.ArrayList(Mapping).initCapacity(allocator, members.len); + defer field_name_mappings.deinit(); // There is an httpQueryParams trait as well, but nobody is using it. API GW // pretends to, but it's an empty map // // Same with httpPayload // // httpLabel is interesting - right now we just assume anything can be used - do we need to track this? - var http_query_mappings = try std.ArrayList(Mapping).initCapacity(state.allocator, members.len); - defer { - for (http_query_mappings.items) |mapping| - state.allocator.free(mapping.snake); - http_query_mappings.deinit(); - } - var http_header_mappings = try std.ArrayList(Mapping).initCapacity(state.allocator, members.len); - defer { - for (http_header_mappings.items) |mapping| - state.allocator.free(mapping.snake); - http_header_mappings.deinit(); - } - var map_fields = std.ArrayList([]const u8).init(state.allocator); - defer { - for (map_fields.items) |f| state.allocator.free(f); - map_fields.deinit(); - } + var http_query_mappings = try std.ArrayList(Mapping).initCapacity(allocator, members.len); + defer http_query_mappings.deinit(); + + var http_header_mappings = try std.ArrayList(Mapping).initCapacity(allocator, members.len); + defer http_header_mappings.deinit(); + + var map_fields = std.ArrayList([]const u8).init(allocator); + defer map_fields.deinit(); + // prolog. We'll rely on caller to get the spacing correct here _ = try writer.write(type_type_name); _ = try writer.write(" {\n"); @@ -836,7 +857,7 @@ fn generateComplexTypeFor(shape_id: []const u8, members: []smithy.TypeMember, ty var payload: ?[]const u8 = null; for (members) |member| { // This is our mapping - const snake_case_member = try constantName(state.allocator, member.name); + const snake_case_member = try constantName(allocator, member.name); // So it looks like some services have duplicate names?! Check out "httpMethod" // in API Gateway. Not sure what we're supposed to do there. Checking the go // sdk, they move this particular duplicate to 'http_method' - not sure yet @@ -846,34 +867,34 @@ fn generateComplexTypeFor(shape_id: []const u8, members: []smithy.TypeMember, ty switch (trait) { .json_name => |n| { found_name_trait = true; - field_name_mappings.appendAssumeCapacity(.{ .snake = try state.allocator.dupe(u8, snake_case_member), .original = n }); + field_name_mappings.appendAssumeCapacity(.{ .snake = try allocator.dupe(u8, snake_case_member), .original = n }); }, .xml_name => |n| { found_name_trait = true; - field_name_mappings.appendAssumeCapacity(.{ .snake = try state.allocator.dupe(u8, snake_case_member), .original = n }); + field_name_mappings.appendAssumeCapacity(.{ .snake = try allocator.dupe(u8, snake_case_member), .original = n }); }, - .http_query => |n| http_query_mappings.appendAssumeCapacity(.{ .snake = try state.allocator.dupe(u8, snake_case_member), .original = n }), - .http_header => http_header_mappings.appendAssumeCapacity(.{ .snake = try state.allocator.dupe(u8, snake_case_member), .original = trait.http_header }), + .http_query => |n| http_query_mappings.appendAssumeCapacity(.{ .snake = try allocator.dupe(u8, snake_case_member), .original = n }), + .http_header => http_header_mappings.appendAssumeCapacity(.{ .snake = try allocator.dupe(u8, snake_case_member), .original = trait.http_header }), .http_payload => { // Don't assert as that will be optimized for Release* builds // We'll continue here and treat the above as a warning if (payload) |first| { std.log.err("Found multiple httpPayloads in violation of smithy spec! Ignoring '{s}' and using '{s}'", .{ first, snake_case_member }); } - payload = try state.allocator.dupe(u8, snake_case_member); + payload = try allocator.dupe(u8, snake_case_member); }, else => {}, } } if (!found_name_trait) - field_name_mappings.appendAssumeCapacity(.{ .snake = try state.allocator.dupe(u8, snake_case_member), .original = member.name }); - defer state.allocator.free(snake_case_member); + field_name_mappings.appendAssumeCapacity(.{ .snake = try allocator.dupe(u8, snake_case_member), .original = member.name }); + try outputIndent(child_state, writer); const member_name = avoidReserved(snake_case_member); try writer.print("{s}: ", .{member_name}); try writeOptional(member.traits, writer, null); if (try generateTypeFor(member.target, writer, child_state, true)) - try map_fields.append(try std.fmt.allocPrint(state.allocator, "{s}", .{member_name})); + try map_fields.append(try std.fmt.allocPrint(allocator, "{s}", .{member_name})); if (!std.mem.eql(u8, "union", type_type_name)) try writeOptional(member.traits, writer, " = null"); diff --git a/lib/json/src/json.zig b/lib/json/src/json.zig index d307e9e..52b4bc1 100644 --- a/lib/json/src/json.zig +++ b/lib/json/src/json.zig @@ -14,35 +14,15 @@ const testing = std.testing; const mem = std.mem; const maxInt = std.math.maxInt; -pub fn serializeMap(map: anytype, key: []const u8, options: anytype, out_stream: anytype) !bool { +pub fn serializeMap(map: anytype, key: []const u8, options: anytype, out_stream: anytype) !void { if (@typeInfo(@TypeOf(map)) == .optional) { - if (map == null) - return false - else - return serializeMapInternal(map.?, key, options, out_stream); + if (map) |m| serializeMapInternal(m, key, options, out_stream); + } else { + serializeMapInternal(map, key, options, out_stream); } - return serializeMapInternal(map, key, options, out_stream); } -fn serializeMapInternal(map: anytype, key: []const u8, options: anytype, out_stream: anytype) !bool { - if (map.len == 0) { - var child_options = options; - if (child_options.whitespace) |*child_ws| - child_ws.indent_level += 1; - - try out_stream.writeByte('"'); - try out_stream.writeAll(key); - _ = try out_stream.write("\":"); - if (options.whitespace) |ws| { - if (ws.separator) { - try out_stream.writeByte(' '); - } - } - try out_stream.writeByte('{'); - try out_stream.writeByte('}'); - return true; - } - // TODO: Map might be [][]struct{key, value} rather than []struct{key, value} +fn serializeMapKey(key: []const u8, options: anytype, out_stream: anytype) !void { var child_options = options; if (child_options.whitespace) |*child_ws| child_ws.indent_level += 1; @@ -55,36 +35,52 @@ fn serializeMapInternal(map: anytype, key: []const u8, options: anytype, out_str try out_stream.writeByte(' '); } } +} + +pub fn serializeMapAsObject(map: anytype, options: anytype, out_stream: anytype) !void { + if (map.len == 0) { + try out_stream.writeByte('{'); + try out_stream.writeByte('}'); + } + + // TODO: Map might be [][]struct{key, value} rather than []struct{key, value} + try out_stream.writeByte('{'); if (options.whitespace) |_| try out_stream.writeByte('\n'); for (map, 0..) |tag, i| { if (tag.key == null or tag.value == null) continue; // TODO: Deal with escaping and general "json.stringify" the values... - if (child_options.whitespace) |ws| + if (options.whitespace) |ws| try ws.outputIndent(out_stream); try out_stream.writeByte('"'); - try jsonEscape(tag.key.?, child_options, out_stream); + try jsonEscape(tag.key.?, options, out_stream); _ = try out_stream.write("\":"); - if (child_options.whitespace) |ws| { + if (options.whitespace) |ws| { if (ws.separator) { try out_stream.writeByte(' '); } } try out_stream.writeByte('"'); - try jsonEscape(tag.value.?, child_options, out_stream); + try jsonEscape(tag.value.?, options, out_stream); try out_stream.writeByte('"'); if (i < map.len - 1) { try out_stream.writeByte(','); } - if (child_options.whitespace) |_| + if (options.whitespace) |_| try out_stream.writeByte('\n'); } if (options.whitespace) |ws| try ws.outputIndent(out_stream); try out_stream.writeByte('}'); - return true; } + +fn serializeMapInternal(map: anytype, key: []const u8, options: anytype, out_stream: anytype) !bool { + var child_options = options; + try serializeMapKey(key, &child_options, out_stream); + return try serializeMapAsObject(map, child_options, out_stream); +} + // code within jsonEscape lifted from json.zig in stdlib fn jsonEscape(value: []const u8, options: anytype, out_stream: anytype) !void { var i: usize = 0; diff --git a/src/aws.zig b/src/aws.zig index de72442..7a224cc 100644 --- a/src/aws.zig +++ b/src/aws.zig @@ -1389,7 +1389,7 @@ test "custom serialization for map objects" { defer tags.deinit(); tags.appendAssumeCapacity(.{ .key = "Foo", .value = "Bar" }); tags.appendAssumeCapacity(.{ .key = "Baz", .value = "Qux" }); - const req = services.lambda.tag_resource.Request{ .resource = "hello", .tags = tags.items }; + const req = services.lambda.TagResourceRequest{ .resource = "hello", .tags = tags.items }; try json.stringify(req, .{ .whitespace = .{} }, buffer.writer()); try std.testing.expectEqualStrings( \\{ diff --git a/src/main.zig b/src/main.zig index 1a7bfaf..658791f 100644 --- a/src/main.zig +++ b/src/main.zig @@ -192,7 +192,7 @@ pub fn main() anyerror!void { const func = fns[0]; const arn = func.function_arn.?; // This is a bit ugly. Maybe a helper function in the library would help? - var tags = try std.ArrayList(@typeInfo(try typeForField(services.lambda.tag_resource.Request, "tags")).pointer.child).initCapacity(allocator, 1); + var tags = try std.ArrayList(aws.services.lambda.TagKeyValue).initCapacity(allocator, 1); defer tags.deinit(); tags.appendAssumeCapacity(.{ .key = "Foo", .value = "Bar" }); const req = services.lambda.tag_resource.Request{ .resource = arn, .tags = tags.items };