refactor(codegen): improve type generation

This commit is contained in:
Simon Hartcher 2025-05-28 11:14:53 +10:00
parent bc18fcca75
commit 13a926af70

View file

@ -83,6 +83,11 @@ pub fn main() anyerror!void {
if (args.len == 0) if (args.len == 0)
_ = try generateServices(allocator, ";", std.io.getStdIn(), stdout); _ = try generateServices(allocator, ";", std.io.getStdIn(), stdout);
if (verbose) {
const output_path = try output_dir.realpathAlloc(allocator, ".");
std.debug.print("Output path: {s}\n", .{output_path});
}
} }
const OutputManifest = struct { const OutputManifest = struct {
@ -184,16 +189,6 @@ fn calculateDigests(models_dir: std.fs.Dir, output_dir: std.fs.Dir, thread_pool:
}; };
} }
fn processFile(file_name: []const u8, output_dir: std.fs.Dir, manifest: anytype) !void { fn processFile(file_name: []const u8, output_dir: std.fs.Dir, manifest: anytype) !void {
// The fixed buffer for output will be 2MB, which is twice as large as the size of the EC2
// (the largest) model. We'll then flush all this at one go at the end.
var buffer = std.mem.zeroes([1024 * 1024 * 2]u8);
var output_stream = std.io.FixedBufferStream([]u8){
.buffer = &buffer,
.pos = 0,
};
var counting_writer = std.io.countingWriter(output_stream.writer());
var writer = counting_writer.writer();
// It's probably best to create our own allocator here so we can deint at the end and // It's probably best to create our own allocator here so we can deint at the end and
// toss all allocations related to the services in this file // toss all allocations related to the services in this file
// I can't guarantee we're not leaking something, and at the end of the // I can't guarantee we're not leaking something, and at the end of the
@ -201,6 +196,13 @@ fn processFile(file_name: []const u8, output_dir: std.fs.Dir, manifest: anytype)
var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator); var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator);
defer arena.deinit(); defer arena.deinit();
const allocator = arena.allocator(); const allocator = arena.allocator();
var output = try std.ArrayListUnmanaged(u8).initCapacity(allocator, 1024 * 1024 * 2);
defer output.deinit(allocator);
var counting_writer = std.io.countingWriter(output.writer(allocator));
var writer = counting_writer.writer();
_ = try writer.write("const std = @import(\"std\");\n"); _ = try writer.write("const std = @import(\"std\");\n");
_ = try writer.write("const smithy = @import(\"smithy\");\n"); _ = try writer.write("const smithy = @import(\"smithy\");\n");
_ = try writer.write("const json = @import(\"json\");\n"); _ = try writer.write("const json = @import(\"json\");\n");
@ -235,7 +237,8 @@ fn processFile(file_name: []const u8, output_dir: std.fs.Dir, manifest: anytype)
output_file_name = new_output_file_name; output_file_name = new_output_file_name;
} }
const formatted = try zigFmt(allocator, @ptrCast(buffer[0..counting_writer.bytes_written])); const unformatted: [:0]const u8 = try output.toOwnedSliceSentinel(allocator, 0);
const formatted = try zigFmt(allocator, unformatted);
// Dump our buffer out to disk // Dump our buffer out to disk
var file = try output_dir.createFile(output_file_name, .{ .truncate = true }); var file = try output_dir.createFile(output_file_name, .{ .truncate = true });
@ -425,7 +428,7 @@ fn generateServices(allocator: std.mem.Allocator, comptime _: []const u8, file:
// Service struct // Service struct
// name of the field will be snake_case of whatever comes in from // name of the field will be snake_case of whatever comes in from
// sdk_id. Not sure this will simple... // sdk_id. Not sure this will simple...
const constant_name = try constantName(allocator, sdk_id); const constant_name = try constantName(allocator, sdk_id, .snake);
try constant_names.append(constant_name); try constant_names.append(constant_name);
try writer.print("const Self = @This();\n", .{}); try writer.print("const Self = @This();\n", .{});
if (version) |v| if (version) |v|
@ -483,16 +486,20 @@ fn generateAdditionalTypes(allocator: std.mem.Allocator, file_state: FileGenerat
try writer.print("\npub const {s} = ", .{type_name}); try writer.print("\npub const {s} = ", .{type_name});
try file_state.additional_types_generated.putNoClobber(t.name, {}); try file_state.additional_types_generated.putNoClobber(t.name, {});
_ = try generateTypeFor(t.id, writer, state, true); _ = try generateTypeFor(t.id, writer, state, .{
.key_case = .snake,
.end_structure = true,
});
_ = try writer.write(";\n"); _ = try writer.write(";\n");
} }
} }
fn constantName(allocator: std.mem.Allocator, id: []const u8) ![]const u8 { fn constantName(allocator: std.mem.Allocator, id: []const u8, comptime to_case: case.Case) ![]const u8 {
// There are some ids that don't follow consistent rules, so we'll // There are some ids that don't follow consistent rules, so we'll
// look for the exceptions and, if not found, revert to the snake case // look for the exceptions and, if not found, revert to the snake case
// algorithm // algorithm
if (to_case == .snake) {
// This one might be a bug in snake, but it's the only example so HPDL // This one might be a bug in snake, but it's the only example so HPDL
if (std.mem.eql(u8, id, "SESv2")) return try std.fmt.allocPrint(allocator, "ses_v2", .{}); if (std.mem.eql(u8, id, "SESv2")) return try std.fmt.allocPrint(allocator, "ses_v2", .{});
if (std.mem.eql(u8, id, "CloudFront")) return try std.fmt.allocPrint(allocator, "cloudfront", .{}); if (std.mem.eql(u8, id, "CloudFront")) return try std.fmt.allocPrint(allocator, "cloudfront", .{});
@ -506,9 +513,10 @@ fn constantName(allocator: std.mem.Allocator, id: []const u8) ![]const u8 {
if (std.mem.eql(u8, id, "DevOps Guru")) return try std.fmt.allocPrint(allocator, "devops_guru", .{}); if (std.mem.eql(u8, id, "DevOps Guru")) return try std.fmt.allocPrint(allocator, "devops_guru", .{});
if (std.mem.eql(u8, id, "FSx")) return try std.fmt.allocPrint(allocator, "fsx", .{}); if (std.mem.eql(u8, id, "FSx")) return try std.fmt.allocPrint(allocator, "fsx", .{});
if (std.mem.eql(u8, id, "ETag")) return try std.fmt.allocPrint(allocator, "e_tag", .{}); if (std.mem.eql(u8, id, "ETag")) return try std.fmt.allocPrint(allocator, "e_tag", .{});
}
// Not a special case - just snake it // Not a special case - just snake it
return try case.allocTo(allocator, .snake, id); return try case.allocTo(allocator, to_case, id);
} }
const FileGenerationState = struct { const FileGenerationState = struct {
@ -529,8 +537,40 @@ fn outputIndent(state: GenerationState, writer: anytype) !void {
const n_chars = 4 * state.indent_level; const n_chars = 4 * state.indent_level;
try writer.writeByteNTimes(' ', n_chars); try writer.writeByteNTimes(' ', n_chars);
} }
const StructType = enum {
request,
response,
apiRequest,
apiResponse,
};
const OperationSubTypeInfo = struct {
type: StructType,
key_case: case.Case,
};
const operation_sub_types = [_]OperationSubTypeInfo{
OperationSubTypeInfo{
.key_case = .snake,
.type = .request,
},
OperationSubTypeInfo{
.key_case = .snake,
.type = .response,
},
// OperationSubTypeInfo{
// .key_case = .pascal,
// .type = .apiRequest,
// },
// OperationSubTypeInfo{
// .key_case = .pascal,
// .type = .apiResponse,
// },
};
fn generateOperation(allocator: std.mem.Allocator, operation: smithy.ShapeInfo, file_state: FileGenerationState, writer: anytype) !void { fn generateOperation(allocator: std.mem.Allocator, operation: smithy.ShapeInfo, file_state: FileGenerationState, writer: anytype) !void {
const snake_case_name = try constantName(allocator, operation.name); const snake_case_name = try constantName(allocator, operation.name, .snake);
defer allocator.free(snake_case_name); defer allocator.free(snake_case_name);
var type_stack = std.ArrayList(*const smithy.ShapeInfo).init(allocator); var type_stack = std.ArrayList(*const smithy.ShapeInfo).init(allocator);
@ -546,33 +586,38 @@ fn generateOperation(allocator: std.mem.Allocator, operation: smithy.ShapeInfo,
// indent should start at 4 spaces here // indent should start at 4 spaces here
const operation_name = avoidReserved(snake_case_name); const operation_name = avoidReserved(snake_case_name);
// Request type inline for (operation_sub_types) |type_info| {
_ = try writer.print("pub const {s}Request = ", .{operation.name}); _ = try writer.print("pub const {s}", .{operation.name});
if (operation.shape.operation.input == null or switch (type_info.type) {
(try shapeInfoForId(operation.shape.operation.input.?, state)).shape == .unit) .request => try writer.writeAll("Request"),
{ .response => try writer.writeAll("Response"),
_ = try writer.write("struct {\n"); .apiRequest => try writer.writeAll("ApiRequest"),
try generateMetadataFunction(operation_name, state, writer); .apiResponse => try writer.writeAll("ApiResponse"),
} else if (operation.shape.operation.input) |member| {
if (try generateTypeFor(member, writer, state, false)) unreachable; // we expect only structs here
_ = try writer.write("\n");
try generateMetadataFunction(operation_name, state, writer);
} }
_ = try writer.write(";\n\n"); try writer.writeAll(" = ");
// Response type const operation_field_name = switch (type_info.type) {
_ = try writer.print("pub const {s}Response = ", .{operation.name}); .request, .apiRequest => "input",
if (operation.shape.operation.output == null or .response, .apiResponse => "output",
(try shapeInfoForId(operation.shape.operation.output.?, state)).shape == .unit) };
const operation_field = @field(operation.shape.operation, operation_field_name);
if (operation_field == null or
(try shapeInfoForId(operation_field.?, state)).shape == .unit)
{ {
std.debug.print("This happens: {s} {s}\n", .{ operation.name, operation_field.? });
_ = try writer.write("struct {\n"); _ = try writer.write("struct {\n");
try generateMetadataFunction(operation_name, state, writer); try generateMetadataFunction(operation_name, state, writer);
} else if (operation.shape.operation.output) |member| { } else if (operation_field) |member| {
if (try generateTypeFor(member, writer, state, false)) unreachable; // we expect only structs here if (try generateTypeFor(member, writer, state, .{
.end_structure = false,
.key_case = type_info.key_case,
})) unreachable; // we expect only structs here
_ = try writer.write("\n"); _ = try writer.write("\n");
try generateMetadataFunction(operation_name, state, writer); try generateMetadataFunction(operation_name, state, writer);
} }
_ = try writer.write(";\n\n"); _ = try writer.write(";\n\n");
}
try writer.print("pub const {s}: struct ", .{operation_name}); try writer.print("pub const {s}: struct ", .{operation_name});
_ = try writer.write("{\n"); _ = try writer.write("{\n");
@ -652,16 +697,16 @@ fn endsWith(item: []const u8, str: []const u8) bool {
} }
fn getTypeName(allocator: std.mem.Allocator, shape: smithy.ShapeInfo) ![]const u8 { fn getTypeName(allocator: std.mem.Allocator, shape: smithy.ShapeInfo) ![]const u8 {
const type_name = avoidReserved(shape.name); const pascal_shape_name = try case.allocTo(allocator, .pascal, shape.name);
const type_name = avoidReserved(pascal_shape_name);
switch (shape.shape) { switch (shape.shape) {
// maps are named like "Tags" // maps are named like "Tags"
// this removes the trailing s and adds "KeyValue" suffix // this removes the trailing s and adds "KeyValue" suffix
.map => { .map => {
const map_type_name = avoidReserved(shape.name); return try std.fmt.allocPrint(allocator, "{s}KeyValue", .{pascal_shape_name[0 .. pascal_shape_name.len - 1]});
return try std.fmt.allocPrint(allocator, "{s}KeyValue", .{map_type_name[0 .. map_type_name.len - 1]});
}, },
else => return allocator.dupe(u8, type_name), else => return type_name,
} }
} }
@ -707,8 +752,22 @@ fn shapeInfoForId(id: []const u8, state: GenerationState) !smithy.ShapeInfo {
}; };
} }
const GenerateTypeOptions = struct {
end_structure: bool,
key_case: case.Case,
pub fn endStructure(self: @This(), value: bool) GenerateTypeOptions {
return .{
.end_structure = value,
.key_case = self.key_case,
};
}
};
/// return type is anyerror!void as this is a recursive function, so the compiler cannot properly infer error types /// return type is anyerror!void as this is a recursive function, so the compiler cannot properly infer error types
fn generateTypeFor(shape_id: []const u8, writer: anytype, state: GenerationState, end_structure: bool) anyerror!bool { fn generateTypeFor(shape_id: []const u8, writer: anytype, state: GenerationState, comptime options: GenerateTypeOptions) anyerror!bool {
const end_structure = options.end_structure;
var rc = false; var rc = false;
// We assume it must exist // We assume it must exist
@ -756,7 +815,7 @@ fn generateTypeFor(shape_id: []const u8, writer: anytype, state: GenerationState
switch (shape) { switch (shape) {
.structure => { .structure => {
if (!try reuseCommonType(shape_info, writer, state)) { if (!try reuseCommonType(shape_info, writer, state)) {
try generateComplexTypeFor(shape_id, shape.structure.members, "struct", writer, state); try generateComplexTypeFor(shape_id, shape.structure.members, "struct", writer, state, options);
if (end_structure) { if (end_structure) {
// epilog // epilog
try outputIndent(state, writer); try outputIndent(state, writer);
@ -766,7 +825,7 @@ fn generateTypeFor(shape_id: []const u8, writer: anytype, state: GenerationState
}, },
.uniontype => { .uniontype => {
if (!try reuseCommonType(shape_info, writer, state)) { if (!try reuseCommonType(shape_info, writer, state)) {
try generateComplexTypeFor(shape_id, shape.uniontype.members, "union", writer, state); try generateComplexTypeFor(shape_id, shape.uniontype.members, "union", writer, state, options);
// epilog // epilog
try outputIndent(state, writer); try outputIndent(state, writer);
_ = try writer.write("}"); _ = try writer.write("}");
@ -782,12 +841,12 @@ fn generateTypeFor(shape_id: []const u8, writer: anytype, state: GenerationState
.list => { .list => {
_ = try writer.write("[]"); _ = try writer.write("[]");
// The serializer will have to deal with the idea we might be an array // The serializer will have to deal with the idea we might be an array
return try generateTypeFor(shape.list.member_target, writer, state, true); return try generateTypeFor(shape.list.member_target, writer, state, options.endStructure(true));
}, },
.set => { .set => {
_ = try writer.write("[]"); _ = try writer.write("[]");
// The serializer will have to deal with the idea we might be an array // The serializer will have to deal with the idea we might be an array
return try generateTypeFor(shape.set.member_target, writer, state, true); return try generateTypeFor(shape.set.member_target, writer, state, options.endStructure(true));
}, },
.timestamp => |s| try generateSimpleTypeFor(s, "date.Timestamp", writer), .timestamp => |s| try generateSimpleTypeFor(s, "date.Timestamp", writer),
.blob => |s| try generateSimpleTypeFor(s, "[]const u8", writer), .blob => |s| try generateSimpleTypeFor(s, "[]const u8", writer),
@ -797,7 +856,7 @@ fn generateTypeFor(shape_id: []const u8, writer: anytype, state: GenerationState
.long => |s| try generateSimpleTypeFor(s, "i64", writer), .long => |s| try generateSimpleTypeFor(s, "i64", writer),
.map => |m| { .map => |m| {
if (!try reuseCommonType(shape_info, std.io.null_writer, state)) { if (!try reuseCommonType(shape_info, std.io.null_writer, state)) {
try generateMapTypeFor(m, writer, state); try generateMapTypeFor(m, writer, state, options);
rc = true; rc = true;
} else { } else {
try writer.writeAll("[]"); try writer.writeAll("[]");
@ -813,7 +872,7 @@ fn generateTypeFor(shape_id: []const u8, writer: anytype, state: GenerationState
return rc; return rc;
} }
fn generateMapTypeFor(map: anytype, writer: anytype, state: GenerationState) anyerror!void { fn generateMapTypeFor(map: anytype, writer: anytype, state: GenerationState, comptime options: GenerateTypeOptions) anyerror!void {
_ = try writer.write("struct {\n"); _ = try writer.write("struct {\n");
try writer.writeAll("pub const is_map_type = true;\n\n"); try writer.writeAll("pub const is_map_type = true;\n\n");
@ -824,7 +883,7 @@ fn generateMapTypeFor(map: anytype, writer: anytype, state: GenerationState) any
_ = try writer.write("key: "); _ = try writer.write("key: ");
try writeOptional(map.traits, writer, null); try writeOptional(map.traits, writer, null);
_ = try generateTypeFor(map.key, writer, child_state, true); _ = try generateTypeFor(map.key, writer, child_state, options.endStructure(true));
try writeOptional(map.traits, writer, " = null"); try writeOptional(map.traits, writer, " = null");
_ = try writer.write(",\n"); _ = try writer.write(",\n");
@ -832,7 +891,7 @@ fn generateMapTypeFor(map: anytype, writer: anytype, state: GenerationState) any
_ = try writer.write("value: "); _ = try writer.write("value: ");
try writeOptional(map.traits, writer, null); try writeOptional(map.traits, writer, null);
_ = try generateTypeFor(map.value, writer, child_state, true); _ = try generateTypeFor(map.value, writer, child_state, options.endStructure(true));
try writeOptional(map.traits, writer, " = null"); try writeOptional(map.traits, writer, " = null");
_ = try writer.write(",\n"); _ = try writer.write(",\n");
@ -844,7 +903,7 @@ fn generateSimpleTypeFor(_: anytype, type_name: []const u8, writer: anytype) !vo
} }
const Mapping = struct { snake: []const u8, original: []const u8 }; const Mapping = struct { snake: []const u8, original: []const u8 };
fn generateComplexTypeFor(shape_id: []const u8, members: []smithy.TypeMember, type_type_name: []const u8, writer: anytype, state: GenerationState) anyerror!void { fn generateComplexTypeFor(shape_id: []const u8, members: []smithy.TypeMember, type_type_name: []const u8, writer: anytype, state: GenerationState, comptime options: GenerateTypeOptions) anyerror!void {
_ = shape_id; _ = shape_id;
var arena = std.heap.ArenaAllocator.init(state.allocator); var arena = std.heap.ArenaAllocator.init(state.allocator);
@ -876,7 +935,7 @@ fn generateComplexTypeFor(shape_id: []const u8, members: []smithy.TypeMember, ty
var payload: ?[]const u8 = null; var payload: ?[]const u8 = null;
for (members) |member| { for (members) |member| {
// This is our mapping // This is our mapping
const snake_case_member = try constantName(allocator, member.name); const snake_case_member = try constantName(allocator, member.name, .snake);
// So it looks like some services have duplicate names?! Check out "httpMethod" // So it looks like some services have duplicate names?! Check out "httpMethod"
// in API Gateway. Not sure what we're supposed to do there. Checking the go // in API Gateway. Not sure what we're supposed to do there. Checking the go
// sdk, they move this particular duplicate to 'http_method' - not sure yet // sdk, they move this particular duplicate to 'http_method' - not sure yet
@ -909,10 +968,18 @@ fn generateComplexTypeFor(shape_id: []const u8, members: []smithy.TypeMember, ty
field_name_mappings.appendAssumeCapacity(.{ .snake = try allocator.dupe(u8, snake_case_member), .original = member.name }); field_name_mappings.appendAssumeCapacity(.{ .snake = try allocator.dupe(u8, snake_case_member), .original = member.name });
try outputIndent(child_state, writer); try outputIndent(child_state, writer);
const member_name = avoidReserved(snake_case_member);
const member_name = blk: {
if (options.key_case == .snake) {
break :blk avoidReserved(snake_case_member);
}
break :blk avoidReserved(try case.allocTo(allocator, options.key_case, snake_case_member));
};
try writer.print("{s}: ", .{member_name}); try writer.print("{s}: ", .{member_name});
try writeOptional(member.traits, writer, null); try writeOptional(member.traits, writer, null);
if (try generateTypeFor(member.target, writer, child_state, true)) if (try generateTypeFor(member.target, writer, child_state, options.endStructure(true)))
try map_fields.append(try std.fmt.allocPrint(allocator, "{s}", .{member_name})); try map_fields.append(try std.fmt.allocPrint(allocator, "{s}", .{member_name}));
if (!std.mem.eql(u8, "union", type_type_name)) if (!std.mem.eql(u8, "union", type_type_name))
@ -945,7 +1012,7 @@ fn generateComplexTypeFor(shape_id: []const u8, members: []smithy.TypeMember, ty
if (payload) |load| { if (payload) |load| {
try writer.writeByte('\n'); try writer.writeByte('\n');
try outputIndent(child_state, writer); try outputIndent(child_state, writer);
try writer.print("pub const http_payload: []const u8 = \"{s}\";", .{load}); try writer.print("pub const http_payload: []const u8 = \"{s}\";\n", .{load});
} }
try writer.writeByte('\n'); try writer.writeByte('\n');
@ -1024,17 +1091,17 @@ fn camelCase(allocator: std.mem.Allocator, name: []const u8) ![]const u8 {
const first_letter = name[0] + ('a' - 'A'); const first_letter = name[0] + ('a' - 'A');
return try std.fmt.allocPrint(allocator, "{c}{s}", .{ first_letter, name[1..] }); return try std.fmt.allocPrint(allocator, "{c}{s}", .{ first_letter, name[1..] });
} }
fn avoidReserved(snake_name: []const u8) []const u8 { fn avoidReserved(name: []const u8) []const u8 {
if (std.mem.eql(u8, snake_name, "error")) return "@\"error\""; if (std.mem.eql(u8, name, "error")) return "@\"error\"";
if (std.mem.eql(u8, snake_name, "return")) return "@\"return\""; if (std.mem.eql(u8, name, "return")) return "@\"return\"";
if (std.mem.eql(u8, snake_name, "not")) return "@\"not\""; if (std.mem.eql(u8, name, "not")) return "@\"not\"";
if (std.mem.eql(u8, snake_name, "and")) return "@\"and\""; if (std.mem.eql(u8, name, "and")) return "@\"and\"";
if (std.mem.eql(u8, snake_name, "or")) return "@\"or\""; if (std.mem.eql(u8, name, "or")) return "@\"or\"";
if (std.mem.eql(u8, snake_name, "test")) return "@\"test\""; if (std.mem.eql(u8, name, "test")) return "@\"test\"";
if (std.mem.eql(u8, snake_name, "null")) return "@\"null\""; if (std.mem.eql(u8, name, "null")) return "@\"null\"";
if (std.mem.eql(u8, snake_name, "export")) return "@\"export\""; if (std.mem.eql(u8, name, "export")) return "@\"export\"";
if (std.mem.eql(u8, snake_name, "union")) return "@\"union\""; if (std.mem.eql(u8, name, "union")) return "@\"union\"";
if (std.mem.eql(u8, snake_name, "enum")) return "@\"enum\""; if (std.mem.eql(u8, name, "enum")) return "@\"enum\"";
if (std.mem.eql(u8, snake_name, "inline")) return "@\"inline\""; if (std.mem.eql(u8, name, "inline")) return "@\"inline\"";
return snake_name; return name;
} }