diff --git a/codegen/src/main.zig b/codegen/src/main.zig index 18709ef..1bc6998 100644 --- a/codegen/src/main.zig +++ b/codegen/src/main.zig @@ -249,12 +249,22 @@ fn addReference(id: []const u8, map: *std.StringHashMap(u64)) !void { } fn countAllReferences(shape_ids: [][]const u8, shapes: std.StringHashMap(smithy.ShapeInfo), shape_references: *std.StringHashMap(u64), stack: *std.ArrayList([]const u8)) anyerror!void { for (shape_ids) |id| { - try countReferences(shapes.get(id).?, shapes, shape_references, stack); + const shape = shapes.get(id); + if (shape == null) { + std.log.err("Error - could not find shape with id {s}", .{id}); + return error.ShapeNotFound; + } + try countReferences(shape.?, shapes, shape_references, stack); } } fn countTypeMembersReferences(type_members: []smithy.TypeMember, shapes: std.StringHashMap(smithy.ShapeInfo), shape_references: *std.StringHashMap(u64), stack: *std.ArrayList([]const u8)) anyerror!void { for (type_members) |m| { - try countReferences(shapes.get(m.target).?, shapes, shape_references, stack); + const target = shapes.get(m.target); + if (target == null) { + std.log.err("Error - could not find target {s}", .{m.target}); + return error.TargetNotFound; + } + try countReferences(target.?, shapes, shape_references, stack); } } @@ -297,8 +307,22 @@ fn countReferences(shape: smithy.ShapeInfo, shapes: std.StringHashMap(smithy.Sha .uniontype => |m| try countTypeMembersReferences(m.members, shapes, shape_references, stack), .service => |i| try countAllReferences(i.operations, shapes, shape_references, stack), .operation => |op| { - if (op.input) |i| try countReferences(shapes.get(i).?, shapes, shape_references, stack); - if (op.output) |i| try countReferences(shapes.get(i).?, shapes, shape_references, stack); + if (op.input) |i| { + const val = shapes.get(i); + if (val == null) { + std.log.err("Error processing shape with id \"{s}\". Input shape \"{s}\" was not found", .{ shape.id, i }); + return error.ShapeNotFound; + } + try countReferences(val.?, shapes, shape_references, stack); + } + if (op.output) |i| { + const val = shapes.get(i); + if (val == null) { + std.log.err("Error processing shape with id \"{s}\". Output shape \"{s}\" was not found", .{ shape.id, i }); + return error.ShapeNotFound; + } + try countReferences(val.?, shapes, shape_references, stack); + } if (op.errors) |i| try countAllReferences(i, shapes, shape_references, stack); }, } @@ -589,16 +613,21 @@ fn reuseCommonType(shape: smithy.ShapeInfo, writer: anytype, state: GenerationSt } return rc; } +fn shapeInfoForId(id: []const u8, state: GenerationState) !smithy.ShapeInfo { + return state.file_state.shapes.get(id) orelse { + std.debug.print("Shape ID not found. This is most likely a bug. Shape ID: {s}\n", .{id}); + return error.InvalidType; + }; +} + /// return type is anyerror!void as this is a recursive function, so the compiler cannot properly infer error types fn generateTypeFor(shape_id: []const u8, writer: anytype, state: GenerationState, end_structure: bool) anyerror!bool { var rc = false; // We assume it must exist - const shape_info = state.file_state.shapes.get(shape_id) orelse { - std.debug.print("Shape ID not found. This is most likely a bug. Shape ID: {s}\n", .{shape_id}); - return error.InvalidType; - }; + const shape_info = try shapeInfoForId(shape_id, state); const shape = shape_info.shape; + // Check for ourselves up the stack var self_occurences: u8 = 0; for (state.type_stack.items) |i| {