From 812ad96a668eda7ac634bf2274ae2708b787d7e5 Mon Sep 17 00:00:00 2001 From: Emil Lerch Date: Tue, 29 Aug 2023 11:24:34 -0700 Subject: [PATCH] add proxy support --- src/aws.zig | 9 +++++---- src/aws_http.zig | 8 ++++---- src/main.zig | 30 ++++++++++++++++++++++++++++-- 3 files changed, 37 insertions(+), 10 deletions(-) diff --git a/src/aws.zig b/src/aws.zig index a719ad0..1b74b46 100644 --- a/src/aws.zig +++ b/src/aws.zig @@ -27,18 +27,19 @@ pub const services = servicemodel.services; /// This will give you a constant with service data for sts, ec2, s3 and ddb only pub const Services = servicemodel.Services; -pub const ClientOptions = struct {}; +pub const ClientOptions = struct { + proxy: ?std.http.Client.HttpProxy = null, +}; pub const Client = struct { allocator: std.mem.Allocator, aws_http: awshttp.AwsHttp, const Self = @This(); - pub fn init(allocator: std.mem.Allocator, options: ClientOptions) !Self { - _ = options; + pub fn init(allocator: std.mem.Allocator, options: ClientOptions) Self { return Self{ .allocator = allocator, - .aws_http = try awshttp.AwsHttp.init(allocator), + .aws_http = awshttp.AwsHttp.init(allocator, options.proxy), }; } pub fn deinit(self: *Client) void { diff --git a/src/aws_http.zig b/src/aws_http.zig index 9dc44ff..9f46800 100644 --- a/src/aws_http.zig +++ b/src/aws_http.zig @@ -61,14 +61,14 @@ const EndPoint = struct { }; pub const AwsHttp = struct { allocator: std.mem.Allocator, + proxy: ?std.http.Client.HttpProxy, const Self = @This(); - /// Recommend usage is init(allocator, awshttp.default_root_ca) - /// Passing null for root_pem will result in no TLS verification - pub fn init(allocator: std.mem.Allocator) !Self { + pub fn init(allocator: std.mem.Allocator, proxy: ?std.http.Client.HttpProxy) Self { return Self{ .allocator = allocator, + .proxy = proxy, // .credentialsProvider = // creds provider could be useful }; } @@ -171,7 +171,7 @@ pub const AwsHttp = struct { const url = try std.fmt.allocPrint(self.allocator, "{s}{s}{s}", .{ endpoint.uri, request_cp.path, request_cp.query }); defer self.allocator.free(url); log.debug("Request url: {s}", .{url}); - var cl = std.http.Client{ .allocator = self.allocator }; + var cl = std.http.Client{ .allocator = self.allocator, .proxy = self.proxy }; defer cl.deinit(); // TODO: Connection pooling // // var req = try zfetch.Request.init(self.allocator, url, self.trust_chain); diff --git a/src/main.zig b/src/main.zig index 17d38cc..a9bcf54 100644 --- a/src/main.zig +++ b/src/main.zig @@ -68,11 +68,12 @@ pub fn main() anyerror!void { defer bw.flush() catch unreachable; const stdout = bw.writer(); var arg0: ?[]const u8 = null; + var proxy: ?std.http.Client.HttpProxy = null; while (args.next()) |arg| { if (arg0 == null) arg0 = arg; if (std.mem.eql(u8, "-h", arg) or std.mem.eql(u8, "--help", arg)) { try stdout.print( - \\usage: {?s} [-h|--help] [-v][-v][-v] [test_name...] + \\usage: {?s} [-h|--help] [-v][-v][-v] [-x|--proxy ] [tests...] \\ \\Where tests are one of the following: \\ @@ -82,6 +83,10 @@ pub fn main() anyerror!void { } return; } + if (std.mem.eql(u8, "-x", arg) or std.mem.eql(u8, "--proxy", arg)) { + proxy = try proxyFromString(args.next().?); // parse stuff + continue; + } if (std.mem.eql(u8, "-v", arg)) { verbose += 1; continue; @@ -99,7 +104,8 @@ pub fn main() anyerror!void { } std.log.info("Start\n", .{}); - var client = try aws.Client.init(allocator, .{}); + const client_options = aws.ClientOptions{ .proxy = proxy }; + var client = aws.Client.init(allocator, client_options); const options = aws.Options{ .region = "us-west-2", .client = client, @@ -339,6 +345,26 @@ pub fn main() anyerror!void { std.log.info("===== Tests complete =====", .{}); } + +fn proxyFromString(string: []const u8) !std.http.Client.HttpProxy { + var rc = std.http.Client.HttpProxy{ + .protocol = undefined, + .host = undefined, + }; + var remaining: []const u8 = string; + if (std.mem.startsWith(u8, string, "http://")) { + remaining = remaining["http://".len..]; + rc.protocol = .plain; + } else if (std.mem.startsWith(u8, string, "https://")) { + remaining = remaining["https://".len..]; + rc.protocol = .tls; + } else return error.InvalidScheme; + var split_iterator = std.mem.split(u8, remaining, ":"); + rc.host = std.mem.trimRight(u8, split_iterator.first(), "/"); + if (split_iterator.next()) |port| + rc.port = try std.fmt.parseInt(u16, port, 10); + return rc; +} fn typeForField(comptime T: type, comptime field_name: []const u8) !type { const ti = @typeInfo(T); switch (ti) {