diff --git a/codegen/src/main.zig b/codegen/src/main.zig index 08c6aeb..1e66168 100644 --- a/codegen/src/main.zig +++ b/codegen/src/main.zig @@ -69,6 +69,72 @@ fn generateServicesForFilePath(allocator: *std.mem.Allocator, comptime terminato defer file.close(); return try generateServices(allocator, terminator, file, writer); } + +fn addReference(id: []const u8, map: *std.StringHashMap(u64)) !void { + const res = try map.getOrPut(id); + if (res.found_existing) { + res.value_ptr.* += 1; + } else { + res.value_ptr.* = 1; + } +} +fn countAllReferences(shape_ids: [][]const u8, shapes: std.StringHashMap(smithy.ShapeInfo), shape_references: *std.StringHashMap(u64), stack: *std.ArrayList([]const u8)) anyerror!void { + for (shape_ids) |id| { + try countReferences(shapes.get(id).?, shapes, shape_references, stack); + } +} +fn countTypeMembersReferences(type_members: []smithy.TypeMember, shapes: std.StringHashMap(smithy.ShapeInfo), shape_references: *std.StringHashMap(u64), stack: *std.ArrayList([]const u8)) anyerror!void { + for (type_members) |m| { + try countReferences(shapes.get(m.target).?, shapes, shape_references, stack); + } +} + +fn countReferences(shape: smithy.ShapeInfo, shapes: std.StringHashMap(smithy.ShapeInfo), shape_references: *std.StringHashMap(u64), stack: *std.ArrayList([]const u8)) anyerror!void { + // Add ourselves as a reference, then we will continue down the tree + try addReference(shape.id, shape_references); + // Put ourselves on the stack. If we come back to ourselves, we want to end. + for (stack.items) |i| { + if (std.mem.eql(u8, shape.id, i)) + return; + } + try stack.append(shape.id); + defer _ = stack.pop(); + // Well, this is a fun read: https://awslabs.github.io/smithy/1.0/spec/core/model.html#recursive-shape-definitions + // Looks like recursion has special rules in the spec to accomodate Java. + // This is silly and we will ignore + switch (shape.shape) { + // We don't care about these primitives - they don't have children + .blob, + .boolean, + .string, + .byte, + .short, + .integer, + .long, + .float, + .double, + .bigInteger, + .bigDecimal, + .timestamp, + => {}, + .document, .member, .resource => {}, // less sure about these? + .list => |i| try countReferences(shapes.get(i.member_target).?, shapes, shape_references, stack), + .set => |i| try countReferences(shapes.get(i.member_target).?, shapes, shape_references, stack), + .map => |i| { + try countReferences(shapes.get(i.key).?, shapes, shape_references, stack); + try countReferences(shapes.get(i.value).?, shapes, shape_references, stack); + }, + .structure => |m| try countTypeMembersReferences(m.members, shapes, shape_references, stack), + .uniontype => |m| try countTypeMembersReferences(m.members, shapes, shape_references, stack), + .service => |i| try countAllReferences(i.operations, shapes, shape_references, stack), + .operation => |op| { + if (op.input) |i| try countReferences(shapes.get(i).?, shapes, shape_references, stack); + if (op.output) |i| try countReferences(shapes.get(i).?, shapes, shape_references, stack); + if (op.errors) |i| try countAllReferences(i, shapes, shape_references, stack); + }, + } +} + fn generateServices(allocator: *std.mem.Allocator, comptime _: []const u8, file: std.fs.File, writer: anytype) ![][]const u8 { const json = try file.readToEndAlloc(allocator, 1024 * 1024 * 1024); defer allocator.free(json); @@ -85,8 +151,29 @@ fn generateServices(allocator: *std.mem.Allocator, comptime _: []const u8, file: else => {}, } } + // At this point we want to generate a graph of shapes, starting + // services -> operations -> other shapes. This will allow us to get + // a reference count in case there are recursive data structures + var shape_references = std.StringHashMap(u64).init(allocator); + defer shape_references.deinit(); + var stack = std.ArrayList([]const u8).init(allocator); + defer stack.deinit(); + for (services.items) |service| + try countReferences(service, shapes, &shape_references, &stack); + var constant_names = std.ArrayList([]const u8).init(allocator); defer constant_names.deinit(); + var unresolved = std.ArrayList(smithy.ShapeInfo).init(allocator); + defer unresolved.deinit(); + var generated = std.StringHashMap(void).init(allocator); + defer generated.deinit(); + + var state = FileGenerationState{ + .shape_references = shape_references, + .additional_types_to_generate = &unresolved, + .additional_types_generated = &generated, + .shapes = shapes, + }; for (services.items) |service| { var sdk_id: []const u8 = undefined; var version: []const u8 = service.shape.service.version; @@ -136,10 +223,33 @@ fn generateServices(allocator: *std.mem.Allocator, comptime _: []const u8, file: // Operations for (service.shape.service.operations) |op| - try generateOperation(allocator, shapes.get(op).?, shapes, writer); + try generateOperation(allocator, shapes.get(op).?, state, writer); } + try generateAdditionalTypes(allocator, state, writer); return constant_names.toOwnedSlice(); } + +fn generateAdditionalTypes(allocator: *std.mem.Allocator, file_state: FileGenerationState, writer: anytype) !void { + // More types may be added during processing + while (file_state.additional_types_to_generate.popOrNull()) |t| { + if (file_state.additional_types_generated.getEntry(t.name) != null) continue; + // std.log.info("\t\t{s}", .{t.name}); + var type_stack = std.ArrayList(*const smithy.ShapeInfo).init(allocator); + defer type_stack.deinit(); + const state = GenerationState{ + .type_stack = &type_stack, + .file_state = file_state, + .allocator = allocator, + .indent_level = 0, + }; + const type_name = avoidReserved(t.name); + try writer.print("\npub const {s} = ", .{type_name}); + try file_state.additional_types_generated.putNoClobber(t.name, {}); + _ = try generateTypeFor(t.id, writer, state, true); + _ = try writer.write(";\n"); + } +} + fn constantName(allocator: *std.mem.Allocator, id: []const u8) ![]const u8 { // There are some ids that don't follow consistent rules, so we'll // look for the exceptions and, if not found, revert to the snake case @@ -161,19 +271,25 @@ fn constantName(allocator: *std.mem.Allocator, id: []const u8) ![]const u8 { return try snake.fromPascalCase(allocator, id); } +const FileGenerationState = struct { + shapes: std.StringHashMap(smithy.ShapeInfo), + shape_references: std.StringHashMap(u64), + additional_types_to_generate: *std.ArrayList(smithy.ShapeInfo), + additional_types_generated: *std.StringHashMap(void), +}; const GenerationState = struct { type_stack: *std.ArrayList(*const smithy.ShapeInfo), + file_state: FileGenerationState, // we will need some sort of "type decls needed" for recursive structures allocator: *std.mem.Allocator, indent_level: u64, - all_required: bool, }; fn outputIndent(state: GenerationState, writer: anytype) !void { const n_chars = 4 * state.indent_level; try writer.writeByteNTimes(' ', n_chars); } -fn generateOperation(allocator: *std.mem.Allocator, operation: smithy.ShapeInfo, shapes: std.StringHashMap(smithy.ShapeInfo), writer: anytype) !void { +fn generateOperation(allocator: *std.mem.Allocator, operation: smithy.ShapeInfo, file_state: FileGenerationState, writer: anytype) !void { const snake_case_name = try snake.fromPascalCase(allocator, operation.name); defer allocator.free(snake_case_name); @@ -181,9 +297,9 @@ fn generateOperation(allocator: *std.mem.Allocator, operation: smithy.ShapeInfo, defer type_stack.deinit(); const state = GenerationState{ .type_stack = &type_stack, + .file_state = file_state, .allocator = allocator, .indent_level = 1, - .all_required = false, }; var child_state = state; child_state.indent_level += 1; @@ -211,7 +327,7 @@ fn generateOperation(allocator: *std.mem.Allocator, operation: smithy.ShapeInfo, try outputIndent(state, writer); _ = try writer.write("Request: type = "); if (operation.shape.operation.input) |member| { - if (try generateTypeFor(member, shapes, writer, state, false)) unreachable; // we expect only structs here + if (try generateTypeFor(member, writer, state, false)) unreachable; // we expect only structs here _ = try writer.write("\n"); try generateMetadataFunction(operation_name, state, writer); } else { @@ -222,7 +338,7 @@ fn generateOperation(allocator: *std.mem.Allocator, operation: smithy.ShapeInfo, try outputIndent(state, writer); _ = try writer.write("Response: type = "); if (operation.shape.operation.output) |member| { - if (try generateTypeFor(member, shapes, writer, state, true)) unreachable; // we expect only structs here + if (try generateTypeFor(member, writer, state, true)) unreachable; // we expect only structs here } else _ = try writer.write("struct {}"); // we want to maintain consistency with other ops _ = try writer.write(",\n"); @@ -230,7 +346,7 @@ fn generateOperation(allocator: *std.mem.Allocator, operation: smithy.ShapeInfo, try outputIndent(state, writer); _ = try writer.write("ServiceError: type = error{\n"); for (errors) |err| { - const err_name = getErrorName(shapes.get(err).?.name); // need to remove "exception" + const err_name = getErrorName(file_state.shapes.get(err).?.name); // need to remove "exception" try outputIndent(child_state, writer); try writer.print("{s},\n", .{err_name}); } @@ -276,16 +392,42 @@ fn endsWith(item: []const u8, str: []const u8) bool { if (str.len < item.len) return false; return std.mem.eql(u8, item, str[str.len - item.len ..]); } -/// return type is anyerror!void as this is a recursive function, so the compiler cannot properly infer error types -fn generateTypeFor(shape_id: []const u8, shapes: std.StringHashMap(smithy.ShapeInfo), writer: anytype, state: GenerationState, end_structure: bool) anyerror!bool { + +fn reuseCommonType(shape: smithy.ShapeInfo, writer: anytype, state: GenerationState) !bool { + // We want to return if we're at the top level of the stack. There are three + // reasons for this: + // 1. For operations, we have a request that includes a metadata function + // to enable aws.zig eventually to find the action based on a request. + // This could be considered a hack and maybe we should remove that + // caller convenience ability. + // 2. Given the state of zig compiler tooling, "intellisense" or whatever + // we're calling it these days, isn't real mature, so we end up looking + // at the models quite a bit. Leaving the top level alone here reduces + // the need for users to hop around too much looking at types as they + // can at least see the top level. + // 3. When we come through at the end, we want to make sure we're writing + // something or we'll have an infinite loop! + if (state.type_stack.items.len == 1) return false; var rc = false; - if (shapes.get(shape_id) == null) { - std.debug.print("Shape ID not found. This is most likely a bug. Shape ID: {s}\n", .{shape_id}); - return error.InvalidType; + if (state.file_state.shape_references.get(shape.id)) |r| { + if (r > 1 and (shape.shape == .structure or shape.shape == .uniontype)) { + rc = true; + _ = try writer.write(avoidReserved(shape.name)); // This can't possibly be this easy... + if (state.file_state.additional_types_generated.getEntry(shape.name) == null) + try state.file_state.additional_types_to_generate.append(shape); + } } + return rc; +} +/// return type is anyerror!void as this is a recursive function, so the compiler cannot properly infer error types +fn generateTypeFor(shape_id: []const u8, writer: anytype, state: GenerationState, end_structure: bool) anyerror!bool { + var rc = false; // We assume it must exist - const shape_info = shapes.get(shape_id).?; + const shape_info = state.file_state.shapes.get(shape_id) orelse { + std.debug.print("Shape ID not found. This is most likely a bug. Shape ID: {s}\n", .{shape_id}); + return error.InvalidType; + }; const shape = shape_info.shape; // Check for ourselves up the stack var self_occurences: u8 = 0; @@ -324,30 +466,34 @@ fn generateTypeFor(shape_id: []const u8, shapes: std.StringHashMap(smithy.ShapeI defer _ = state.type_stack.pop(); switch (shape) { .structure => { - try generateComplexTypeFor(shape_id, shape.structure.members, "struct", shapes, writer, state); - if (end_structure) { + if (!try reuseCommonType(shape_info, writer, state)) { + try generateComplexTypeFor(shape_id, shape.structure.members, "struct", writer, state); + if (end_structure) { + // epilog + try outputIndent(state, writer); + _ = try writer.write("}"); + } + } + }, + .uniontype => { + if (!try reuseCommonType(shape_info, writer, state)) { + try generateComplexTypeFor(shape_id, shape.uniontype.members, "union", writer, state); // epilog try outputIndent(state, writer); _ = try writer.write("}"); } }, - .uniontype => { - try generateComplexTypeFor(shape_id, shape.uniontype.members, "union", shapes, writer, state); - // epilog - try outputIndent(state, writer); - _ = try writer.write("}"); - }, .string => |s| try generateSimpleTypeFor(s, "[]const u8", writer), .integer => |s| try generateSimpleTypeFor(s, "i64", writer), .list => { _ = try writer.write("[]"); // The serializer will have to deal with the idea we might be an array - return try generateTypeFor(shape.list.member_target, shapes, writer, state, true); + return try generateTypeFor(shape.list.member_target, writer, state, true); }, .set => { _ = try writer.write("[]"); // The serializer will have to deal with the idea we might be an array - return try generateTypeFor(shape.set.member_target, shapes, writer, state, true); + return try generateTypeFor(shape.set.member_target, writer, state, true); }, .timestamp => |s| try generateSimpleTypeFor(s, "i64", writer), .blob => |s| try generateSimpleTypeFor(s, "[]const u8", writer), @@ -364,14 +510,14 @@ fn generateTypeFor(shape_id: []const u8, shapes: std.StringHashMap(smithy.ShapeI try writeOptional(shape.map.traits, writer, null); var sub_maps = std.ArrayList([]const u8).init(state.allocator); defer sub_maps.deinit(); - if (try generateTypeFor(shape.map.key, shapes, writer, child_state, true)) + if (try generateTypeFor(shape.map.key, writer, child_state, true)) try sub_maps.append("key"); try writeOptional(shape.map.traits, writer, " = null"); _ = try writer.write(",\n"); try outputIndent(child_state, writer); _ = try writer.write("value: "); try writeOptional(shape.map.traits, writer, null); - if (try generateTypeFor(shape.map.value, shapes, writer, child_state, true)) + if (try generateTypeFor(shape.map.value, writer, child_state, true)) try sub_maps.append("value"); try writeOptional(shape.map.traits, writer, " = null"); _ = try writer.write(",\n"); @@ -396,7 +542,7 @@ fn generateTypeFor(shape_id: []const u8, shapes: std.StringHashMap(smithy.ShapeI fn generateSimpleTypeFor(_: anytype, type_name: []const u8, writer: anytype) !void { _ = try writer.write(type_name); // This had required stuff but the problem was elsewhere. Better to leave as function just in case } -fn generateComplexTypeFor(shape_id: []const u8, members: []smithy.TypeMember, type_type_name: []const u8, shapes: std.StringHashMap(smithy.ShapeInfo), writer: anytype, state: GenerationState) anyerror!void { +fn generateComplexTypeFor(shape_id: []const u8, members: []smithy.TypeMember, type_type_name: []const u8, writer: anytype, state: GenerationState) anyerror!void { _ = shape_id; const Mapping = struct { snake: []const u8, json: []const u8 }; var json_field_name_mappings = try std.ArrayList(Mapping).initCapacity(state.allocator, members.len); @@ -459,7 +605,7 @@ fn generateComplexTypeFor(shape_id: []const u8, members: []smithy.TypeMember, ty const member_name = avoidReserved(snake_case_member); try writer.print("{s}: ", .{member_name}); try writeOptional(member.traits, writer, null); - if (try generateTypeFor(member.target, shapes, writer, child_state, true)) + if (try generateTypeFor(member.target, writer, child_state, true)) try map_fields.append(try std.fmt.allocPrint(state.allocator, "{s}", .{member_name})); if (!std.mem.eql(u8, "union", type_type_name))