some codegen changes needed/investigation into http_query (see note)

This fixes main compile issues. Problems that remain:

* json is borked for objects with key/value keys that need to be coerced
* currently all request structures need an http_query declaration, and this may be a hard requirement
* need to revisit all the places where we are reading http response bodies
* there are 35 compile errors that do not have reference traces and will take a while to track down
This commit is contained in:
Emil Lerch 2023-08-05 12:41:04 -07:00
parent e3ec2b8c2e
commit 5ee53a493d
Signed by: lobo
GPG Key ID: A7B62D657EF764F8
12 changed files with 278 additions and 256 deletions

1
.gitignore vendored
View File

@ -10,3 +10,4 @@ smithy/zig-out/
libs/ libs/
src/git_version.zig src/git_version.zig
zig-out zig-out
core

View File

@ -209,7 +209,7 @@ fn generateServices(allocator: std.mem.Allocator, comptime _: []const u8, file:
try writer.print("pub const sigv4_name: []const u8 = \"{s}\";\n", .{sigv4_name}); try writer.print("pub const sigv4_name: []const u8 = \"{s}\";\n", .{sigv4_name});
try writer.print("pub const name: []const u8 = \"{s}\";\n", .{name}); try writer.print("pub const name: []const u8 = \"{s}\";\n", .{name});
// TODO: This really should just be ".whatevs". We're fully qualifying here, which isn't typical // TODO: This really should just be ".whatevs". We're fully qualifying here, which isn't typical
try writer.print("pub const aws_protocol: smithy.AwsProtocol = smithy.{};\n\n", .{aws_protocol}); try writer.print("pub const aws_protocol: smithy.AwsProtocol = {};\n\n", .{aws_protocol});
_ = try writer.write("pub const service_metadata: struct {\n"); _ = try writer.write("pub const service_metadata: struct {\n");
try writer.print(" version: []const u8 = \"{s}\",\n", .{version}); try writer.print(" version: []const u8 = \"{s}\",\n", .{version});
try writer.print(" sdk_id: []const u8 = \"{s}\",\n", .{sdk_id}); try writer.print(" sdk_id: []const u8 = \"{s}\",\n", .{sdk_id});
@ -218,7 +218,7 @@ fn generateServices(allocator: std.mem.Allocator, comptime _: []const u8, file:
try writer.print(" sigv4_name: []const u8 = \"{s}\",\n", .{sigv4_name}); try writer.print(" sigv4_name: []const u8 = \"{s}\",\n", .{sigv4_name});
try writer.print(" name: []const u8 = \"{s}\",\n", .{name}); try writer.print(" name: []const u8 = \"{s}\",\n", .{name});
// TODO: This really should just be ".whatevs". We're fully qualifying here, which isn't typical // TODO: This really should just be ".whatevs". We're fully qualifying here, which isn't typical
try writer.print(" aws_protocol: smithy.AwsProtocol = smithy.{},\n", .{aws_protocol}); try writer.print(" aws_protocol: smithy.AwsProtocol = {},\n", .{aws_protocol});
_ = try writer.write("} = .{};\n"); _ = try writer.write("} = .{};\n");
// Operations // Operations

View File

@ -27,9 +27,7 @@ pub const services = servicemodel.services;
/// This will give you a constant with service data for sts, ec2, s3 and ddb only /// This will give you a constant with service data for sts, ec2, s3 and ddb only
pub const Services = servicemodel.Services; pub const Services = servicemodel.Services;
pub const ClientOptions = struct { pub const ClientOptions = struct {};
trust_pem: ?[]const u8 = awshttp.default_root_ca,
};
pub const Client = struct { pub const Client = struct {
allocator: std.mem.Allocator, allocator: std.mem.Allocator,
aws_http: awshttp.AwsHttp, aws_http: awshttp.AwsHttp,
@ -37,9 +35,10 @@ pub const Client = struct {
const Self = @This(); const Self = @This();
pub fn init(allocator: std.mem.Allocator, options: ClientOptions) !Self { pub fn init(allocator: std.mem.Allocator, options: ClientOptions) !Self {
_ = options;
return Self{ return Self{
.allocator = allocator, .allocator = allocator,
.aws_http = try awshttp.AwsHttp.init(allocator, options.trust_pem), .aws_http = try awshttp.AwsHttp.init(allocator),
}; };
} }
pub fn deinit(self: *Client) void { pub fn deinit(self: *Client) void {
@ -208,7 +207,7 @@ pub fn Request(comptime request_action: anytype) type {
.query = "", .query = "",
.body = buffer.items, .body = buffer.items,
.content_type = content_type, .content_type = content_type,
.headers = &[_]awshttp.Header{.{ .name = "X-Amz-Target", .value = target }}, .headers = @constCast(&[_]awshttp.Header{.{ .name = "X-Amz-Target", .value = target }}),
}, options); }, options);
} }
@ -221,9 +220,8 @@ pub fn Request(comptime request_action: anytype) type {
var buffer = std.ArrayList(u8).init(options.client.allocator); var buffer = std.ArrayList(u8).init(options.client.allocator);
defer buffer.deinit(); defer buffer.deinit();
const writer = buffer.writer(); const writer = buffer.writer();
try url.encode(request, writer, .{ try url.encode(options.client.allocator, request, writer, .{
.field_name_transformer = &queryFieldTransformer, .field_name_transformer = &queryFieldTransformer,
.allocator = options.client.allocator,
}); });
const continuation = if (buffer.items.len > 0) "&" else ""; const continuation = if (buffer.items.len > 0) "&" else "";
@ -556,7 +554,7 @@ pub fn Request(comptime request_action: anytype) type {
// scenario, then parse as appropriate later // scenario, then parse as appropriate later
const SRawResponse = if (Self.service_meta.aws_protocol != .query and const SRawResponse = if (Self.service_meta.aws_protocol != .query and
std.meta.fields(action.Response).len == 1) std.meta.fields(action.Response).len == 1)
std.meta.fields(action.Response)[0].field_type std.meta.fields(action.Response)[0].type
else else
NullType; NullType;
@ -635,7 +633,7 @@ pub fn Request(comptime request_action: anytype) type {
}; };
return ParsedJsonData(response_types.NormalResponse){ return ParsedJsonData(response_types.NormalResponse){
.raw_response_parsed = raw_response_parsed, .raw_response_parsed = raw_response_parsed,
.parsed_response_ptr = parsed_response_ptr, .parsed_response_ptr = @constCast(parsed_response_ptr), //TODO: why doesn't changing const->var above fix this?
.allocator = options.client.allocator, .allocator = options.client.allocator,
}; };
} }
@ -792,39 +790,39 @@ fn ServerResponse(comptime action: anytype) type {
const Result = @Type(.{ const Result = @Type(.{
.Struct = .{ .Struct = .{
.layout = .Auto, .layout = .Auto,
.fields = &[_]std.builtin.TypeInfo.StructField{ .fields = &[_]std.builtin.Type.StructField{
.{ .{
.name = action.action_name ++ "Result", .name = action.action_name ++ "Result",
.field_type = T, .type = T,
.default_value = null, .default_value = null,
.is_comptime = false, .is_comptime = false,
.alignment = 0, .alignment = 0,
}, },
.{ .{
.name = "ResponseMetadata", .name = "ResponseMetadata",
.field_type = ResponseMetadata, .type = ResponseMetadata,
.default_value = null, .default_value = null,
.is_comptime = false, .is_comptime = false,
.alignment = 0, .alignment = 0,
}, },
}, },
.decls = &[_]std.builtin.TypeInfo.Declaration{}, .decls = &[_]std.builtin.Type.Declaration{},
.is_tuple = false, .is_tuple = false,
}, },
}); });
return @Type(.{ return @Type(.{
.Struct = .{ .Struct = .{
.layout = .Auto, .layout = .Auto,
.fields = &[_]std.builtin.TypeInfo.StructField{ .fields = &[_]std.builtin.Type.StructField{
.{ .{
.name = action.action_name ++ "Response", .name = action.action_name ++ "Response",
.field_type = Result, .type = Result,
.default_value = null, .default_value = null,
.is_comptime = false, .is_comptime = false,
.alignment = 0, .alignment = 0,
}, },
}, },
.decls = &[_]std.builtin.TypeInfo.Declaration{}, .decls = &[_]std.builtin.Type.Declaration{},
.is_tuple = false, .is_tuple = false,
}, },
}); });
@ -885,8 +883,9 @@ fn FullResponse(comptime action: anytype) type {
} }
}; };
} }
fn queryFieldTransformer(field_name: []const u8, encoding_options: url.EncodingOptions) anyerror![]const u8 { fn queryFieldTransformer(allocator: std.mem.Allocator, field_name: []const u8, options: url.EncodingOptions) anyerror![]const u8 {
return try case.snakeToPascal(encoding_options.allocator.?, field_name); _ = options;
return try case.snakeToPascal(allocator, field_name);
} }
fn buildPath( fn buildPath(
@ -984,26 +983,17 @@ fn buildQuery(allocator: std.mem.Allocator, request: anytype) ![]const u8 {
const writer = buffer.writer(); const writer = buffer.writer();
defer buffer.deinit(); defer buffer.deinit();
var prefix = "?"; var prefix = "?";
const Req = @TypeOf(request); // TODO: This was a pain before, and it's a pain now. Clearly our codegen
if (declaration(Req, "http_query") == null) // needs to emit a declaration 100% of the time
return buffer.toOwnedSlice(); const query_arguments = @TypeOf(request).http_query;
const query_arguments = Req.http_query;
inline for (@typeInfo(@TypeOf(query_arguments)).Struct.fields) |arg| { inline for (@typeInfo(@TypeOf(query_arguments)).Struct.fields) |arg| {
const val = @field(request, arg.name); const val = @field(request, arg.name);
if (try addQueryArg(arg.field_type, prefix, @field(query_arguments, arg.name), val, writer)) if (try addQueryArg(arg.type, prefix, @field(query_arguments, arg.name), val, writer))
prefix = "&"; prefix = "&";
} }
return buffer.toOwnedSlice(); return buffer.toOwnedSlice();
} }
fn declaration(comptime T: type, name: []const u8) ?std.builtin.TypeInfo.Declaration {
for (std.meta.declarations(T)) |decl| {
if (std.mem.eql(u8, name, decl.name))
return decl;
}
return null;
}
fn addQueryArg(comptime ValueType: type, prefix: []const u8, key: []const u8, value: anytype, writer: anytype) !bool { fn addQueryArg(comptime ValueType: type, prefix: []const u8, key: []const u8, value: anytype, writer: anytype) !bool {
switch (@typeInfo(@TypeOf(value))) { switch (@typeInfo(@TypeOf(value))) {
.Optional => { .Optional => {
@ -1044,7 +1034,9 @@ fn addBasicQueryArg(prefix: []const u8, key: []const u8, value: anytype, writer:
// TODO: url escaping // TODO: url escaping
try uriEncode(key, writer, true); try uriEncode(key, writer, true);
_ = try writer.write("="); _ = try writer.write("=");
try json.stringify(value, .{}, ignoringWriter(uriEncodingWriter(writer).writer(), '"').writer()); var encoding_writer = uriEncodingWriter(writer);
var ignoring_writer = ignoringWriter(encoding_writer.writer(), '"');
try json.stringify(value, .{}, ignoring_writer.writer());
return true; return true;
} }
pub fn uriEncodingWriter(child_stream: anytype) UriEncodingWriter(@TypeOf(child_stream)) { pub fn uriEncodingWriter(child_stream: anytype) UriEncodingWriter(@TypeOf(child_stream)) {
@ -1135,7 +1127,7 @@ fn typeForField(comptime T: type, field_name: []const u8) !type {
.Struct => { .Struct => {
inline for (ti.Struct.fields) |field| { inline for (ti.Struct.fields) |field| {
if (std.mem.eql(u8, field.name, field_name)) if (std.mem.eql(u8, field.name, field_name))
return field.field_type; return field.type;
} }
}, },
else => return error.TypeIsNotAStruct, // should not hit this else => return error.TypeIsNotAStruct, // should not hit this

View File

@ -7,7 +7,6 @@
const std = @import("std"); const std = @import("std");
const builtin = @import("builtin"); const builtin = @import("builtin");
const auth = @import("aws_authentication.zig"); const auth = @import("aws_authentication.zig");
const zfetch = @import("zfetch");
const log = std.log.scoped(.aws_credentials); const log = std.log.scoped(.aws_credentials);
@ -114,28 +113,32 @@ fn getContainerCredentials(allocator: std.mem.Allocator) !?auth.Credentials {
// from s3 and run // from s3 and run
const container_relative_uri = (try getEnvironmentVariable(allocator, "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI")) orelse return null; const container_relative_uri = (try getEnvironmentVariable(allocator, "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI")) orelse return null;
defer allocator.free(container_relative_uri); defer allocator.free(container_relative_uri);
try zfetch.init();
defer zfetch.deinit();
const container_uri = try std.fmt.allocPrint(allocator, "http://169.254.170.2{s}", .{container_relative_uri}); const container_uri = try std.fmt.allocPrint(allocator, "http://169.254.170.2{s}", .{container_relative_uri});
defer allocator.free(container_uri); defer allocator.free(container_uri);
var req = try zfetch.Request.init(allocator, container_uri, null); var empty_headers = std.http.Headers.init(allocator);
defer empty_headers.deinit();
var cl = std.http.Client{ .allocator = allocator };
defer cl.deinit(); // I don't belive connection pooling would help much here as it's non-ssl and local
var req = try cl.request(.GET, try std.Uri.parse(container_uri), empty_headers, .{});
defer req.deinit(); defer req.deinit();
try req.do(.GET, null, null); try req.start();
if (req.status.code != 200 and req.status.code != 404) { try req.wait();
log.warn("Bad status code received from container credentials endpoint: {}", .{req.status.code}); if (req.response.status != .ok and req.response.status != .not_found) {
log.warn("Bad status code received from container credentials endpoint: {}", .{@intFromEnum(req.response.status)});
return null; return null;
} }
if (req.status.code == 404) return null; if (req.response.status == .not_found) return null;
const reader = req.reader(); if (req.response.content_length == null or req.response.content_length.? == 0) return null;
var buf: [2048]u8 = undefined;
const read = try reader.read(&buf); var resp_payload = try std.ArrayList(u8).initCapacity(allocator, req.response.content_length.?);
if (read == 2048) { defer resp_payload.deinit();
log.warn("Unexpected long response from container credentials endpoint: {s}", .{buf}); try resp_payload.resize(req.response.content_length.?);
return null; var response_data = try resp_payload.toOwnedSlice();
} defer allocator.free(response_data);
log.debug("Read {d} bytes from container credentials endpoint", .{read}); _ = try req.readAll(response_data);
if (read == 0) return null; log.debug("Read {d} bytes from container credentials endpoint", .{response_data.len});
if (response_data.len == 0) return null;
const CredsResponse = struct { const CredsResponse = struct {
AccessKeyId: []const u8, AccessKeyId: []const u8,
@ -145,9 +148,8 @@ fn getContainerCredentials(allocator: std.mem.Allocator) !?auth.Credentials {
Token: []const u8, Token: []const u8,
}; };
const creds_response = blk: { const creds_response = blk: {
var stream = std.json.TokenStream.init(buf[0..read]); const res = std.json.parseFromSlice(CredsResponse, allocator, response_data, .{}) catch |e| {
const res = std.json.parse(CredsResponse, &stream, .{ .allocator = allocator }) catch |e| { log.err("Unexpected Json response from container credentials endpoint: {s}", .{response_data});
log.err("Unexpected Json response from container credentials endpoint: {s}", .{buf});
log.err("Error parsing json: {}", .{e}); log.err("Error parsing json: {}", .{e});
if (@errorReturnTrace()) |trace| { if (@errorReturnTrace()) |trace| {
std.debug.dumpStackTrace(trace.*); std.debug.dumpStackTrace(trace.*);
@ -157,83 +159,92 @@ fn getContainerCredentials(allocator: std.mem.Allocator) !?auth.Credentials {
}; };
break :blk res; break :blk res;
}; };
defer std.json.parseFree(CredsResponse, creds_response, .{ .allocator = allocator }); defer creds_response.deinit();
return auth.Credentials.init( return auth.Credentials.init(
allocator, allocator,
try allocator.dupe(u8, creds_response.AccessKeyId), try allocator.dupe(u8, creds_response.value.AccessKeyId),
try allocator.dupe(u8, creds_response.SecretAccessKey), try allocator.dupe(u8, creds_response.value.SecretAccessKey),
try allocator.dupe(u8, creds_response.Token), try allocator.dupe(u8, creds_response.value.Token),
); );
} }
fn getImdsv2Credentials(allocator: std.mem.Allocator) !?auth.Credentials { fn getImdsv2Credentials(allocator: std.mem.Allocator) !?auth.Credentials {
try zfetch.init(); var token: ?[]u8 = null;
defer zfetch.deinit(); defer if (token) |t| allocator.free(t);
var cl = std.http.Client{ .allocator = allocator };
var token: [1024]u8 = undefined; defer cl.deinit(); // I don't belive connection pooling would help much here as it's non-ssl and local
var len: usize = undefined;
// Get token // Get token
{ {
var headers = zfetch.Headers.init(allocator); var headers = std.http.Headers.init(allocator);
defer headers.deinit(); defer headers.deinit();
try headers.append("X-aws-ec2-metadata-token-ttl-seconds", "21600");
try headers.appendValue("X-aws-ec2-metadata-token-ttl-seconds", "21600"); var req = try cl.request(.PUT, try std.Uri.parse("http://169.254.169.254/latest/api/token"), headers, .{});
var req = try zfetch.Request.init(allocator, "http://169.254.169.254/latest/api/token", null);
defer req.deinit(); defer req.deinit();
try req.do(.PUT, headers, ""); try req.start();
if (req.status.code != 200) { try req.wait();
log.warn("Bad status code received from IMDS v2: {}", .{req.status.code}); if (req.response.status != .ok) {
log.warn("Bad status code received from IMDS v2: {}", .{@intFromEnum(req.response.status)});
return null; return null;
} }
const reader = req.reader(); if (req.response.content_length == null or req.response.content_length == 0) {
const read = try reader.read(&token); log.warn("Unexpected zero response from IMDS v2", .{});
if (read == 0 or read == 1024) {
log.warn("Unexpected zero or long response from IMDS v2: {s}", .{token});
return null; return null;
} }
len = read;
var resp_payload = try std.ArrayList(u8).initCapacity(allocator, req.response.content_length.?);
defer resp_payload.deinit();
try resp_payload.resize(req.response.content_length.?);
token = try resp_payload.toOwnedSlice();
errdefer allocator.free(token);
_ = try req.readAll(token.?);
} }
log.debug("Got token from IMDSv2", .{}); std.debug.assert(token != null);
const role_name = try getImdsRoleName(allocator, token[0..len]); log.debug("Got token from IMDSv2: {s}", .{token.?});
const role_name = try getImdsRoleName(allocator, &cl, token.?);
if (role_name == null) { if (role_name == null) {
log.info("No role is associated with this instance", .{}); log.info("No role is associated with this instance", .{});
return null; return null;
} }
defer allocator.free(role_name.?); defer allocator.free(role_name.?);
log.debug("Got role name '{s}'", .{role_name}); log.debug("Got role name '{s}'", .{role_name});
return getImdsCredentials(allocator, role_name.?, token[0..len]); return getImdsCredentials(allocator, &cl, role_name.?, token.?);
} }
fn getImdsRoleName(allocator: std.mem.Allocator, imds_token: []u8) !?[]const u8 { fn getImdsRoleName(allocator: std.mem.Allocator, client: *std.http.Client, imds_token: []u8) !?[]const u8 {
// { // {
// "Code" : "Success", // "Code" : "Success",
// "LastUpdated" : "2022-02-09T05:42:09Z", // "LastUpdated" : "2022-02-09T05:42:09Z",
// "InstanceProfileArn" : "arn:aws:iam::550620852718:instance-profile/ec2-dev", // "InstanceProfileArn" : "arn:aws:iam::550620852718:instance-profile/ec2-dev",
// "InstanceProfileId" : "AIPAYAM4POHXCFNKZ7HU2" // "InstanceProfileId" : "AIPAYAM4POHXCFNKZ7HU2"
// } // }
var buf: [255]u8 = undefined; var headers = std.http.Headers.init(allocator);
var headers = zfetch.Headers.init(allocator);
defer headers.deinit(); defer headers.deinit();
try headers.appendValue("X-aws-ec2-metadata-token", imds_token); try headers.append("X-aws-ec2-metadata-token", imds_token);
var req = try zfetch.Request.init(allocator, "http://169.254.169.254/latest/meta-data/iam/info", null); var req = try client.request(.GET, try std.Uri.parse("http://169.254.169.254/latest/meta-data/iam/info"), headers, .{});
defer req.deinit(); defer req.deinit();
try req.do(.GET, headers, null); try req.start();
try req.wait();
if (req.status.code != 200 and req.status.code != 404) { if (req.response.status != .ok and req.response.status != .not_found) {
log.warn("Bad status code received from IMDS iam endpoint: {}", .{req.status.code}); log.warn("Bad status code received from IMDS iam endpoint: {}", .{@intFromEnum(req.response.status)});
return null; return null;
} }
if (req.status.code == 404) return null; if (req.response.status == .not_found) return null;
const reader = req.reader(); if (req.response.content_length == null or req.response.content_length.? == 0) {
const read = try reader.read(&buf); log.warn("Unexpected empty response from IMDS endpoint post token", .{});
if (read == 255) {
log.warn("Unexpected zero or long response from IMDS endpoint post token: {s}", .{buf});
return null; return null;
} }
if (read == 0) return null; // TODO: This is all stupid. We can just allocate a freaking array and be done
var resp_payload = try std.ArrayList(u8).initCapacity(allocator, req.response.content_length.?);
defer resp_payload.deinit();
try resp_payload.resize(req.response.content_length.?);
// TODO: This feels safer, but can we avoid this?
const resp = try resp_payload.toOwnedSlice();
defer allocator.free(resp);
_ = try req.readAll(resp);
const ImdsResponse = struct { const ImdsResponse = struct {
Code: []const u8, Code: []const u8,
@ -241,22 +252,17 @@ fn getImdsRoleName(allocator: std.mem.Allocator, imds_token: []u8) !?[]const u8
InstanceProfileArn: []const u8, InstanceProfileArn: []const u8,
InstanceProfileId: []const u8, InstanceProfileId: []const u8,
}; };
const imds_response = blk: { const imds_response = std.json.parseFromSlice(ImdsResponse, allocator, resp, .{}) catch |e| {
var stream = std.json.TokenStream.init(buf[0..read]); log.err("Unexpected Json response from IMDS endpoint: {s}", .{resp});
const res = std.json.parse(ImdsResponse, &stream, .{ .allocator = allocator }) catch |e| { log.err("Error parsing json: {}", .{e});
log.err("Unexpected Json response from IMDS endpoint: {s}", .{buf}); if (@errorReturnTrace()) |trace| {
log.err("Error parsing json: {}", .{e}); std.debug.dumpStackTrace(trace.*);
if (@errorReturnTrace()) |trace| { }
std.debug.dumpStackTrace(trace.*); return null;
}
return null;
};
break :blk res;
}; };
defer std.json.parseFree(ImdsResponse, imds_response, .{ .allocator = allocator }); defer imds_response.deinit();
const role_arn = imds_response.InstanceProfileArn; const role_arn = imds_response.value.InstanceProfileArn;
const first_slash = std.mem.indexOf(u8, role_arn, "/"); // I think this is valid const first_slash = std.mem.indexOf(u8, role_arn, "/"); // I think this is valid
if (first_slash == null) { if (first_slash == null) {
log.err("Could not find role name in arn '{s}'", .{role_arn}); log.err("Could not find role name in arn '{s}'", .{role_arn});
@ -266,29 +272,37 @@ fn getImdsRoleName(allocator: std.mem.Allocator, imds_token: []u8) !?[]const u8
} }
/// Note - this internal function assumes zfetch is initialized prior to use /// Note - this internal function assumes zfetch is initialized prior to use
fn getImdsCredentials(allocator: std.mem.Allocator, role_name: []const u8, imds_token: []u8) !?auth.Credentials { fn getImdsCredentials(allocator: std.mem.Allocator, client: *std.http.Client, role_name: []const u8, imds_token: []u8) !?auth.Credentials {
var buf: [2048]u8 = undefined; var headers = std.http.Headers.init(allocator);
var headers = zfetch.Headers.init(allocator);
defer headers.deinit(); defer headers.deinit();
try headers.appendValue("X-aws-ec2-metadata-token", imds_token); try headers.append("X-aws-ec2-metadata-token", imds_token);
const url = try std.fmt.allocPrint(allocator, "http://169.254.169.254/latest/meta-data/iam/security-credentials/{s}/", .{role_name}); const url = try std.fmt.allocPrint(allocator, "http://169.254.169.254/latest/meta-data/iam/security-credentials/{s}/", .{role_name});
defer allocator.free(url); defer allocator.free(url);
var req = try zfetch.Request.init(allocator, url, null);
var req = try client.request(.GET, try std.Uri.parse(url), headers, .{});
defer req.deinit(); defer req.deinit();
try req.do(.GET, headers, null); try req.start();
try req.wait();
if (req.status.code != 200) { if (req.response.status != .ok and req.response.status != .not_found) {
log.warn("Bad status code received from IMDS role endpoint: {}", .{req.status.code}); log.warn("Bad status code received from IMDS role endpoint: {}", .{@intFromEnum(req.response.status)});
return null; return null;
} }
const reader = req.reader(); if (req.response.status == .not_found) return null;
const read = try reader.read(&buf); if (req.response.content_length == null or req.response.content_length.? == 0) {
if (read == 0 or read == 2048) { log.warn("Unexpected empty response from IMDS role endpoint", .{});
log.warn("Unexpected zero or long response from IMDS role endpoint: {s}", .{buf});
return null; return null;
} }
// TODO: This is still stupid
var resp_payload = try std.ArrayList(u8).initCapacity(allocator, req.response.content_length.?);
defer resp_payload.deinit();
try resp_payload.resize(req.response.content_length.?);
const resp = try resp_payload.toOwnedSlice();
defer allocator.free(resp);
_ = try req.readAll(resp);
// log.debug("Read {d} bytes from imds v2 credentials endpoint", .{read}); // log.debug("Read {d} bytes from imds v2 credentials endpoint", .{read});
const ImdsResponse = struct { const ImdsResponse = struct {
Code: []const u8, Code: []const u8,
@ -299,26 +313,22 @@ fn getImdsCredentials(allocator: std.mem.Allocator, role_name: []const u8, imds_
Token: []const u8, Token: []const u8,
Expiration: []const u8, Expiration: []const u8,
}; };
const imds_response = blk: { const imds_response = std.json.parseFromSlice(ImdsResponse, allocator, resp, .{}) catch |e| {
var stream = std.json.TokenStream.init(buf[0..read]); log.err("Unexpected Json response from IMDS endpoint: {s}", .{resp});
const res = std.json.parse(ImdsResponse, &stream, .{ .allocator = allocator }) catch |e| { log.err("Error parsing json: {}", .{e});
log.err("Unexpected Json response from IMDS endpoint: {s}", .{buf}); if (@errorReturnTrace()) |trace| {
log.err("Error parsing json: {}", .{e}); std.debug.dumpStackTrace(trace.*);
if (@errorReturnTrace()) |trace| { }
std.debug.dumpStackTrace(trace.*);
}
return null; return null;
};
break :blk res;
}; };
defer std.json.parseFree(ImdsResponse, imds_response, .{ .allocator = allocator }); defer imds_response.deinit();
const ret = auth.Credentials.init( const ret = auth.Credentials.init(
allocator, allocator,
try allocator.dupe(u8, imds_response.AccessKeyId), try allocator.dupe(u8, imds_response.value.AccessKeyId),
try allocator.dupe(u8, imds_response.SecretAccessKey), try allocator.dupe(u8, imds_response.value.SecretAccessKey),
try allocator.dupe(u8, imds_response.Token), try allocator.dupe(u8, imds_response.value.Token),
); );
log.debug("IMDSv2 credentials found. Access key: {s}", .{ret.access_key}); log.debug("IMDSv2 credentials found. Access key: {s}", .{ret.access_key});

View File

@ -11,8 +11,6 @@ const std = @import("std");
const base = @import("aws_http_base.zig"); const base = @import("aws_http_base.zig");
const signing = @import("aws_signing.zig"); const signing = @import("aws_signing.zig");
const credentials = @import("aws_credentials.zig"); const credentials = @import("aws_credentials.zig");
const zfetch = @import("zfetch");
const tls = @import("iguanaTLS");
const CN_NORTH_1_HASH = std.hash_map.hashString("cn-north-1"); const CN_NORTH_1_HASH = std.hash_map.hashString("cn-north-1");
const CN_NORTHWEST_1_HASH = std.hash_map.hashString("cn-northwest-1"); const CN_NORTHWEST_1_HASH = std.hash_map.hashString("cn-northwest-1");
@ -21,10 +19,6 @@ const US_ISOB_EAST_1_HASH = std.hash_map.hashString("us-isob-east-1");
const log = std.log.scoped(.awshttp); const log = std.log.scoped(.awshttp);
const amazon_root_ca_1 = @embedFile("Amazon_Root_CA_1.pem");
pub const default_root_ca = amazon_root_ca_1;
pub const AwsError = error{ pub const AwsError = error{
AddHeaderError, AddHeaderError,
AlpnError, AlpnError,
@ -67,27 +61,19 @@ const EndPoint = struct {
}; };
pub const AwsHttp = struct { pub const AwsHttp = struct {
allocator: std.mem.Allocator, allocator: std.mem.Allocator,
trust_chain: ?tls.x509.CertificateChain,
const Self = @This(); const Self = @This();
/// Recommend usage is init(allocator, awshttp.default_root_ca) /// Recommend usage is init(allocator, awshttp.default_root_ca)
/// Passing null for root_pem will result in no TLS verification /// Passing null for root_pem will result in no TLS verification
pub fn init(allocator: std.mem.Allocator, root_pem: ?[]const u8) !Self { pub fn init(allocator: std.mem.Allocator) !Self {
var trust_chain: ?tls.x509.CertificateChain = null;
if (root_pem) |p| {
var fbs = std.io.fixedBufferStream(p);
trust_chain = try tls.x509.CertificateChain.from_pem(allocator, fbs.reader());
}
return Self{ return Self{
.allocator = allocator, .allocator = allocator,
.trust_chain = trust_chain,
// .credentialsProvider = // creds provider could be useful // .credentialsProvider = // creds provider could be useful
}; };
} }
pub fn deinit(self: *AwsHttp) void { pub fn deinit(self: *AwsHttp) void {
if (self.trust_chain) |c| c.deinit();
_ = self; _ = self;
log.debug("Deinit complete", .{}); log.debug("Deinit complete", .{});
} }
@ -173,12 +159,10 @@ pub const AwsHttp = struct {
} }
} }
try zfetch.init(); // This only does anything on Windows. Not sure how performant it is to do this on every request var headers = std.http.Headers.init(self.allocator);
defer zfetch.deinit();
var headers = zfetch.Headers.init(self.allocator);
defer headers.deinit(); defer headers.deinit();
for (request_cp.headers) |header| for (request_cp.headers) |header|
try headers.appendValue(header.name, header.value); try headers.append(header.name, header.value);
log.debug("All Request Headers (zfetch):", .{}); log.debug("All Request Headers (zfetch):", .{});
for (headers.list.items) |h| { for (headers.list.items) |h| {
log.debug("\t{s}: {s}", .{ h.name, h.value }); log.debug("\t{s}: {s}", .{ h.name, h.value });
@ -187,24 +171,36 @@ pub const AwsHttp = struct {
const url = try std.fmt.allocPrint(self.allocator, "{s}{s}{s}", .{ endpoint.uri, request_cp.path, request_cp.query }); const url = try std.fmt.allocPrint(self.allocator, "{s}{s}{s}", .{ endpoint.uri, request_cp.path, request_cp.query });
defer self.allocator.free(url); defer self.allocator.free(url);
log.debug("Request url: {s}", .{url}); log.debug("Request url: {s}", .{url});
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! var cl = std.http.Client{ .allocator = self.allocator };
// PLEASE READ!! IF YOU ARE LOOKING AT THIS LINE OF CODE DUE TO A defer cl.deinit(); // TODO: Connection pooling
// SEGFAULT IN INIT, IT IS PROBABLY BECAUSE THE HOST DOES NOT EXIST //
// https://github.com/ziglang/zig/issues/11358 // var req = try zfetch.Request.init(self.allocator, url, self.trust_chain);
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // defer req.deinit();
var req = try zfetch.Request.init(self.allocator, url, self.trust_chain);
defer req.deinit();
const method = std.meta.stringToEnum(zfetch.Method, request_cp.method).?; const method = std.meta.stringToEnum(std.http.Method, request_cp.method).?;
try req.do(method, headers, if (request_cp.body.len == 0) null else request_cp.body); var req = try cl.request(method, try std.Uri.parse(url), headers, .{});
if (request_cp.body.len > 0)
req.transfer_encoding = .{ .content_length = request_cp.body.len };
try req.start();
if (request_cp.body.len > 0) {
try req.writeAll(request_cp.body);
try req.finish();
}
try req.wait();
// TODO: Timeout - is this now above us? // TODO: Timeout - is this now above us?
log.debug("Request Complete. Response code {d}: {s}", .{ req.status.code, req.status.reason }); log.debug(
"Request Complete. Response code {d}: {s}",
.{ @intFromEnum(req.response.status), req.response.status.phrase() },
);
log.debug("Response headers:", .{}); log.debug("Response headers:", .{});
var resp_headers = try std.ArrayList(Header).initCapacity(self.allocator, req.headers.list.items.len); var resp_headers = try std.ArrayList(Header).initCapacity(
self.allocator,
req.response.headers.list.items.len,
);
defer resp_headers.deinit(); defer resp_headers.deinit();
var content_length: usize = 0; var content_length: usize = 0;
for (req.headers.list.items) |h| { for (req.response.headers.list.items) |h| {
log.debug(" {s}: {s}", .{ h.name, h.value }); log.debug(" {s}: {s}", .{ h.name, h.value });
resp_headers.appendAssumeCapacity(.{ resp_headers.appendAssumeCapacity(.{
.name = try (self.allocator.dupe(u8, h.name)), .name = try (self.allocator.dupe(u8, h.name)),
@ -213,22 +209,20 @@ pub const AwsHttp = struct {
if (content_length == 0 and std.ascii.eqlIgnoreCase("content-length", h.name)) if (content_length == 0 and std.ascii.eqlIgnoreCase("content-length", h.name))
content_length = std.fmt.parseInt(usize, h.value, 10) catch 0; content_length = std.fmt.parseInt(usize, h.value, 10) catch 0;
} }
const reader = req.reader();
var buf: [65535]u8 = undefined; // TODO: This is still stupid. Allocate a freaking array
var resp_payload = try std.ArrayList(u8).initCapacity(self.allocator, content_length); var resp_payload = try std.ArrayList(u8).initCapacity(self.allocator, content_length);
defer resp_payload.deinit(); defer resp_payload.deinit();
try resp_payload.resize(content_length);
while (true) { var response_data = try resp_payload.toOwnedSlice();
const read = try reader.read(&buf); errdefer self.allocator.free(response_data);
try resp_payload.appendSlice(buf[0..read]); _ = try req.readAll(response_data);
if (read == 0) break; log.debug("raw response body:\n{s}", .{response_data});
}
log.debug("raw response body:\n{s}", .{resp_payload.items});
const rc = HttpResult{ const rc = HttpResult{
.response_code = req.status.code, .response_code = @intFromEnum(req.response.status),
.body = resp_payload.toOwnedSlice(), .body = response_data,
.headers = resp_headers.toOwnedSlice(), .headers = try resp_headers.toOwnedSlice(),
.allocator = self.allocator, .allocator = self.allocator,
}; };
return rc; return rc;

View File

@ -283,7 +283,9 @@ pub fn freeSignedRequest(allocator: std.mem.Allocator, request: *base.Request, c
} }
} }
if (remove_len > 0) if (remove_len > 0)
request.headers = allocator.resize(request.headers, request.headers.len - remove_len).?; // TODO: We should not be discarding this return value
// Why on earth are we resizing the array if we're about to free the whole thing anyway?
_ = allocator.resize(request.headers, request.headers.len - remove_len);
allocator.free(request.headers); allocator.free(request.headers);
} }
@ -434,7 +436,7 @@ fn encodeParamPart(allocator: std.mem.Allocator, path: []const u8) ![]const u8 {
should_encode = false; should_encode = false;
break; break;
}; };
if (should_encode and std.ascii.isAlNum(c)) if (should_encode and std.ascii.isAlphanumeric(c))
should_encode = false; should_encode = false;
if (!should_encode) { if (!should_encode) {
@ -468,7 +470,7 @@ fn encodeUri(allocator: std.mem.Allocator, path: []const u8) ![]const u8 {
break; break;
}; };
} }
if (should_encode and std.ascii.isAlNum(c)) if (should_encode and std.ascii.isAlphanumeric(c))
should_encode = false; should_encode = false;
if (!should_encode) { if (!should_encode) {
@ -538,7 +540,7 @@ fn canonicalQueryString(allocator: std.mem.Allocator, path: []const u8) ![]const
defer sort_me.deinit(); defer sort_me.deinit();
while (portions.next()) |item| while (portions.next()) |item|
try sort_me.append(item); try sort_me.append(item);
std.sort.sort([]const u8, sort_me.items, {}, lessThanBinary); std.sort.pdq([]const u8, sort_me.items, {}, lessThanBinary);
var normalized = try std.ArrayList(u8).initCapacity(allocator, path.len); var normalized = try std.ArrayList(u8).initCapacity(allocator, path.len);
defer normalized.deinit(); defer normalized.deinit();
@ -642,7 +644,7 @@ fn canonicalHeaders(allocator: std.mem.Allocator, headers: []base.Header, servic
try dest.append(.{ .name = n, .value = v }); try dest.append(.{ .name = n, .value = v });
} }
std.sort.sort(base.Header, dest.items, {}, lessThan); std.sort.pdq(base.Header, dest.items, {}, lessThan);
var dest_str = try std.ArrayList(u8).initCapacity(allocator, total_len); var dest_str = try std.ArrayList(u8).initCapacity(allocator, total_len);
defer dest_str.deinit(); defer dest_str.deinit();
@ -660,8 +662,8 @@ fn canonicalHeaders(allocator: std.mem.Allocator, headers: []base.Header, servic
signed_headers.appendSliceAssumeCapacity(h.name); signed_headers.appendSliceAssumeCapacity(h.name);
} }
return CanonicalHeaders{ return CanonicalHeaders{
.str = dest_str.toOwnedSlice(), .str = try dest_str.toOwnedSlice(),
.signed_headers = signed_headers.toOwnedSlice(), .signed_headers = try signed_headers.toOwnedSlice(),
}; };
} }
@ -672,12 +674,12 @@ fn canonicalHeaderValue(allocator: std.mem.Allocator, value: []const u8) ![]cons
const rc = try allocator.alloc(u8, value.len); const rc = try allocator.alloc(u8, value.len);
var rc_inx: usize = 0; var rc_inx: usize = 0;
for (value, 0..) |c, i| { for (value, 0..) |c, i| {
if (!started and !std.ascii.isSpace(c)) { if (!started and !std.ascii.isWhitespace(c)) {
started = true; started = true;
start = i; start = i;
} }
if (started) { if (started) {
if (!in_quote and i > 0 and std.ascii.isSpace(c) and std.ascii.isSpace(value[i - 1])) if (!in_quote and i > 0 and std.ascii.isWhitespace(c) and std.ascii.isWhitespace(value[i - 1]))
continue; continue;
// if (c == '"') in_quote = !in_quote; // if (c == '"') in_quote = !in_quote;
rc[rc_inx] = c; rc[rc_inx] = c;
@ -685,7 +687,7 @@ fn canonicalHeaderValue(allocator: std.mem.Allocator, value: []const u8) ![]cons
} }
} }
// Trim end // Trim end
while (std.ascii.isSpace(rc[rc_inx - 1])) while (std.ascii.isWhitespace(rc[rc_inx - 1]))
rc_inx -= 1; rc_inx -= 1;
return rc[0..rc_inx]; return rc[0..rc_inx];
} }

View File

@ -336,8 +336,8 @@ fn isLeapYear(year: u16) bool {
} }
fn leapYearsBetween(start_year_inclusive: u16, end_year_exclusive: u16) u16 { fn leapYearsBetween(start_year_inclusive: u16, end_year_exclusive: u16) u16 {
const start = std.math.min(start_year_inclusive, end_year_exclusive); const start = @min(start_year_inclusive, end_year_exclusive);
const end = std.math.max(start_year_inclusive, end_year_exclusive); const end = @max(start_year_inclusive, end_year_exclusive);
var current = start; var current = start;
log.debug("Leap years starting from {d}, ending at {d}", .{ start, end }); log.debug("Leap years starting from {d}, ending at {d}", .{ start, end });
while (current % 4 != 0 and current < end) { while (current % 4 != 0 and current < end) {

View File

@ -1624,7 +1624,7 @@ fn parseInternal(comptime T: type, token: Token, tokens: *TokenStream, options:
inline for (unionInfo.fields) |u_field| { inline for (unionInfo.fields) |u_field| {
// take a copy of tokens so we can withhold mutations until success // take a copy of tokens so we can withhold mutations until success
var tokens_copy = tokens.*; var tokens_copy = tokens.*;
if (parseInternal(u_field.field_type, token, &tokens_copy, options)) |value| { if (parseInternal(u_field.type, token, &tokens_copy, options)) |value| {
tokens.* = tokens_copy; tokens.* = tokens_copy;
return @unionInit(T, u_field.name, value); return @unionInit(T, u_field.name, value);
} else |err| { } else |err| {
@ -1654,7 +1654,7 @@ fn parseInternal(comptime T: type, token: Token, tokens: *TokenStream, options:
@setEvalBranchQuota(100000); @setEvalBranchQuota(100000);
inline for (structInfo.fields, 0..) |field, i| { inline for (structInfo.fields, 0..) |field, i| {
if (fields_seen[i] and !field.is_comptime) { if (fields_seen[i] and !field.is_comptime) {
parseFree(field.field_type, @field(r, field.name), options); parseFree(field.type, @field(r, field.name), options);
} }
} }
} }
@ -1683,16 +1683,16 @@ fn parseInternal(comptime T: type, token: Token, tokens: *TokenStream, options:
} else if (options.duplicate_field_behavior == .Error) { } else if (options.duplicate_field_behavior == .Error) {
return error.DuplicateJSONField; return error.DuplicateJSONField;
} else if (options.duplicate_field_behavior == .UseLast) { } else if (options.duplicate_field_behavior == .UseLast) {
parseFree(field.field_type, @field(r, field.name), options); parseFree(field.type, @field(r, field.name), options);
fields_seen[i] = false; fields_seen[i] = false;
} }
} }
if (field.is_comptime) { if (field.is_comptime) {
if (!try parsesTo(field.field_type, field.default_value.?, tokens, options)) { if (!try parsesTo(field.type, field.default_value.?, tokens, options)) {
return error.UnexpectedValue; return error.UnexpectedValue;
} }
} else { } else {
@field(r, field.name) = try parse(field.field_type, tokens, options); @field(r, field.name) = try parse(field.type, tokens, options);
} }
fields_seen[i] = true; fields_seen[i] = true;
found = true; found = true;
@ -1722,9 +1722,10 @@ fn parseInternal(comptime T: type, token: Token, tokens: *TokenStream, options:
} }
inline for (structInfo.fields, 0..) |field, i| { inline for (structInfo.fields, 0..) |field, i| {
if (!fields_seen[i]) { if (!fields_seen[i]) {
if (field.default_value) |default| { if (field.default_value) |default_value_ptr| {
if (!field.is_comptime) { if (!field.is_comptime) {
@field(r, field.name) = default; const default_value = @as(*align(1) const field.type, @ptrCast(default_value_ptr)).*;
@field(r, field.name) = default_value;
} }
} else { } else {
if (!options.allow_missing_fields) if (!options.allow_missing_fields)
@ -1815,33 +1816,36 @@ fn parseInternal(comptime T: type, token: Token, tokens: *TokenStream, options:
} }
}, },
.ObjectBegin => { .ObjectBegin => {
// TODO: Fix this, or better yet, try to switch
// back to standard json parse
return error.NotConvertedToZig11;
// We are parsing into a slice, but we have an // We are parsing into a slice, but we have an
// ObjectBegin. This might be ok, iff the type // ObjectBegin. This might be ok, iff the type
// follows this pattern: []struct { key: []const u8, value: anytype } // follows this pattern: []struct { key: []const u8, value: anytype }
// (could key be anytype?). // (could key be anytype?).
if (!isMapPattern(T)) // if (!isMapPattern(T))
return error.UnexpectedToken; // return error.UnexpectedToken;
var arraylist = std.ArrayList(ptrInfo.child).init(allocator); // var arraylist = std.ArrayList(ptrInfo.child).init(allocator);
errdefer { // errdefer {
while (arraylist.popOrNull()) |v| { // while (arraylist.popOrNull()) |v| {
parseFree(ptrInfo.child, v, options); // parseFree(ptrInfo.child, v, options);
} // }
arraylist.deinit(); // arraylist.deinit();
} // }
while (true) { // while (true) {
const key = (try tokens.next()) orelse return error.UnexpectedEndOfJson; // const key = (try tokens.next()) orelse return error.UnexpectedEndOfJson;
switch (key) { // switch (key) {
.ObjectEnd => break, // .ObjectEnd => break,
else => {}, // else => {},
} // }
//
try arraylist.ensureTotalCapacity(arraylist.items.len + 1); // try arraylist.ensureTotalCapacity(arraylist.items.len + 1);
const key_val = try parseInternal(try typeForField(ptrInfo.child, "key"), key, tokens, options); // const key_val = try parseInternal(try typeForField(ptrInfo.child, "key"), key, tokens, options);
const val = (try tokens.next()) orelse return error.UnexpectedEndOfJson; // const val = (try tokens.next()) orelse return error.UnexpectedEndOfJson;
const val_val = try parseInternal(try typeForField(ptrInfo.child, "value"), val, tokens, options); // const val_val = try parseInternal(try typeForField(ptrInfo.child, "value"), val, tokens, options);
arraylist.appendAssumeCapacity(.{ .key = key_val, .value = val_val }); // arraylist.appendAssumeCapacity(.{ .key = key_val, .value = val_val });
} // }
return arraylist.toOwnedSlice(); // return arraylist.toOwnedSlice();
}, },
else => return error.UnexpectedToken, else => return error.UnexpectedToken,
} }
@ -1854,13 +1858,13 @@ fn parseInternal(comptime T: type, token: Token, tokens: *TokenStream, options:
unreachable; unreachable;
} }
fn typeForField(comptime T: type, field_name: []const u8) !type { fn typeForField(comptime T: type, comptime field_name: []const u8) !type {
const ti = @typeInfo(T); const ti = @typeInfo(T);
switch (ti) { switch (ti) {
.Struct => { .Struct => {
inline for (ti.Struct.fields) |field| { inline for (ti.Struct.fields) |field| {
if (std.mem.eql(u8, field.name, field_name)) if (std.mem.eql(u8, field.name, field_name))
return field.field_type; return field.type;
} }
}, },
else => return error.TypeIsNotAStruct, // should not hit this else => return error.TypeIsNotAStruct, // should not hit this
@ -1907,7 +1911,7 @@ pub fn parseFree(comptime T: type, value: T, options: ParseOptions) void {
if (unionInfo.tag_type) |UnionTagType| { if (unionInfo.tag_type) |UnionTagType| {
inline for (unionInfo.fields) |u_field| { inline for (unionInfo.fields) |u_field| {
if (value == @field(UnionTagType, u_field.name)) { if (value == @field(UnionTagType, u_field.name)) {
parseFree(u_field.field_type, @field(value, u_field.name), options); parseFree(u_field.type, @field(value, u_field.name), options);
break; break;
} }
} }
@ -1917,7 +1921,7 @@ pub fn parseFree(comptime T: type, value: T, options: ParseOptions) void {
}, },
.Struct => |structInfo| { .Struct => |structInfo| {
inline for (structInfo.fields) |field| { inline for (structInfo.fields) |field| {
parseFree(field.field_type, @field(value, field.name), options); parseFree(field.type, @field(value, field.name), options);
} }
}, },
.Array => |arrayInfo| { .Array => |arrayInfo| {
@ -2855,7 +2859,7 @@ pub fn stringify(
} }
inline for (S.fields) |Field| { inline for (S.fields) |Field| {
// don't include void fields // don't include void fields
if (Field.field_type == void) continue; if (Field.type == void) continue;
if (!field_output) { if (!field_output) {
field_output = true; field_output = true;
@ -3172,5 +3176,5 @@ test "stringify struct with custom stringifier" {
} }
test "stringify vector" { test "stringify vector" {
// try teststringify("[1,1]", @splat(2, @as(u32, 1)), StringifyOptions{}); try teststringify("[1,1]", @as(@Vector(2, u32), @splat(@as(u32, 1))), StringifyOptions{});
} }

View File

@ -318,13 +318,13 @@ pub fn main() anyerror!void {
std.log.info("===== Tests complete =====", .{}); std.log.info("===== Tests complete =====", .{});
} }
fn typeForField(comptime T: type, field_name: []const u8) !type { fn typeForField(comptime T: type, comptime field_name: []const u8) !type {
const ti = @typeInfo(T); const ti = @typeInfo(T);
switch (ti) { switch (ti) {
.Struct => { .Struct => {
inline for (ti.Struct.fields) |field| { inline for (ti.Struct.fields) |field| {
if (std.mem.eql(u8, field.name, field_name)) if (std.mem.eql(u8, field.name, field_name))
return field.field_type; return field.type;
} }
}, },
else => return error.TypeIsNotAStruct, // should not hit this else => return error.TypeIsNotAStruct, // should not hit this

View File

@ -1,57 +1,72 @@
const std = @import("std"); const std = @import("std");
fn defaultTransformer(field_name: []const u8, _: EncodingOptions) anyerror![]const u8 { fn defaultTransformer(allocator: std.mem.Allocator, field_name: []const u8, options: EncodingOptions) anyerror![]const u8 {
_ = options;
_ = allocator;
return field_name; return field_name;
} }
pub const FieldNameTransformer = fn ([]const u8, EncodingOptions) anyerror![]const u8; pub const fieldNameTransformerFn = *const fn (std.mem.Allocator, []const u8, EncodingOptions) anyerror![]const u8;
pub const EncodingOptions = struct { pub const EncodingOptions = struct {
allocator: ?std.mem.Allocator = null, field_name_transformer: fieldNameTransformerFn = &defaultTransformer,
field_name_transformer: *const FieldNameTransformer = &defaultTransformer,
}; };
pub fn encode(obj: anytype, writer: anytype, options: EncodingOptions) !void { pub fn encode(allocator: std.mem.Allocator, obj: anytype, writer: anytype, comptime options: EncodingOptions) !void {
_ = try encodeInternal("", "", true, obj, writer, options); _ = try encodeInternal(allocator, "", "", true, obj, writer, options);
} }
fn encodeStruct(parent: []const u8, first: bool, obj: anytype, writer: anytype, options: EncodingOptions) !bool { fn encodeStruct(
allocator: std.mem.Allocator,
parent: []const u8,
first: bool,
obj: anytype,
writer: anytype,
comptime options: EncodingOptions,
) !bool {
var rc = first; var rc = first;
inline for (@typeInfo(@TypeOf(obj)).Struct.fields) |field| { inline for (@typeInfo(@TypeOf(obj)).Struct.fields) |field| {
const field_name = try options.field_name_transformer.*(field.name, options); const field_name = try options.field_name_transformer(allocator, field.name, options);
defer if (options.field_name_transformer.* != defaultTransformer) defer if (options.field_name_transformer.* != defaultTransformer)
if (options.allocator) |a| a.free(field_name); allocator.free(field_name);
// @compileLog(@typeInfo(field.field_type).Pointer); // @compileLog(@typeInfo(field.field_type).Pointer);
rc = try encodeInternal(parent, field_name, rc, @field(obj, field.name), writer, options); rc = try encodeInternal(allocator, parent, field_name, rc, @field(obj, field.name), writer, options);
} }
return rc; return rc;
} }
pub fn encodeInternal(parent: []const u8, field_name: []const u8, first: bool, obj: anytype, writer: anytype, options: EncodingOptions) !bool { pub fn encodeInternal(
allocator: std.mem.Allocator,
parent: []const u8,
field_name: []const u8,
first: bool,
obj: anytype,
writer: anytype,
comptime options: EncodingOptions,
) !bool {
// @compileLog(@typeInfo(@TypeOf(obj))); // @compileLog(@typeInfo(@TypeOf(obj)));
var rc = first; var rc = first;
switch (@typeInfo(@TypeOf(obj))) { switch (@typeInfo(@TypeOf(obj))) {
.Optional => if (obj) |o| { .Optional => if (obj) |o| {
rc = try encodeInternal(parent, field_name, first, o, writer, options); rc = try encodeInternal(allocator, parent, field_name, first, o, writer, options);
}, },
.Pointer => |ti| if (ti.size == .One) { .Pointer => |ti| if (ti.size == .One) {
rc = try encodeInternal(parent, field_name, first, obj.*, writer, options); rc = try encodeInternal(allocator, parent, field_name, first, obj.*, writer, options);
} else { } else {
if (!first) _ = try writer.write("&"); if (!first) _ = try writer.write("&");
try writer.print("{s}{s}={s}", .{ parent, field_name, obj }); try writer.print("{s}{s}={s}", .{ parent, field_name, obj });
rc = false; rc = false;
}, },
.Struct => if (std.mem.eql(u8, "", field_name)) { .Struct => if (std.mem.eql(u8, "", field_name)) {
rc = try encodeStruct(parent, first, obj, writer, options); rc = try encodeStruct(allocator, parent, first, obj, writer, options);
} else { } else {
// TODO: It would be lovely if we could concat at compile time or allocPrint at runtime // TODO: It would be lovely if we could concat at compile time or allocPrint at runtime
// XOR have compile time allocator support. Alas, neither are possible: // XOR have compile time allocator support. Alas, neither are possible:
// https://github.com/ziglang/zig/issues/868: Comptime detection (feels like foot gun) // https://github.com/ziglang/zig/issues/868: Comptime detection (feels like foot gun)
// https://github.com/ziglang/zig/issues/1291: Comptime allocator // https://github.com/ziglang/zig/issues/1291: Comptime allocator
const allocator = options.allocator orelse return error.AllocatorRequired;
const new_parent = try std.fmt.allocPrint(allocator, "{s}{s}.", .{ parent, field_name }); const new_parent = try std.fmt.allocPrint(allocator, "{s}{s}.", .{ parent, field_name });
defer allocator.free(new_parent); defer allocator.free(new_parent);
rc = try encodeStruct(new_parent, first, obj, writer, options); rc = try encodeStruct(allocator, new_parent, first, obj, writer, options);
// try encodeStruct(parent ++ field_name ++ ".", first, obj, writer, options); // try encodeStruct(parent ++ field_name ++ ".", first, obj, writer, options);
}, },
.Array => { .Array => {

View File

@ -70,7 +70,8 @@ pub const Element = struct {
} }
pub fn findChildByTag(self: *Element, tag: []const u8) !?*Element { pub fn findChildByTag(self: *Element, tag: []const u8) !?*Element {
return try self.findChildrenByTag(tag).next(); var it = self.findChildrenByTag(tag);
return try it.next();
} }
pub fn findChildrenByTag(self: *Element, tag: []const u8) FindChildrenByTagIterator { pub fn findChildrenByTag(self: *Element, tag: []const u8) FindChildrenByTagIterator {
@ -116,7 +117,7 @@ pub const Element = struct {
pub const FindChildrenByTagIterator = struct { pub const FindChildrenByTagIterator = struct {
inner: ChildElementIterator, inner: ChildElementIterator,
tag: []const u8, tag: []const u8,
predicate: fn (a: []const u8, b: []const u8, options: PredicateOptions) anyerror!bool = strictEqual, predicate: *const fn (a: []const u8, b: []const u8, options: PredicateOptions) anyerror!bool = strictEqual,
predicate_options: PredicateOptions = .{}, predicate_options: PredicateOptions = .{},
pub fn next(self: *FindChildrenByTagIterator) !?*Element { pub fn next(self: *FindChildrenByTagIterator) !?*Element {
@ -650,7 +651,10 @@ fn dupeAndUnescape(alloc: Allocator, text: []const u8) ![]const u8 {
} }
} }
return alloc.shrink(str, j); // This error is not strictly true, but we need to match one of the items
// from the error set provided by the other stdlib calls at the calling site
if (!alloc.resize(str, j)) return error.OutOfMemory;
return str;
} }
test "dupeAndUnescape" { test "dupeAndUnescape" {

View File

@ -83,7 +83,7 @@ pub fn parse(comptime T: type, source: []const u8, options: ParseOptions) !Parse
errdefer parsed.deinit(); errdefer parsed.deinit();
const opts = ParseOptions{ const opts = ParseOptions{
.allocator = aa, .allocator = aa,
.match_predicate = options.match_predicate, .match_predicate_ptr = options.match_predicate_ptr,
}; };
return Parsed(T).init(arena_allocator, try parseInternal(T, parsed.root, opts), parsed); return Parsed(T).init(arena_allocator, try parseInternal(T, parsed.root, opts), parsed);
@ -123,7 +123,7 @@ fn parseInternal(comptime T: type, element: *xml.Element, options: ParseOptions)
// We have an iso8601 in an integer field (we think) // We have an iso8601 in an integer field (we think)
// Try to coerce this into our type // Try to coerce this into our type
const timestamp = try date.parseIso8601ToTimestamp(element.children.items[0].CharData); const timestamp = try date.parseIso8601ToTimestamp(element.children.items[0].CharData);
return try std.math.cast(T, timestamp); return std.math.cast(T, timestamp).?;
} }
if (log_parse_traces) { if (log_parse_traces) {
std.log.err( std.log.err(
@ -167,7 +167,7 @@ fn parseInternal(comptime T: type, element: *xml.Element, options: ParseOptions)
// inline for (union_info.fields) |u_field| { // inline for (union_info.fields) |u_field| {
// // take a copy of tokens so we can withhold mutations until success // // take a copy of tokens so we can withhold mutations until success
// var tokens_copy = tokens.*; // var tokens_copy = tokens.*;
// if (parseInternal(u_field.field_type, token, &tokens_copy, options)) |value| { // if (parseInternal(u_field.type, token, &tokens_copy, options)) |value| {
// tokens.* = tokens_copy; // tokens.* = tokens_copy;
// return @unionInit(T, u_field.name, value); // return @unionInit(T, u_field.name, value);
// } else |err| { // } else |err| {
@ -193,7 +193,7 @@ fn parseInternal(comptime T: type, element: *xml.Element, options: ParseOptions)
// @setEvalBranchQuota(100000); // @setEvalBranchQuota(100000);
// inline for (struct_info.fields) |field, i| { // inline for (struct_info.fields) |field, i| {
// if (fields_seen[i] and !field.is_comptime) { // if (fields_seen[i] and !field.is_comptime) {
// parseFree(field.field_type, @field(r, field.name), options); // parseFree(field.type, @field(r, field.name), options);
// } // }
// } // }
// } // }
@ -220,31 +220,31 @@ fn parseInternal(comptime T: type, element: *xml.Element, options: ParseOptions)
name = r.fieldNameFor(field.name); name = r.fieldNameFor(field.name);
log.debug("Field name: {s}, Element: {s}, Adjusted field name: {s}", .{ field.name, element.tag, name }); log.debug("Field name: {s}, Element: {s}, Adjusted field name: {s}", .{ field.name, element.tag, name });
var iterator = element.findChildrenByTag(name); var iterator = element.findChildrenByTag(name);
if (options.match_predicate) |predicate| { if (options.match_predicate_ptr) |predicate_ptr| {
iterator.predicate = predicate; iterator.predicate = predicate_ptr;
iterator.predicate_options = .{ .allocator = options.allocator.? }; iterator.predicate_options = .{ .allocator = options.allocator.? };
} }
if (try iterator.next()) |child| { if (try iterator.next()) |child| {
// I don't know that we would use comptime here. I'm also // I don't know that we would use comptime here. I'm also
// not sure the nuance of setting this... // not sure the nuance of setting this...
// if (field.is_comptime) { // if (field.is_comptime) {
// if (!try parsesTo(field.field_type, field.default_value.?, tokens, options)) { // if (!try parsesTo(field.type, field.default_value.?, tokens, options)) {
// return error.UnexpectedValue; // return error.UnexpectedValue;
// } // }
// } else { // } else {
log.debug("Found child element {s}", .{child.tag}); log.debug("Found child element {s}", .{child.tag});
// TODO: how do we errdefer this? // TODO: how do we errdefer this?
@field(r, field.name) = try parseInternal(field.field_type, child, options); @field(r, field.name) = try parseInternal(field.type, child, options);
fields_seen[i] = true; fields_seen[i] = true;
fields_set = fields_set + 1; fields_set = fields_set + 1;
found_value = true; found_value = true;
} }
if (@typeInfo(field.field_type) == .Optional) { if (@typeInfo(field.type) == .Optional) {
// Test "compiler assertion failure 2" // Test "compiler assertion failure 2"
// Zig compiler bug circa 0.9.0. Using "and !found_value" // Zig compiler bug circa 0.9.0. Using "and !found_value"
// in the if statement above will trigger assertion failure // in the if statement above will trigger assertion failure
if (!found_value) { if (!found_value) {
// @compileLog("Optional: Field name ", field.name, ", type ", field.field_type); // @compileLog("Optional: Field name ", field.name, ", type ", field.type);
@field(r, field.name) = null; @field(r, field.name) = null;
fields_set = fields_set + 1; fields_set = fields_set + 1;
found_value = true; found_value = true;