refactor: generate types for maps

TODO: handle json stringify for generated map types
This commit is contained in:
Simon Hartcher 2025-05-14 15:29:21 +10:00
parent 522ab72296
commit 934323acf1
4 changed files with 113 additions and 96 deletions

View file

@ -467,7 +467,9 @@ fn generateAdditionalTypes(allocator: std.mem.Allocator, file_state: FileGenerat
.allocator = allocator, .allocator = allocator,
.indent_level = 0, .indent_level = 0,
}; };
const type_name = avoidReserved(t.name); const type_name = try getTypeName(allocator, t);
defer allocator.free(type_name);
try writer.print("\npub const {s} = ", .{type_name}); try writer.print("\npub const {s} = ", .{type_name});
try file_state.additional_types_generated.putNoClobber(t.name, {}); try file_state.additional_types_generated.putNoClobber(t.name, {});
_ = try generateTypeFor(t.id, writer, state, true); _ = try generateTypeFor(t.id, writer, state, true);
@ -637,6 +639,18 @@ fn endsWith(item: []const u8, str: []const u8) bool {
return std.mem.eql(u8, item, str[str.len - item.len ..]); return std.mem.eql(u8, item, str[str.len - item.len ..]);
} }
fn getTypeName(allocator: std.mem.Allocator, shape: smithy.ShapeInfo) ![]const u8 {
const type_name = avoidReserved(shape.name);
switch (shape.shape) {
.map => {
const map_type_name = avoidReserved(shape.name);
return try std.fmt.allocPrint(allocator, "{s}KeyValue", .{map_type_name[0 .. map_type_name.len - 1]});
},
else => return allocator.dupe(u8, type_name),
}
}
fn reuseCommonType(shape: smithy.ShapeInfo, writer: anytype, state: GenerationState) !bool { fn reuseCommonType(shape: smithy.ShapeInfo, writer: anytype, state: GenerationState) !bool {
// We want to return if we're at the top level of the stack. There are three // We want to return if we're at the top level of the stack. There are three
// reasons for this: // reasons for this:
@ -651,12 +665,21 @@ fn reuseCommonType(shape: smithy.ShapeInfo, writer: anytype, state: GenerationSt
// can at least see the top level. // can at least see the top level.
// 3. When we come through at the end, we want to make sure we're writing // 3. When we come through at the end, we want to make sure we're writing
// something or we'll have an infinite loop! // something or we'll have an infinite loop!
switch (shape.shape) {
.structure, .uniontype, .map => {},
else => return false,
}
const type_name = try getTypeName(state.allocator, shape);
defer state.allocator.free(type_name);
if (state.type_stack.items.len == 1) return false; if (state.type_stack.items.len == 1) return false;
var rc = false; var rc = false;
if (state.file_state.shape_references.get(shape.id)) |r| { if (state.file_state.shape_references.get(shape.id)) |r| {
if (r > 1 and (shape.shape == .structure or shape.shape == .uniontype)) { if (r > 1) {
rc = true; rc = true;
_ = try writer.write(avoidReserved(shape.name)); // This can't possibly be this easy... _ = try writer.write(type_name); // This can't possibly be this easy...
if (state.file_state.additional_types_generated.getEntry(shape.name) == null) if (state.file_state.additional_types_generated.getEntry(shape.name) == null)
try state.file_state.additional_types_to_generate.append(shape); try state.file_state.additional_types_to_generate.append(shape);
} }
@ -755,34 +778,14 @@ fn generateTypeFor(shape_id: []const u8, writer: anytype, state: GenerationState
.double => |s| try generateSimpleTypeFor(s, "f64", writer), .double => |s| try generateSimpleTypeFor(s, "f64", writer),
.float => |s| try generateSimpleTypeFor(s, "f32", writer), .float => |s| try generateSimpleTypeFor(s, "f32", writer),
.long => |s| try generateSimpleTypeFor(s, "i64", writer), .long => |s| try generateSimpleTypeFor(s, "i64", writer),
.map => { .map => |m| {
_ = try writer.write("[]struct {\n"); if (!try reuseCommonType(shape_info, std.io.null_writer, state)) {
var child_state = state; try generateMapTypeFor(m, writer, state);
child_state.indent_level += 1; rc = true;
try outputIndent(child_state, writer); } else {
_ = try writer.write("key: "); try writer.writeAll("[]");
try writeOptional(shape.map.traits, writer, null); _ = try reuseCommonType(shape_info, writer, state);
var sub_maps = std.ArrayList([]const u8).init(state.allocator);
defer sub_maps.deinit();
if (try generateTypeFor(shape.map.key, writer, child_state, true))
try sub_maps.append("key");
try writeOptional(shape.map.traits, writer, " = null");
_ = try writer.write(",\n");
try outputIndent(child_state, writer);
_ = try writer.write("value: ");
try writeOptional(shape.map.traits, writer, null);
if (try generateTypeFor(shape.map.value, writer, child_state, true))
try sub_maps.append("value");
try writeOptional(shape.map.traits, writer, " = null");
_ = try writer.write(",\n");
if (sub_maps.items.len > 0) {
_ = try writer.write("\n");
try writeStringify(state, sub_maps.items, writer);
} }
try outputIndent(state, writer);
_ = try writer.write("}");
rc = true;
}, },
else => { else => {
std.log.err("encountered unimplemented shape type {s} for shape_id {s}. Generated code will not compile", .{ @tagName(shape), shape_id }); std.log.err("encountered unimplemented shape type {s} for shape_id {s}. Generated code will not compile", .{ @tagName(shape), shape_id });
@ -793,41 +796,59 @@ fn generateTypeFor(shape_id: []const u8, writer: anytype, state: GenerationState
return rc; return rc;
} }
fn generateMapTypeFor(map: anytype, writer: anytype, state: GenerationState) anyerror!void {
_ = try writer.write("struct {\n");
var child_state = state;
child_state.indent_level += 1;
_ = try writer.write("key: ");
try writeOptional(map.traits, writer, null);
_ = try generateTypeFor(map.key, writer, child_state, true);
try writeOptional(map.traits, writer, " = null");
_ = try writer.write(",\n");
_ = try writer.write("value: ");
try writeOptional(map.traits, writer, null);
_ = try generateTypeFor(map.value, writer, child_state, true);
try writeOptional(map.traits, writer, " = null");
_ = try writer.write(",\n");
_ = try writer.write("}");
}
fn generateSimpleTypeFor(_: anytype, type_name: []const u8, writer: anytype) !void { fn generateSimpleTypeFor(_: anytype, type_name: []const u8, writer: anytype) !void {
_ = try writer.write(type_name); // This had required stuff but the problem was elsewhere. Better to leave as function just in case _ = try writer.write(type_name); // This had required stuff but the problem was elsewhere. Better to leave as function just in case
} }
const Mapping = struct { snake: []const u8, original: []const u8 };
fn generateComplexTypeFor(shape_id: []const u8, members: []smithy.TypeMember, type_type_name: []const u8, writer: anytype, state: GenerationState) anyerror!void { fn generateComplexTypeFor(shape_id: []const u8, members: []smithy.TypeMember, type_type_name: []const u8, writer: anytype, state: GenerationState) anyerror!void {
_ = shape_id; _ = shape_id;
const Mapping = struct { snake: []const u8, original: []const u8 };
var field_name_mappings = try std.ArrayList(Mapping).initCapacity(state.allocator, members.len); var arena = std.heap.ArenaAllocator.init(state.allocator);
defer { defer arena.deinit();
for (field_name_mappings.items) |mapping| const allocator = arena.allocator();
state.allocator.free(mapping.snake);
field_name_mappings.deinit(); var field_name_mappings = try std.ArrayList(Mapping).initCapacity(allocator, members.len);
} defer field_name_mappings.deinit();
// There is an httpQueryParams trait as well, but nobody is using it. API GW // There is an httpQueryParams trait as well, but nobody is using it. API GW
// pretends to, but it's an empty map // pretends to, but it's an empty map
// //
// Same with httpPayload // Same with httpPayload
// //
// httpLabel is interesting - right now we just assume anything can be used - do we need to track this? // httpLabel is interesting - right now we just assume anything can be used - do we need to track this?
var http_query_mappings = try std.ArrayList(Mapping).initCapacity(state.allocator, members.len); var http_query_mappings = try std.ArrayList(Mapping).initCapacity(allocator, members.len);
defer { defer http_query_mappings.deinit();
for (http_query_mappings.items) |mapping|
state.allocator.free(mapping.snake); var http_header_mappings = try std.ArrayList(Mapping).initCapacity(allocator, members.len);
http_query_mappings.deinit(); defer http_header_mappings.deinit();
}
var http_header_mappings = try std.ArrayList(Mapping).initCapacity(state.allocator, members.len); var map_fields = std.ArrayList([]const u8).init(allocator);
defer { defer map_fields.deinit();
for (http_header_mappings.items) |mapping|
state.allocator.free(mapping.snake);
http_header_mappings.deinit();
}
var map_fields = std.ArrayList([]const u8).init(state.allocator);
defer {
for (map_fields.items) |f| state.allocator.free(f);
map_fields.deinit();
}
// prolog. We'll rely on caller to get the spacing correct here // prolog. We'll rely on caller to get the spacing correct here
_ = try writer.write(type_type_name); _ = try writer.write(type_type_name);
_ = try writer.write(" {\n"); _ = try writer.write(" {\n");
@ -836,7 +857,7 @@ fn generateComplexTypeFor(shape_id: []const u8, members: []smithy.TypeMember, ty
var payload: ?[]const u8 = null; var payload: ?[]const u8 = null;
for (members) |member| { for (members) |member| {
// This is our mapping // This is our mapping
const snake_case_member = try constantName(state.allocator, member.name); const snake_case_member = try constantName(allocator, member.name);
// So it looks like some services have duplicate names?! Check out "httpMethod" // So it looks like some services have duplicate names?! Check out "httpMethod"
// in API Gateway. Not sure what we're supposed to do there. Checking the go // in API Gateway. Not sure what we're supposed to do there. Checking the go
// sdk, they move this particular duplicate to 'http_method' - not sure yet // sdk, they move this particular duplicate to 'http_method' - not sure yet
@ -846,34 +867,34 @@ fn generateComplexTypeFor(shape_id: []const u8, members: []smithy.TypeMember, ty
switch (trait) { switch (trait) {
.json_name => |n| { .json_name => |n| {
found_name_trait = true; found_name_trait = true;
field_name_mappings.appendAssumeCapacity(.{ .snake = try state.allocator.dupe(u8, snake_case_member), .original = n }); field_name_mappings.appendAssumeCapacity(.{ .snake = try allocator.dupe(u8, snake_case_member), .original = n });
}, },
.xml_name => |n| { .xml_name => |n| {
found_name_trait = true; found_name_trait = true;
field_name_mappings.appendAssumeCapacity(.{ .snake = try state.allocator.dupe(u8, snake_case_member), .original = n }); field_name_mappings.appendAssumeCapacity(.{ .snake = try allocator.dupe(u8, snake_case_member), .original = n });
}, },
.http_query => |n| http_query_mappings.appendAssumeCapacity(.{ .snake = try state.allocator.dupe(u8, snake_case_member), .original = n }), .http_query => |n| http_query_mappings.appendAssumeCapacity(.{ .snake = try allocator.dupe(u8, snake_case_member), .original = n }),
.http_header => http_header_mappings.appendAssumeCapacity(.{ .snake = try state.allocator.dupe(u8, snake_case_member), .original = trait.http_header }), .http_header => http_header_mappings.appendAssumeCapacity(.{ .snake = try allocator.dupe(u8, snake_case_member), .original = trait.http_header }),
.http_payload => { .http_payload => {
// Don't assert as that will be optimized for Release* builds // Don't assert as that will be optimized for Release* builds
// We'll continue here and treat the above as a warning // We'll continue here and treat the above as a warning
if (payload) |first| { if (payload) |first| {
std.log.err("Found multiple httpPayloads in violation of smithy spec! Ignoring '{s}' and using '{s}'", .{ first, snake_case_member }); std.log.err("Found multiple httpPayloads in violation of smithy spec! Ignoring '{s}' and using '{s}'", .{ first, snake_case_member });
} }
payload = try state.allocator.dupe(u8, snake_case_member); payload = try allocator.dupe(u8, snake_case_member);
}, },
else => {}, else => {},
} }
} }
if (!found_name_trait) if (!found_name_trait)
field_name_mappings.appendAssumeCapacity(.{ .snake = try state.allocator.dupe(u8, snake_case_member), .original = member.name }); field_name_mappings.appendAssumeCapacity(.{ .snake = try allocator.dupe(u8, snake_case_member), .original = member.name });
defer state.allocator.free(snake_case_member);
try outputIndent(child_state, writer); try outputIndent(child_state, writer);
const member_name = avoidReserved(snake_case_member); const member_name = avoidReserved(snake_case_member);
try writer.print("{s}: ", .{member_name}); try writer.print("{s}: ", .{member_name});
try writeOptional(member.traits, writer, null); try writeOptional(member.traits, writer, null);
if (try generateTypeFor(member.target, writer, child_state, true)) if (try generateTypeFor(member.target, writer, child_state, true))
try map_fields.append(try std.fmt.allocPrint(state.allocator, "{s}", .{member_name})); try map_fields.append(try std.fmt.allocPrint(allocator, "{s}", .{member_name}));
if (!std.mem.eql(u8, "union", type_type_name)) if (!std.mem.eql(u8, "union", type_type_name))
try writeOptional(member.traits, writer, " = null"); try writeOptional(member.traits, writer, " = null");

View file

@ -14,35 +14,15 @@ const testing = std.testing;
const mem = std.mem; const mem = std.mem;
const maxInt = std.math.maxInt; const maxInt = std.math.maxInt;
pub fn serializeMap(map: anytype, key: []const u8, options: anytype, out_stream: anytype) !bool { pub fn serializeMap(map: anytype, key: []const u8, options: anytype, out_stream: anytype) !void {
if (@typeInfo(@TypeOf(map)) == .optional) { if (@typeInfo(@TypeOf(map)) == .optional) {
if (map == null) if (map) |m| serializeMapInternal(m, key, options, out_stream);
return false } else {
else serializeMapInternal(map, key, options, out_stream);
return serializeMapInternal(map.?, key, options, out_stream);
} }
return serializeMapInternal(map, key, options, out_stream);
} }
fn serializeMapInternal(map: anytype, key: []const u8, options: anytype, out_stream: anytype) !bool { fn serializeMapKey(key: []const u8, options: anytype, out_stream: anytype) !void {
if (map.len == 0) {
var child_options = options;
if (child_options.whitespace) |*child_ws|
child_ws.indent_level += 1;
try out_stream.writeByte('"');
try out_stream.writeAll(key);
_ = try out_stream.write("\":");
if (options.whitespace) |ws| {
if (ws.separator) {
try out_stream.writeByte(' ');
}
}
try out_stream.writeByte('{');
try out_stream.writeByte('}');
return true;
}
// TODO: Map might be [][]struct{key, value} rather than []struct{key, value}
var child_options = options; var child_options = options;
if (child_options.whitespace) |*child_ws| if (child_options.whitespace) |*child_ws|
child_ws.indent_level += 1; child_ws.indent_level += 1;
@ -55,36 +35,52 @@ fn serializeMapInternal(map: anytype, key: []const u8, options: anytype, out_str
try out_stream.writeByte(' '); try out_stream.writeByte(' ');
} }
} }
}
pub fn serializeMapAsObject(map: anytype, options: anytype, out_stream: anytype) !void {
if (map.len == 0) {
try out_stream.writeByte('{');
try out_stream.writeByte('}');
}
// TODO: Map might be [][]struct{key, value} rather than []struct{key, value}
try out_stream.writeByte('{'); try out_stream.writeByte('{');
if (options.whitespace) |_| if (options.whitespace) |_|
try out_stream.writeByte('\n'); try out_stream.writeByte('\n');
for (map, 0..) |tag, i| { for (map, 0..) |tag, i| {
if (tag.key == null or tag.value == null) continue; if (tag.key == null or tag.value == null) continue;
// TODO: Deal with escaping and general "json.stringify" the values... // TODO: Deal with escaping and general "json.stringify" the values...
if (child_options.whitespace) |ws| if (options.whitespace) |ws|
try ws.outputIndent(out_stream); try ws.outputIndent(out_stream);
try out_stream.writeByte('"'); try out_stream.writeByte('"');
try jsonEscape(tag.key.?, child_options, out_stream); try jsonEscape(tag.key.?, options, out_stream);
_ = try out_stream.write("\":"); _ = try out_stream.write("\":");
if (child_options.whitespace) |ws| { if (options.whitespace) |ws| {
if (ws.separator) { if (ws.separator) {
try out_stream.writeByte(' '); try out_stream.writeByte(' ');
} }
} }
try out_stream.writeByte('"'); try out_stream.writeByte('"');
try jsonEscape(tag.value.?, child_options, out_stream); try jsonEscape(tag.value.?, options, out_stream);
try out_stream.writeByte('"'); try out_stream.writeByte('"');
if (i < map.len - 1) { if (i < map.len - 1) {
try out_stream.writeByte(','); try out_stream.writeByte(',');
} }
if (child_options.whitespace) |_| if (options.whitespace) |_|
try out_stream.writeByte('\n'); try out_stream.writeByte('\n');
} }
if (options.whitespace) |ws| if (options.whitespace) |ws|
try ws.outputIndent(out_stream); try ws.outputIndent(out_stream);
try out_stream.writeByte('}'); try out_stream.writeByte('}');
return true;
} }
fn serializeMapInternal(map: anytype, key: []const u8, options: anytype, out_stream: anytype) !bool {
var child_options = options;
try serializeMapKey(key, &child_options, out_stream);
return try serializeMapAsObject(map, child_options, out_stream);
}
// code within jsonEscape lifted from json.zig in stdlib // code within jsonEscape lifted from json.zig in stdlib
fn jsonEscape(value: []const u8, options: anytype, out_stream: anytype) !void { fn jsonEscape(value: []const u8, options: anytype, out_stream: anytype) !void {
var i: usize = 0; var i: usize = 0;

View file

@ -1389,7 +1389,7 @@ test "custom serialization for map objects" {
defer tags.deinit(); defer tags.deinit();
tags.appendAssumeCapacity(.{ .key = "Foo", .value = "Bar" }); tags.appendAssumeCapacity(.{ .key = "Foo", .value = "Bar" });
tags.appendAssumeCapacity(.{ .key = "Baz", .value = "Qux" }); tags.appendAssumeCapacity(.{ .key = "Baz", .value = "Qux" });
const req = services.lambda.tag_resource.Request{ .resource = "hello", .tags = tags.items }; const req = services.lambda.TagResourceRequest{ .resource = "hello", .tags = tags.items };
try json.stringify(req, .{ .whitespace = .{} }, buffer.writer()); try json.stringify(req, .{ .whitespace = .{} }, buffer.writer());
try std.testing.expectEqualStrings( try std.testing.expectEqualStrings(
\\{ \\{

View file

@ -192,7 +192,7 @@ pub fn main() anyerror!void {
const func = fns[0]; const func = fns[0];
const arn = func.function_arn.?; const arn = func.function_arn.?;
// This is a bit ugly. Maybe a helper function in the library would help? // This is a bit ugly. Maybe a helper function in the library would help?
var tags = try std.ArrayList(@typeInfo(try typeForField(services.lambda.tag_resource.Request, "tags")).pointer.child).initCapacity(allocator, 1); var tags = try std.ArrayList(aws.services.lambda.TagKeyValue).initCapacity(allocator, 1);
defer tags.deinit(); defer tags.deinit();
tags.appendAssumeCapacity(.{ .key = "Foo", .value = "Bar" }); tags.appendAssumeCapacity(.{ .key = "Foo", .value = "Bar" });
const req = services.lambda.tag_resource.Request{ .resource = arn, .tags = tags.items }; const req = services.lambda.tag_resource.Request{ .resource = arn, .tags = tags.items };