refactor: generate types for maps

TODO: handle json stringify for generated map types
This commit is contained in:
Simon Hartcher 2025-05-14 15:29:21 +10:00
parent 522ab72296
commit 934323acf1
4 changed files with 113 additions and 96 deletions

View file

@ -467,7 +467,9 @@ fn generateAdditionalTypes(allocator: std.mem.Allocator, file_state: FileGenerat
.allocator = allocator,
.indent_level = 0,
};
const type_name = avoidReserved(t.name);
const type_name = try getTypeName(allocator, t);
defer allocator.free(type_name);
try writer.print("\npub const {s} = ", .{type_name});
try file_state.additional_types_generated.putNoClobber(t.name, {});
_ = try generateTypeFor(t.id, writer, state, true);
@ -637,6 +639,18 @@ fn endsWith(item: []const u8, str: []const u8) bool {
return std.mem.eql(u8, item, str[str.len - item.len ..]);
}
fn getTypeName(allocator: std.mem.Allocator, shape: smithy.ShapeInfo) ![]const u8 {
const type_name = avoidReserved(shape.name);
switch (shape.shape) {
.map => {
const map_type_name = avoidReserved(shape.name);
return try std.fmt.allocPrint(allocator, "{s}KeyValue", .{map_type_name[0 .. map_type_name.len - 1]});
},
else => return allocator.dupe(u8, type_name),
}
}
fn reuseCommonType(shape: smithy.ShapeInfo, writer: anytype, state: GenerationState) !bool {
// We want to return if we're at the top level of the stack. There are three
// reasons for this:
@ -651,12 +665,21 @@ fn reuseCommonType(shape: smithy.ShapeInfo, writer: anytype, state: GenerationSt
// can at least see the top level.
// 3. When we come through at the end, we want to make sure we're writing
// something or we'll have an infinite loop!
switch (shape.shape) {
.structure, .uniontype, .map => {},
else => return false,
}
const type_name = try getTypeName(state.allocator, shape);
defer state.allocator.free(type_name);
if (state.type_stack.items.len == 1) return false;
var rc = false;
if (state.file_state.shape_references.get(shape.id)) |r| {
if (r > 1 and (shape.shape == .structure or shape.shape == .uniontype)) {
if (r > 1) {
rc = true;
_ = try writer.write(avoidReserved(shape.name)); // This can't possibly be this easy...
_ = try writer.write(type_name); // This can't possibly be this easy...
if (state.file_state.additional_types_generated.getEntry(shape.name) == null)
try state.file_state.additional_types_to_generate.append(shape);
}
@ -755,34 +778,14 @@ fn generateTypeFor(shape_id: []const u8, writer: anytype, state: GenerationState
.double => |s| try generateSimpleTypeFor(s, "f64", writer),
.float => |s| try generateSimpleTypeFor(s, "f32", writer),
.long => |s| try generateSimpleTypeFor(s, "i64", writer),
.map => {
_ = try writer.write("[]struct {\n");
var child_state = state;
child_state.indent_level += 1;
try outputIndent(child_state, writer);
_ = try writer.write("key: ");
try writeOptional(shape.map.traits, writer, null);
var sub_maps = std.ArrayList([]const u8).init(state.allocator);
defer sub_maps.deinit();
if (try generateTypeFor(shape.map.key, writer, child_state, true))
try sub_maps.append("key");
try writeOptional(shape.map.traits, writer, " = null");
_ = try writer.write(",\n");
try outputIndent(child_state, writer);
_ = try writer.write("value: ");
try writeOptional(shape.map.traits, writer, null);
if (try generateTypeFor(shape.map.value, writer, child_state, true))
try sub_maps.append("value");
try writeOptional(shape.map.traits, writer, " = null");
_ = try writer.write(",\n");
if (sub_maps.items.len > 0) {
_ = try writer.write("\n");
try writeStringify(state, sub_maps.items, writer);
.map => |m| {
if (!try reuseCommonType(shape_info, std.io.null_writer, state)) {
try generateMapTypeFor(m, writer, state);
rc = true;
} else {
try writer.writeAll("[]");
_ = try reuseCommonType(shape_info, writer, state);
}
try outputIndent(state, writer);
_ = try writer.write("}");
rc = true;
},
else => {
std.log.err("encountered unimplemented shape type {s} for shape_id {s}. Generated code will not compile", .{ @tagName(shape), shape_id });
@ -793,41 +796,59 @@ fn generateTypeFor(shape_id: []const u8, writer: anytype, state: GenerationState
return rc;
}
fn generateMapTypeFor(map: anytype, writer: anytype, state: GenerationState) anyerror!void {
_ = try writer.write("struct {\n");
var child_state = state;
child_state.indent_level += 1;
_ = try writer.write("key: ");
try writeOptional(map.traits, writer, null);
_ = try generateTypeFor(map.key, writer, child_state, true);
try writeOptional(map.traits, writer, " = null");
_ = try writer.write(",\n");
_ = try writer.write("value: ");
try writeOptional(map.traits, writer, null);
_ = try generateTypeFor(map.value, writer, child_state, true);
try writeOptional(map.traits, writer, " = null");
_ = try writer.write(",\n");
_ = try writer.write("}");
}
fn generateSimpleTypeFor(_: anytype, type_name: []const u8, writer: anytype) !void {
_ = try writer.write(type_name); // This had required stuff but the problem was elsewhere. Better to leave as function just in case
}
const Mapping = struct { snake: []const u8, original: []const u8 };
fn generateComplexTypeFor(shape_id: []const u8, members: []smithy.TypeMember, type_type_name: []const u8, writer: anytype, state: GenerationState) anyerror!void {
_ = shape_id;
const Mapping = struct { snake: []const u8, original: []const u8 };
var field_name_mappings = try std.ArrayList(Mapping).initCapacity(state.allocator, members.len);
defer {
for (field_name_mappings.items) |mapping|
state.allocator.free(mapping.snake);
field_name_mappings.deinit();
}
var arena = std.heap.ArenaAllocator.init(state.allocator);
defer arena.deinit();
const allocator = arena.allocator();
var field_name_mappings = try std.ArrayList(Mapping).initCapacity(allocator, members.len);
defer field_name_mappings.deinit();
// There is an httpQueryParams trait as well, but nobody is using it. API GW
// pretends to, but it's an empty map
//
// Same with httpPayload
//
// httpLabel is interesting - right now we just assume anything can be used - do we need to track this?
var http_query_mappings = try std.ArrayList(Mapping).initCapacity(state.allocator, members.len);
defer {
for (http_query_mappings.items) |mapping|
state.allocator.free(mapping.snake);
http_query_mappings.deinit();
}
var http_header_mappings = try std.ArrayList(Mapping).initCapacity(state.allocator, members.len);
defer {
for (http_header_mappings.items) |mapping|
state.allocator.free(mapping.snake);
http_header_mappings.deinit();
}
var map_fields = std.ArrayList([]const u8).init(state.allocator);
defer {
for (map_fields.items) |f| state.allocator.free(f);
map_fields.deinit();
}
var http_query_mappings = try std.ArrayList(Mapping).initCapacity(allocator, members.len);
defer http_query_mappings.deinit();
var http_header_mappings = try std.ArrayList(Mapping).initCapacity(allocator, members.len);
defer http_header_mappings.deinit();
var map_fields = std.ArrayList([]const u8).init(allocator);
defer map_fields.deinit();
// prolog. We'll rely on caller to get the spacing correct here
_ = try writer.write(type_type_name);
_ = try writer.write(" {\n");
@ -836,7 +857,7 @@ fn generateComplexTypeFor(shape_id: []const u8, members: []smithy.TypeMember, ty
var payload: ?[]const u8 = null;
for (members) |member| {
// This is our mapping
const snake_case_member = try constantName(state.allocator, member.name);
const snake_case_member = try constantName(allocator, member.name);
// So it looks like some services have duplicate names?! Check out "httpMethod"
// in API Gateway. Not sure what we're supposed to do there. Checking the go
// sdk, they move this particular duplicate to 'http_method' - not sure yet
@ -846,34 +867,34 @@ fn generateComplexTypeFor(shape_id: []const u8, members: []smithy.TypeMember, ty
switch (trait) {
.json_name => |n| {
found_name_trait = true;
field_name_mappings.appendAssumeCapacity(.{ .snake = try state.allocator.dupe(u8, snake_case_member), .original = n });
field_name_mappings.appendAssumeCapacity(.{ .snake = try allocator.dupe(u8, snake_case_member), .original = n });
},
.xml_name => |n| {
found_name_trait = true;
field_name_mappings.appendAssumeCapacity(.{ .snake = try state.allocator.dupe(u8, snake_case_member), .original = n });
field_name_mappings.appendAssumeCapacity(.{ .snake = try allocator.dupe(u8, snake_case_member), .original = n });
},
.http_query => |n| http_query_mappings.appendAssumeCapacity(.{ .snake = try state.allocator.dupe(u8, snake_case_member), .original = n }),
.http_header => http_header_mappings.appendAssumeCapacity(.{ .snake = try state.allocator.dupe(u8, snake_case_member), .original = trait.http_header }),
.http_query => |n| http_query_mappings.appendAssumeCapacity(.{ .snake = try allocator.dupe(u8, snake_case_member), .original = n }),
.http_header => http_header_mappings.appendAssumeCapacity(.{ .snake = try allocator.dupe(u8, snake_case_member), .original = trait.http_header }),
.http_payload => {
// Don't assert as that will be optimized for Release* builds
// We'll continue here and treat the above as a warning
if (payload) |first| {
std.log.err("Found multiple httpPayloads in violation of smithy spec! Ignoring '{s}' and using '{s}'", .{ first, snake_case_member });
}
payload = try state.allocator.dupe(u8, snake_case_member);
payload = try allocator.dupe(u8, snake_case_member);
},
else => {},
}
}
if (!found_name_trait)
field_name_mappings.appendAssumeCapacity(.{ .snake = try state.allocator.dupe(u8, snake_case_member), .original = member.name });
defer state.allocator.free(snake_case_member);
field_name_mappings.appendAssumeCapacity(.{ .snake = try allocator.dupe(u8, snake_case_member), .original = member.name });
try outputIndent(child_state, writer);
const member_name = avoidReserved(snake_case_member);
try writer.print("{s}: ", .{member_name});
try writeOptional(member.traits, writer, null);
if (try generateTypeFor(member.target, writer, child_state, true))
try map_fields.append(try std.fmt.allocPrint(state.allocator, "{s}", .{member_name}));
try map_fields.append(try std.fmt.allocPrint(allocator, "{s}", .{member_name}));
if (!std.mem.eql(u8, "union", type_type_name))
try writeOptional(member.traits, writer, " = null");

View file

@ -14,35 +14,15 @@ const testing = std.testing;
const mem = std.mem;
const maxInt = std.math.maxInt;
pub fn serializeMap(map: anytype, key: []const u8, options: anytype, out_stream: anytype) !bool {
pub fn serializeMap(map: anytype, key: []const u8, options: anytype, out_stream: anytype) !void {
if (@typeInfo(@TypeOf(map)) == .optional) {
if (map == null)
return false
else
return serializeMapInternal(map.?, key, options, out_stream);
if (map) |m| serializeMapInternal(m, key, options, out_stream);
} else {
serializeMapInternal(map, key, options, out_stream);
}
return serializeMapInternal(map, key, options, out_stream);
}
fn serializeMapInternal(map: anytype, key: []const u8, options: anytype, out_stream: anytype) !bool {
if (map.len == 0) {
var child_options = options;
if (child_options.whitespace) |*child_ws|
child_ws.indent_level += 1;
try out_stream.writeByte('"');
try out_stream.writeAll(key);
_ = try out_stream.write("\":");
if (options.whitespace) |ws| {
if (ws.separator) {
try out_stream.writeByte(' ');
}
}
try out_stream.writeByte('{');
try out_stream.writeByte('}');
return true;
}
// TODO: Map might be [][]struct{key, value} rather than []struct{key, value}
fn serializeMapKey(key: []const u8, options: anytype, out_stream: anytype) !void {
var child_options = options;
if (child_options.whitespace) |*child_ws|
child_ws.indent_level += 1;
@ -55,36 +35,52 @@ fn serializeMapInternal(map: anytype, key: []const u8, options: anytype, out_str
try out_stream.writeByte(' ');
}
}
}
pub fn serializeMapAsObject(map: anytype, options: anytype, out_stream: anytype) !void {
if (map.len == 0) {
try out_stream.writeByte('{');
try out_stream.writeByte('}');
}
// TODO: Map might be [][]struct{key, value} rather than []struct{key, value}
try out_stream.writeByte('{');
if (options.whitespace) |_|
try out_stream.writeByte('\n');
for (map, 0..) |tag, i| {
if (tag.key == null or tag.value == null) continue;
// TODO: Deal with escaping and general "json.stringify" the values...
if (child_options.whitespace) |ws|
if (options.whitespace) |ws|
try ws.outputIndent(out_stream);
try out_stream.writeByte('"');
try jsonEscape(tag.key.?, child_options, out_stream);
try jsonEscape(tag.key.?, options, out_stream);
_ = try out_stream.write("\":");
if (child_options.whitespace) |ws| {
if (options.whitespace) |ws| {
if (ws.separator) {
try out_stream.writeByte(' ');
}
}
try out_stream.writeByte('"');
try jsonEscape(tag.value.?, child_options, out_stream);
try jsonEscape(tag.value.?, options, out_stream);
try out_stream.writeByte('"');
if (i < map.len - 1) {
try out_stream.writeByte(',');
}
if (child_options.whitespace) |_|
if (options.whitespace) |_|
try out_stream.writeByte('\n');
}
if (options.whitespace) |ws|
try ws.outputIndent(out_stream);
try out_stream.writeByte('}');
return true;
}
fn serializeMapInternal(map: anytype, key: []const u8, options: anytype, out_stream: anytype) !bool {
var child_options = options;
try serializeMapKey(key, &child_options, out_stream);
return try serializeMapAsObject(map, child_options, out_stream);
}
// code within jsonEscape lifted from json.zig in stdlib
fn jsonEscape(value: []const u8, options: anytype, out_stream: anytype) !void {
var i: usize = 0;

View file

@ -1389,7 +1389,7 @@ test "custom serialization for map objects" {
defer tags.deinit();
tags.appendAssumeCapacity(.{ .key = "Foo", .value = "Bar" });
tags.appendAssumeCapacity(.{ .key = "Baz", .value = "Qux" });
const req = services.lambda.tag_resource.Request{ .resource = "hello", .tags = tags.items };
const req = services.lambda.TagResourceRequest{ .resource = "hello", .tags = tags.items };
try json.stringify(req, .{ .whitespace = .{} }, buffer.writer());
try std.testing.expectEqualStrings(
\\{

View file

@ -192,7 +192,7 @@ pub fn main() anyerror!void {
const func = fns[0];
const arn = func.function_arn.?;
// This is a bit ugly. Maybe a helper function in the library would help?
var tags = try std.ArrayList(@typeInfo(try typeForField(services.lambda.tag_resource.Request, "tags")).pointer.child).initCapacity(allocator, 1);
var tags = try std.ArrayList(aws.services.lambda.TagKeyValue).initCapacity(allocator, 1);
defer tags.deinit();
tags.appendAssumeCapacity(.{ .key = "Foo", .value = "Bar" });
const req = services.lambda.tag_resource.Request{ .resource = arn, .tags = tags.items };