diff --git a/src/srf.zig b/src/srf.zig index 5e9ad75..f116a3e 100644 --- a/src/srf.zig +++ b/src/srf.zig @@ -346,7 +346,19 @@ pub const Record = struct { return null; } + fn maxFields(comptime T: type) usize { + const ti = @typeInfo(T); + if (ti != .@"union") return std.meta.fields(T).len; + comptime var max_fields = 0; + inline for (std.meta.fields(T)) |f| { + const field_count = std.meta.fields(f.type).len; + if (field_count > max_fields) max_fields = field_count; + } + return max_fields + 1; + } + fn OwnedRecord(comptime T: type) type { + // for unions, we don't know how many fields we're dealing with... return struct { fields_buf: [fields_len]Field, fields_allocated: [fields_len]bool = .{false} ** fields_len, @@ -355,7 +367,7 @@ pub const Record = struct { cached_record: ?Record = null, const Self = @This(); - const fields_len = std.meta.fields(T).len; + const fields_len = maxFields(T); pub const SourceType = T; @@ -378,8 +390,8 @@ pub const Record = struct { comptime field_name: []const u8, comptime field_type: type, comptime default_value_ptr: ?*const anyopaque, + val: field_type, ) !usize { - const val = @field(self.source_value, field_name); if (default_value_ptr) |d| { const default_val: *const field_type = @ptrCast(@alignCast(d)); if (std.meta.eql(val, default_val.*)) return inx; @@ -434,21 +446,46 @@ pub const Record = struct { } } pub fn record(self: *Self) !Record { + return self.recordInternal(SourceType, self.source_value, 0); + } + fn recordInternal(self: *Self, comptime U: type, val: U, start_inx: usize) !Record { if (self.cached_record) |r| return r; - var inx: usize = 0; - const ti = @typeInfo(SourceType); + var inx: usize = start_inx; + const ti = @typeInfo(U); switch (ti) { .@"struct" => |info| { - inline for (info.fields) |f| - inx = try self.setField(inx, f.name, f.type, f.default_value_ptr); + inline for (info.fields) |f| { + const field_val = @field(val, f.name); + inx = try self.setField(inx, f.name, f.type, f.default_value_ptr, field_val); + } }, .@"union" => |info| { - inline for (info.fields) |f| - inx = try self.setField(inx, f.name, f.type, null); + const active_tag_name = @tagName(val); + comptime var has_decl = false; + inline for (info.decls) |d| { + if (comptime std.mem.eql(u8, "srf_tag_field", d.name)) has_decl = true; + } + const key = if (has_decl) + U.srf_tag_field + else + "active_tag"; + self.fields_buf[inx] = .{ + .key = key, + .value = .{ .string = active_tag_name }, + }; + inx += 1; + switch (val) { + inline else => |payload| { + if (@typeInfo(@TypeOf(payload)) == .@"union") + @compileError("Nested unions not supported for srf serialization"); + return self.recordInternal(@TypeOf(payload), payload, inx); + }, + } }, .@"enum" => |info| { + // TODO: I do not believe this is correct inline for (info.fields) |f| - inx = try self.setField(inx, f.name, self.SourceType, null); + inx = try self.setField(inx, f.name, self.SourceType, null, val); }, .error_set => return error.ErrorSetNotSupported, else => @compileError("Expected struct, union, error set or enum type, found '" ++ @typeName(T) ++ "'"), @@ -1246,6 +1283,41 @@ test "serialize/deserialize" { ; try std.testing.expectEqualStrings(expect, compact_from); } +test "unions" { + const Foo = struct { + number: u8, + true_or_false: bool, + }; + const Bar = struct { + sentence: []const u8, + decimal: f64, + }; + const MixedData = union(enum) { + foo: Foo, + bar: Bar, + + //pub const srf_tag_field = "foobar"; + }; + + const data: []const MixedData = &.{ + .{ .foo = .{ .number = 42, .true_or_false = true } }, + .{ .bar = .{ .sentence = "foobar", .decimal = 6.9 } }, + }; + const alloc = std.testing.allocator; + var buf: [4096]u8 = undefined; + const compact_from = try std.fmt.bufPrint( + &buf, + "{f}", + .{fmtFrom(MixedData, alloc, data, .{})}, + ); + const expect = + \\#!srfv1 + \\active_tag::foo,number:num:42,true_or_false:bool:true + \\active_tag::bar,sentence::foobar,decimal:num:6.9 + \\ + ; + try std.testing.expectEqualStrings(expect, compact_from); +} test "compact format length-prefixed string as last field" { // When a length-prefixed value is the last field on the line, // rest_of_data.len == size exactly. The check on line 216 uses