From d3efa21a415866f1176692f5c22cf8bd1f2fdd03 Mon Sep 17 00:00:00 2001 From: Emil Lerch Date: Tue, 27 Apr 2021 11:24:01 -0700 Subject: [PATCH] first thing that actually works --- .gitignore | 2 + Dockerfile | 104 ++++ LICENSE | 21 + Makefile | 15 + README.md | 103 ++++ build.zig | 51 ++ src/aws.zig | 1145 +++++++++++++++++++++++++++++++++++++ src/bitfield-workaround.c | 34 ++ src/bitfield-workaround.h | 142 +++++ src/bool.zig | 55 ++ src/main.zig | 63 ++ src/xml.zig | 649 +++++++++++++++++++++ 12 files changed, 2384 insertions(+) create mode 100644 .gitignore create mode 100644 Dockerfile create mode 100644 LICENSE create mode 100644 Makefile create mode 100644 README.md create mode 100644 build.zig create mode 100644 src/aws.zig create mode 100644 src/bitfield-workaround.c create mode 100644 src/bitfield-workaround.h create mode 100644 src/bool.zig create mode 100644 src/main.zig create mode 100644 src/xml.zig diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4fef7a8 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +.cache +zig-cache diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..9a28e8d --- /dev/null +++ b/Dockerfile @@ -0,0 +1,104 @@ +FROM alpine:3.13 AS base +# gcc gets us libgcc.a, even though the build should be using clang +RUN apk add --no-cache clang git cmake make lld musl-dev gcc && \ + rm /usr/bin/ld && \ + ln -s /usr/bin/ld.lld /usr/bin/ld && rm /usr/bin/gcc # just to be sure + +FROM base AS common +# d5f9398d48d9c318563db08100e2e87b24ea3656 +# RUN git clone --depth 1 -b pthread-np https://github.com/r-burns/aws-c-common && \ +RUN git clone --depth 1 -b v0.5.2 https://github.com/awslabs/aws-c-common && \ + mkdir aws-c-common-build && cd aws-c-common-build && \ + cmake ../aws-c-common && \ + make -j12 && make test && make install + +RUN tar -czf aws-c-common-clang.tgz /usr/local/* + +FROM base AS openssl +RUN apk add --no-cache perl linux-headers && \ + git clone --depth 1 -b OpenSSL_1_1_1i https://github.com/openssl/openssl && \ + cd openssl && ./Configure linux-x86_64-clang && make && make install + +RUN tar -czf openssl-clang.tgz /usr/local/* + +FROM base AS s2n +ENV S2N_LIBCRYPTO=openssl-1.1.1 +COPY --from=openssl /openssl-clang.tgz / +RUN git clone --depth 1 -b v0.10.26 https://github.com/awslabs/s2n && \ + tar -xzf openssl-clang.tgz && \ + mkdir s2n-build && cd s2n-build && \ + cmake ../s2n && \ + make -j12 && make install + +RUN tar -czf s2n-clang.tgz /usr/local/* + +FROM base AS cal +COPY --from=openssl /openssl-clang.tgz / +COPY --from=common /aws-c-common-clang.tgz / +# environment not used - just busting docker's cache +ENV COMMIT=d1a4d +# RUN git clone --depth 1 -b v0.4.5 https://github.com/awslabs/aws-c-cal && \ +RUN git clone --depth 1 https://github.com/elerch/aws-c-cal && \ + tar -xzf aws-c-common-clang.tgz && \ + tar -xzf openssl-clang.tgz && \ + mkdir cal-build && cd cal-build && \ + cmake -DCMAKE_MODULE_PATH=/usr/local/lib64/cmake ../aws-c-cal && \ + make -j12 && make install +# No make test: +# 40 - ecdsa_p384_test_key_gen_export (Failed) +RUN tar -czf aws-c-cal-clang.tgz /usr/local/* + +FROM base AS compression +COPY --from=common /aws-c-common-clang.tgz / +RUN git clone --depth 1 -b v0.2.10 https://github.com/awslabs/aws-c-compression && \ + tar -xzf aws-c-common-clang.tgz && \ + mkdir compression-build && cd compression-build && \ + cmake -DCMAKE_MODULE_PATH=/usr/local/lib64/cmake ../aws-c-compression && \ + make -j12 && make test && make install + +RUN tar -czf aws-c-compression-clang.tgz /usr/local/* + +FROM base AS io +# Cal includes common and openssl +COPY --from=cal /aws-c-cal-clang.tgz / +COPY --from=s2n /s2n-clang.tgz / +RUN git clone --depth 1 -b v0.9.1 https://github.com/awslabs/aws-c-io && \ + tar -xzf s2n-clang.tgz && \ + tar -xzf aws-c-cal-clang.tgz && \ + mkdir io-build && cd io-build && \ + cmake -DCMAKE_MODULE_PATH=/usr/local/lib64/cmake ../aws-c-io && \ + make -j12 && make install + +RUN tar -czf aws-c-io-clang.tgz /usr/local/* + +FROM base AS http +# Cal includes common and openssl +# 2 test failures on musl - both "download medium file" +COPY --from=io /aws-c-io-clang.tgz / +COPY --from=compression /aws-c-compression-clang.tgz / +# RUN git clone --depth 1 -b v0.5.19 https://github.com/awslabs/aws-c-http && \ +RUN git clone --depth 1 -b v0.6.1 https://github.com/awslabs/aws-c-http && \ + tar -xzf aws-c-io-clang.tgz && \ + tar -xzf aws-c-compression-clang.tgz && \ + mkdir http-build && cd http-build && \ + cmake -DCMAKE_MODULE_PATH=/usr/local/lib64/cmake ../aws-c-http && \ + make -j12 && make install + +RUN tar -czf aws-c-http-clang.tgz /usr/local/* + +FROM base AS auth +# http should have all other dependencies +COPY --from=http /aws-c-http-clang.tgz / +RUN git clone --depth 1 -b v0.5.0 https://github.com/awslabs/aws-c-auth && \ + tar -xzf aws-c-http-clang.tgz && \ + mkdir auth-build && cd auth-build && \ + cmake -DCMAKE_MODULE_PATH=/usr/local/lib64/cmake ../aws-c-auth && \ + make -j12 && make install # chunked_signing_test fails + +RUN tar -czf aws-c-auth-clang.tgz /usr/local/* + +FROM alpine:3.13 as final +COPY --from=auth /aws-c-auth-clang.tgz / +ADD https://ziglang.org/download/0.7.1/zig-linux-x86_64-0.7.1.tar.xz / +RUN tar -xzf /aws-c-auth-clang.tgz && mkdir /src && tar -C /usr/local -xf zig-linux* && \ + ln -s /usr/local/zig-linux*/zig /usr/local/bin/zig diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..0feb744 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 Emil Lerch + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..fb87279 --- /dev/null +++ b/Makefile @@ -0,0 +1,15 @@ +start-hand-test: src/main.zig src/aws.zig src/xml.zig + @zig build-exe -static -I/usr/local/include -Isrc/ -lc --strip \ + --name start-hand-test src/main.zig src/bitfield-workaround.c \ + /usr/local/lib64/libaws-c-*.a \ + /usr/local/lib64/libs2n.a \ + /usr/local/lib/libcrypto.a \ + /usr/local/lib/libssl.a + +elasticurl: curl.c + @zig build-exe -static -I/usr/local/include -Isrc/ -lc --strip \ + --name elasticurl curl.c \ + /usr/local/lib64/libaws-c-*.a \ + /usr/local/lib64/libs2n.a \ + /usr/local/lib/libcrypto.a \ + /usr/local/lib/libssl.a diff --git a/README.md b/README.md new file mode 100644 index 0000000..070b04c --- /dev/null +++ b/README.md @@ -0,0 +1,103 @@ +# AWS SDK for Zig + +Ok, so it's not actually an SDK (yet). Right now this is SDK supports sts +get-caller-identity action only. Why? Because it's one of the easiest to +support, so I started there. From here, the next major step is to codegen +the types necessary to support the various services. Currently this code is +dynamically generating the sts types so we are somewhat codegen ready, but +current comptime limitations might trip us up. The advantage of comptime is +that only types actually used would be generated vs the whole surface area +of AWS. That said, with most of the heavy lifting now coded, the addition +of the request/response types, even if all of them are added, should not +balloon the size beyond "reasonable". Of course this still needs to be be seen. + +This is my first serious zig effort, so please issue a PR if the code isn't +"ziggy" or if there's a better way. + +This is designed to be built statically using the `aws_c_*` libraries, so +we inherit a lot of the goodness of the work going on there. Implementing +get-caller-identity with all dependencies statically linked gives us a stripped +executable size of 5.3M for x86_linux (which is all that's tested at the moment). + +## Building + +I am assuming here that if you're playing with zig, you pretty much know +what you're doing, so I will stay brief. + +First, the dependencies are required. Use the Dockerfile to build these. +a `docker build` will do, but be prepared for it to run a while. Openssl in +particular will take a while, but without any particular knowledge +I'm also hoping/expecting AWS to factor out that library sometime in +the future. + +Once that's done, you'll have an alpine image with all dependencies ready +to go and zig 0.7.1 installed. The build.zig currently relies on +[this PR to allow stripping -static](https://github.com/ziglang/zig/pull/8248), +so either: + +* Modify build.zig, then strip (or not) after the fact +* Install make and use the included Makefile + +## Running + +This library uses the aws c libraries for it's work, so it operates like most +other 'AWS things'. Note that I tested by setting the appropriate environment +variables, so config files haven't gotten a run through. +main.zig gives you a program to call sts GetCallerIdentity. +For local testing or alternative endpoints, there's no real standard, so +there is code to look for `AWS_ENDPOINT_URL` environment variable that will +supercede all other configuration. + +## Dependencies + + +Full dependency tree: +aws-c-auth + * s2n + * openssl + * aws-c-common + * aws-c-compression + * aws-c-common + * aws-c-http + * s2n + * aws-c-common + * aws-c-io + * aws-c-common + * s2n + * openssl + * aws-c-cal + * aws-c-compression + * aws-c-common + * aws-c-cal + * aws-c-common + +Build order based on above: + +1. aws-c-common +1. openssl +2. s2n +2. aws-c-cal +2. aws-c-compression +3. aws-c-io +4. aws-c-http +5. aws-c-auth + +Dockerfile in this repo will manage this + +TODO List: + +* Implement jitter/exponential backoff. This appears to be configuration of `aws_c_io` and should therefore be trivial +* Implement timeouts and other TODO's in the code +* Implement error handling for 4xx, 5xx and other unexpected return values +* Implement generic response body -> Response type handling (right now, this is hard-coded) +* Implement codegen for services with xml structures (using Smithy models) +* Implement codegen for others (using Smithy models) +* Issue PR in c libraries for full static musl build support (see Dockerfile) +* Remove compiler 0.7.1 shims when 0.8.0 is released + +Compiler wishlist/watchlist: + +* Fix the weirdness we see with comptime type generation (see aws.zig around line 251) +* [Allow declarations for comptime type generation](https://github.com/ziglang/zig/issues/6709) +* [Merge PR to allow stripping -static](https://github.com/ziglang/zig/pull/8248) +* [comptime allocations](https://github.com/ziglang/zig/issues/1291) so we can read files, etc (or is there another way) diff --git a/build.zig b/build.zig new file mode 100644 index 0000000..8f5b12e --- /dev/null +++ b/build.zig @@ -0,0 +1,51 @@ +// const std = @import("std"); +const Builder = @import("std").build.Builder; + +pub fn build(b: *Builder) void { + // Standard target options allows the person running `zig build` to choose + // what target to build for. Here we do not override the defaults, which + // means any target is allowed, and the default is native. Other options + // for restricting supported target set are available. + const target = b.standardTargetOptions(.{}); + + // Standard release options allow the person running `zig build` to select + // between Debug, ReleaseSafe, ReleaseFast, and ReleaseSmall. + const mode = b.standardReleaseOptions(); + const exe = b.addExecutable("start-hand-test", "src/main.zig"); + + exe.addCSourceFile("src/bitfield-workaround.c", &[_][]const u8{"-std=c99"}); + exe.addIncludeDir("./src/"); + exe.addIncludeDir("/usr/local/include"); + exe.addObjectFile("/usr/local/lib64/libs2n.a"); + exe.addObjectFile("/usr/local/lib/libcrypto.a"); + exe.addObjectFile("/usr/local/lib/libssl.a"); + exe.addObjectFile("/usr/local/lib64/libaws-c-auth.a"); + exe.addObjectFile("/usr/local/lib64/libaws-c-cal.a"); + exe.addObjectFile("/usr/local/lib64/libaws-c-common.a"); + exe.addObjectFile("/usr/local/lib64/libaws-c-compression.a"); + exe.addObjectFile("/usr/local/lib64/libaws-c-http.a"); + exe.addObjectFile("/usr/local/lib64/libaws-c-io.a"); + exe.linkSystemLibrary("c"); + exe.setTarget(target); + exe.setBuildMode(mode); + exe.override_dest_dir = .{ .Custom = ".." }; + + // TODO: Figure out -static + // Neither of these two work + // exe.addCompileFlags([][]const u8{ + // "-static", + // "--strip", + // }); + exe.is_static = true; + exe.strip = true; + exe.install(); + + const run_cmd = exe.run(); + run_cmd.step.dependOn(b.getInstallStep()); + if (b.args) |args| { + run_cmd.addArgs(args); + } + + const run_step = b.step("run", "Run the app"); + run_step.dependOn(&run_cmd.step); +} diff --git a/src/aws.zig b/src/aws.zig new file mode 100644 index 0000000..1e50c00 --- /dev/null +++ b/src/aws.zig @@ -0,0 +1,1145 @@ +const std = @import("std"); +const xml = @import("xml.zig"); +const c = @cImport({ + @cInclude("bitfield-workaround.h"); + @cInclude("aws/common/allocator.h"); + @cInclude("aws/common/error.h"); + @cInclude("aws/common/string.h"); + @cInclude("aws/auth/auth.h"); + @cInclude("aws/auth/credentials.h"); + @cInclude("aws/auth/signable.h"); + @cInclude("aws/auth/signing_config.h"); + @cInclude("aws/auth/signing_result.h"); + @cInclude("aws/auth/signing.h"); + @cInclude("aws/http/connection.h"); + @cInclude("aws/http/request_response.h"); + @cInclude("aws/io/channel_bootstrap.h"); + @cInclude("aws/io/tls_channel_handler.h"); + @cInclude("aws/io/event_loop.h"); + @cInclude("aws/io/socket.h"); + @cInclude("aws/io/stream.h"); +}); +const std_atomic_bool = @import("bool.zig"); // This is in std in 0.8.0 + +const CN_NORTH_1_HASH = std.hash_map.hashString("cn-north-1"); +const CN_NORTHWEST_1_HASH = std.hash_map.hashString("cn-northwest-1"); +const US_ISO_EAST_1_HASH = std.hash_map.hashString("us-iso-east-1"); +const US_ISOB_EAST_1_HASH = std.hash_map.hashString("us-isob-east-1"); + +var reference_count: u32 = 0; +var c_allocator: ?*c.aws_allocator = null; +var c_logger: c.aws_logger = .{ + .vtable = null, + .allocator = null, + .p_impl = null, +}; + +const log = std.log.scoped(.aws); +const httplog = std.log.scoped(.awshttp); + +// Code "generation" prototype +// TODO: Make generic +pub fn Services() type { + const types = [_]type{ + Service("sts"), + }; + return @Type(.{ + .Struct = .{ + .layout = .Auto, + .fields = &[_]std.builtin.TypeInfo.StructField{ + .{ + .name = "sts", + .field_type = types[0], + .default_value = new(types[0]), + .is_comptime = false, + .alignment = 0, + }, + }, + .decls = &[_]std.builtin.TypeInfo.Declaration{}, + .is_tuple = false, + }, + }); +} + +fn ServiceActionResponse(comptime service: []const u8, comptime action: []const u8) type { + if (std.mem.eql(u8, service, "sts") and std.mem.eql(u8, action, "get_caller_identity")) { + return struct { + arn: []const u8, + user_id: []const u8, + account: []const u8, + response_metadata: ResponseMetadata, + + allocator: *std.mem.Allocator, + raw_response: xml.Document, + + // If this is purely generic we won't be able to do it here as + // declarations aren't supported yet + // pub fn deinit(self: *const GetCallerIdentityResponse) void { + // self.responseMetadata.deinit(); + // self.rawResponse.deinit(); + // } + }; + } + unreachable; +} + +fn ServiceAction(comptime service: []const u8, comptime action: []const u8) type { + if (std.mem.eql(u8, service, "sts") and std.mem.eql(u8, action, "get_caller_identity")) { + return @Type(.{ + .Struct = .{ + .layout = .Auto, + .fields = &[_]std.builtin.TypeInfo.StructField{ + .{ + .name = "Request", + .field_type = type, + .default_value = struct {}, + .is_comptime = false, + .alignment = 0, + }, + .{ + .name = "action_name", + .field_type = @TypeOf("GetCallerIdentity"), + .default_value = "GetCallerIdentity", + .is_comptime = false, + .alignment = 0, + }, + // TODO: maybe best is to separate requests from responses in whole other struct? + .{ + .name = "Response", + .field_type = type, + .default_value = ServiceActionResponse("sts", "get_caller_identity"), + .is_comptime = false, + .alignment = 0, + }, + }, + .decls = &[_]std.builtin.TypeInfo.Declaration{}, + .is_tuple = false, + }, + }); + } + unreachable; +} + +pub const services = Services(){}; + +fn new(comptime T: type) T { + return T{}; +} +fn Service(comptime service: []const u8) type { + if (std.mem.eql(u8, "sts", service)) { + return @Type(.{ + .Struct = .{ + .layout = .Auto, + .fields = &[_]std.builtin.TypeInfo.StructField{ + .{ + .name = "version", + .field_type = @TypeOf("2011-06-15"), + .default_value = "2011-06-15", + .is_comptime = false, + .alignment = 0, + }, + .{ + .name = "get_caller_identity", + .field_type = ServiceAction("sts", "get_caller_identity"), + .default_value = new(ServiceAction("sts", "get_caller_identity")), + .is_comptime = false, + .alignment = 0, + }, + }, + .decls = &[_]std.builtin.TypeInfo.Declaration{}, + .is_tuple = false, + }, + }); + } + unreachable; +} +// End code "generation" prototype + +pub const Aws = struct { + allocator: *std.mem.Allocator, + bootstrap: *c.aws_client_bootstrap, + resolver: *c.aws_host_resolver, + eventLoopGroup: *c.aws_event_loop_group, + credentialsProvider: *c.aws_credentials_provider, + + var tls_ctx_options: ?*c.aws_tls_ctx_options = null; + var tls_ctx: ?*c.aws_tls_ctx = null; + + pub fn responseDeinit(raw_response: xml.Document, response_metadata: ?ResponseMetadata) void { + raw_response.deinit(); + if (response_metadata) |meta| { + meta.deinit(); + } + } + + fn AsyncResult(comptime T: type) type { + return struct { + result: *T, + requiredCount: u32 = 1, + sync: std_atomic_bool.Bool = std_atomic_bool.Bool.init(false), // This is a 0.8.0 feature... :( + count: u8 = 0, + }; + } + + fn AwsAsyncCallbackResult(comptime T: type) type { + return struct { + result: ?*T = null, + error_code: i32 = c.AWS_ERROR_SUCCESS, + }; + } + + const Self = @This(); + + pub fn init(allocator: *std.mem.Allocator) Self { + if (reference_count == 0) cInit(allocator); + reference_count += 1; + log.debug("auth ref count: {}", .{reference_count}); + // TODO; determine appropriate lifetime for the bootstrap and credentials' + // provider + // Mostly stolen from aws_c_auth/credentials_tests.c + const el_group = c.aws_event_loop_group_new_default(c_allocator, 1, null); + + var resolver_options = c.aws_host_resolver_default_options{ + .el_group = el_group, + .max_entries = 8, + .shutdown_options = null, // not set in test + .system_clock_override_fn = null, // not set in test + }; + + const resolver = c.aws_host_resolver_new_default(c_allocator, &resolver_options); + + const bootstrap_options = c.aws_client_bootstrap_options{ + .host_resolver = resolver, + .on_shutdown_complete = null, // was set in test + .host_resolution_config = null, + .user_data = null, + .event_loop_group = el_group, + }; + + const bootstrap = c.aws_client_bootstrap_new(c_allocator, &bootstrap_options); + const provider_chain_options = c.aws_credentials_provider_chain_default_options{ + .bootstrap = bootstrap, + .shutdown_options = c.aws_credentials_provider_shutdown_options{ + .shutdown_callback = null, // was set on test + .shutdown_user_data = null, + }, + }; + return .{ + .allocator = allocator, + .bootstrap = bootstrap, + .resolver = resolver, + .eventLoopGroup = el_group, + .credentialsProvider = c.aws_credentials_provider_new_chain_default(c_allocator, &provider_chain_options), + }; + } + pub fn deinit(self: *Aws) void { + if (reference_count > 0) + reference_count -= 1; + log.debug("deinit: auth ref count: {}", .{reference_count}); + c.aws_credentials_provider_release(self.credentialsProvider); + // TODO: Wait for provider shutdown? https://github.com/awslabs/aws-c-auth/blob/c394e30808816a8edaab712e77f79f480c911d3a/tests/credentials_tests.c#L197 + c.aws_client_bootstrap_release(self.bootstrap); + c.aws_host_resolver_release(self.resolver); + c.aws_event_loop_group_release(self.eventLoopGroup); + if (reference_count == 0) { + cDeinit(); + log.debug("Deinit complete", .{}); + } + } + pub fn call(self: Self, comptime request: anytype, options: Options) !Response(request) { + const action_info = actionForRequest(request); + // This is true weirdness, but we are running into compiler bugs. Touch only if + // prepared... + const service = @field(services, action_info.service); + const action = @field(service, action_info.action); + const R = Response(request); + + log.debug("service {s}", .{action_info.service}); + log.debug("version {s}", .{service.version}); + log.debug("action {s}", .{action.action_name}); + const response = try self.callApi(action_info.service, service.version, action.action_name, options); + defer response.deinit(); + // TODO: Check status code for badness + const doc = try xml.parse(self.allocator, response.body); + const result = doc.root.findChildByTag("GetCallerIdentityResult"); + return R{ + .arn = result.?.getCharData("Arn").?, + .user_id = result.?.getCharData("UserId").?, + .account = result.?.getCharData("Account").?, + .allocator = self.allocator, + .raw_response = doc, + .response_metadata = try metadataFromResponse(self.allocator, response.body), + }; + } + fn actionForRequest(comptime request: anytype) struct { service: []const u8, action: []const u8, service_obj: anytype } { + const type_name = @typeName(@TypeOf(request)); + var service_start: usize = 0; + var service_end: usize = 0; + var action_start: usize = 0; + var action_end: usize = 0; + for (type_name) |ch, i| { + switch (ch) { + '(' => service_start = i + 2, + ')' => action_end = i - 1, + ',' => { + service_end = i - 1; + action_start = i + 2; + }, + else => continue, + } + } + // const zero: usize = 0; + // TODO: Figure out why if statement isn't working + // if (serviceStart == zero or serviceEnd == zero or actionStart == zero or actionEnd == zero) { + // @compileLog("Type must be a function with two parameters \"service\" and \"action\". Found: " ++ type_name); + // // @compileError("Type must be a function with two parameters \"service\" and \"action\". Found: " ++ type_name); + // } + return .{ + .service = type_name[service_start..service_end], + .action = type_name[action_start..action_end], + .service_obj = @field(services, type_name[service_start..service_end]), + }; + } + fn Response(comptime request: anytype) type { + const action_info = actionForRequest(request); + const service = @field(services, action_info.service); + const action = @field(service, action_info.action); + return action.Response; + } + fn callApi(self: Self, service: []const u8, version: []const u8, action: []const u8, options: Options) !HttpResult { + const endpoint = try regionSubDomain(self.allocator, service, options.region, options.dualstack); + defer endpoint.deinit(); + const body = try std.fmt.allocPrint(self.allocator, "Action={s}&Version={s}\n", .{ action, version }); + defer self.allocator.free(body); + httplog.debug("Calling {s}.{s}, endpoint {s}", .{ service, action, endpoint.uri }); + const signing_options: SigningOptions = .{ + .region = options.region, + .service = service, + }; + return try self.makeRequest(endpoint, "POST", "/", body, signing_options); + } + + fn signRequest(self: Self, http_request: *c.aws_http_message, options: SigningOptions) !void { + const creds = try self.getCredentials(); + defer c.aws_credentials_release(creds); + // print the access key. Creds are an opaque C type, so we + // use aws_credentials_get_access_key_id. That gets us an aws_byte_cursor, + // from which we create a new aws_string with the contents. We need + // to convert to c_str with aws_string_c_str + const access_key = c.aws_string_new_from_cursor(c_allocator, &c.aws_credentials_get_access_key_id(creds)); + defer c.aws_mem_release(c_allocator, access_key); + // defer c_allocator.*.mem_release.?(c_allocator, access_key); + log.debug("Signing with access key: {s}", .{c.aws_string_c_str(access_key)}); + + const signable = c.aws_signable_new_http_request(c_allocator, http_request); + if (signable == null) { + log.warn("Could not create signable request", .{}); + return AwsError.SignableError; + } + defer c.aws_signable_destroy(signable); + + const signing_region = try std.fmt.allocPrint(self.allocator, "{s}", .{options.region}); + defer self.allocator.free(signing_region); + const signing_service = try std.fmt.allocPrint(self.allocator, "{s}", .{options.service}); + defer self.allocator.free(signing_service); + const temp_signing_config = c.bitfield_workaround_aws_signing_config_aws{ + .algorithm = .AWS_SIGNING_ALGORITHM_V4, + .config_type = .AWS_SIGNING_CONFIG_AWS, + .signature_type = .AWS_ST_HTTP_REQUEST_HEADERS, + .region = c.aws_byte_cursor_from_c_str(@ptrCast([*c]const u8, signing_region)), + .service = c.aws_byte_cursor_from_c_str(@ptrCast([*c]const u8, signing_service)), + .should_sign_header = null, + .should_sign_header_ud = null, + .flags = c.bitfield_workaround_aws_signing_config_aws_flags{ + .use_double_uri_encode = 0, + .should_normalize_uri_path = 0, + .omit_session_token = 1, + }, + .signed_body_value = c.aws_byte_cursor_from_c_str(""), + .signed_body_header = .AWS_SBHT_X_AMZ_CONTENT_SHA256, //or AWS_SBHT_NONE + .credentials = creds, + .credentials_provider = self.credentialsProvider, + .expiration_in_seconds = 0, + }; + var signing_config = c.new_aws_signing_config(c_allocator, &temp_signing_config); + defer c.aws_mem_release(c_allocator, signing_config); + var signing_result = AwsAsyncCallbackResult(c.aws_http_message){ .result = http_request }; + var sign_result_request = AsyncResult(AwsAsyncCallbackResult(c.aws_http_message)){ .result = &signing_result }; + if (c.aws_sign_request_aws(c_allocator, signable, fullCast([*c]const c.aws_signing_config_base, signing_config), signComplete, &sign_result_request) != c.AWS_OP_SUCCESS) { + const error_code = c.aws_last_error(); + log.alert("Could not initiate signing request: {s}:{s}", .{ c.aws_error_name(error_code), c.aws_error_str(error_code) }); + return AwsError.SigningInitiationError; + } + + // Wait for callback. Note that execution, including real work of signing + // the http request, will continue in signComplete (below), + // then continue beyond this line + waitOnCallback(c.aws_http_message, &sign_result_request); + if (sign_result_request.result.error_code != c.AWS_ERROR_SUCCESS) { + return AwsError.SignableError; + } + } + + /// It's my theory that the aws event loop has a trigger to corrupt the + /// signing result after this call completes. So the technique of assigning + /// now, using later will not work + fn signComplete(result: ?*c.aws_signing_result, error_code: c_int, user_data: ?*c_void) callconv(.C) void { + var async_result = userDataTo(AsyncResult(AwsAsyncCallbackResult(c.aws_http_message)), user_data); + var http_request = async_result.result.result; + async_result.sync.store(true, .SeqCst); + + async_result.count += 1; + async_result.result.error_code = error_code; + + if (result) |res| { + if (c.aws_apply_signing_result_to_http_request(http_request, c_allocator, result) != c.AWS_OP_SUCCESS) { + log.alert("Could not apply signing request to http request: {s}", .{c.aws_error_debug_str(c.aws_last_error())}); + } + log.debug("signing result applied", .{}); + } else { + log.alert("Did not receive signing result: {s}", .{c.aws_error_debug_str(c.aws_last_error())}); + } + async_result.sync.store(false, .SeqCst); + } + + fn fullCast(comptime T: type, val: anytype) T { + return @ptrCast(T, @alignCast(@alignOf(T), val)); + } + + const HttpResult = struct { + body: []const u8, + fn deinit(self: HttpResult) void { + httplog.debug("http result deinit complete", .{}); + return; + } + }; + + // This is a fairly generic "make an http/https request" method and could + // potentially be extracted to another type that's non-AWS specific. + // It does make AWS signing if signingoptions are passed, which could be + // some function passed in, or just left as needed. + fn makeRequest(self: Self, endpoint: EndPoint, method: []const u8, path: []const u8, body: []const u8, signing_options: ?SigningOptions) !HttpResult { + // TODO: Try to re-encapsulate this + // var http_request = try createRequest(method, path, body); + + // TODO: Likely this should be encapsulated more + var http_request = c.aws_http_message_new_request(c_allocator); + defer c.aws_http_message_release(http_request); + // TODO: Verify if AWS cares about these headers (probably should be passing them...) + // Accept-Encoding: identity + // Content-Type: application/x-www-form-urlencoded + + if (c.aws_http_message_set_request_method(http_request, c.aws_byte_cursor_from_c_str(@ptrCast([*c]const u8, method))) != c.AWS_OP_SUCCESS) + return AwsError.SetRequestMethodError; + + if (c.aws_http_message_set_request_path(http_request, c.aws_byte_cursor_from_c_str(@ptrCast([*c]const u8, path))) != c.AWS_OP_SUCCESS) + return AwsError.SetRequestPathError; + + httplog.debug("body length: {d}", .{body.len}); + const body_cursor = c.aws_byte_cursor_from_c_str(@ptrCast([*c]const u8, body)); + const request_body = c.aws_input_stream_new_from_cursor(c_allocator, &body_cursor); + defer c.aws_input_stream_destroy(request_body); + if (body.len > 0) { + c.aws_http_message_set_body_stream(http_request, request_body); + } + + // End CreateRequest. This should return a struct with a deinit function that can do + // destroys, etc + + var context = RequestContext{ + .allocator = self.allocator, + }; + var tls_connection_options: ?*c.aws_tls_connection_options = null; + const host = try std.fmt.allocPrint(self.allocator, "{s}", .{endpoint.host}); + defer self.allocator.free(host); + try self.addHeaders(http_request.?, host, body); + if (std.mem.eql(u8, endpoint.scheme, "https")) { + // TODO: Figure out why this needs to be inline vs function call + // tls_connection_options = try self.setupTls(host); + if (Aws.tls_ctx_options == null) { + httplog.debug("Setting up tls options", .{}); + var opts: c.aws_tls_ctx_options = .{ + .allocator = c_allocator, + .minimum_tls_version = @intToEnum(c.aws_tls_versions, c.AWS_IO_TLS_VER_SYS_DEFAULTS), + .cipher_pref = @intToEnum(c.aws_tls_cipher_pref, c.AWS_IO_TLS_CIPHER_PREF_SYSTEM_DEFAULT), + .ca_file = c.aws_byte_buf_from_c_str(""), + .ca_path = c.aws_string_new_from_c_str(c_allocator, ""), + .alpn_list = null, + .certificate = c.aws_byte_buf_from_c_str(""), + .private_key = c.aws_byte_buf_from_c_str(""), + .max_fragment_size = 0, + .verify_peer = true, + }; + Aws.tls_ctx_options = &opts; + + c.aws_tls_ctx_options_init_default_client(Aws.tls_ctx_options.?, c_allocator); + // h2;http/1.1 + if (c.aws_tls_ctx_options_set_alpn_list(Aws.tls_ctx_options, "http/1.1") != c.AWS_OP_SUCCESS) { + httplog.alert("Failed to load alpn list with error {s}.", .{c.aws_error_debug_str(c.aws_last_error())}); + return AwsError.AlpnError; + } + + Aws.tls_ctx = c.aws_tls_client_ctx_new(c_allocator, Aws.tls_ctx_options.?); + + if (Aws.tls_ctx == null) { + std.debug.panic("Failed to initialize TLS context with error {s}.", .{c.aws_error_debug_str(c.aws_last_error())}); + } + httplog.debug("tls options setup applied", .{}); + } + var conn_opts = c.aws_tls_connection_options{ + .alpn_list = null, + .server_name = null, + .on_negotiation_result = null, + .on_data_read = null, + .on_error = null, + .user_data = null, + .ctx = null, + .advertise_alpn_message = false, + .timeout_ms = 0, + }; + tls_connection_options = &conn_opts; + c.aws_tls_connection_options_init_from_ctx(tls_connection_options, tls_ctx); + var host_var = host; + var host_cur = c.aws_byte_cursor_from_c_str(@ptrCast([*c]const u8, host_var)); + if (c.aws_tls_connection_options_set_server_name(tls_connection_options, c_allocator, &host_cur) != c.AWS_OP_SUCCESS) { + httplog.alert("Failed to set servername with error {s}.", .{c.aws_error_debug_str(c.aws_last_error())}); + return AwsError.TlsError; + } + } + if (signing_options) |opts| try self.signRequest(http_request.?, opts); + const socket_options = c.aws_socket_options{ + .type = @intToEnum(c.aws_socket_type, c.AWS_SOCKET_STREAM), + .domain = @intToEnum(c.aws_socket_domain, c.AWS_SOCKET_IPV4), + .connect_timeout_ms = 3000, // TODO: change hardcoded 3s value + .keep_alive_timeout_sec = 0, + .keepalive = false, + .keep_alive_interval_sec = 0, + // If set, sets the number of keep alive probes allowed to fail before the connection is considered + // lost. If zero OS defaults are used. On Windows, this option is meaningless until Windows 10 1703. + .keep_alive_max_failed_probes = 0, + }; + const http_client_options = c.aws_http_client_connection_options{ + .self_size = @sizeOf(c.aws_http_client_connection_options), + .socket_options = &socket_options, + .allocator = c_allocator, + .port = endpoint.port, + .host_name = c.aws_byte_cursor_from_c_str(@ptrCast([*c]const u8, host)), + .bootstrap = self.bootstrap, + .initial_window_size = c.SIZE_MAX, + .tls_options = tls_connection_options, + .user_data = &context, + .proxy_options = null, + .monitoring_options = null, + .http1_options = null, + .http2_options = null, + .manual_window_management = false, + .on_setup = connectionSetupCallback, + .on_shutdown = connectionShutdownCallback, + }; + if (c.aws_http_client_connect(&http_client_options) != c.AWS_OP_SUCCESS) { + httplog.alert("HTTP client connect failed with {s}.", .{c.aws_error_debug_str(c.aws_last_error())}); + return AwsError.HttpClientConnectError; + } + // TODO: Timeout + // Wait for connection to setup + while (!context.connection_complete.load(.SeqCst)) { + std.time.sleep(1 * std.time.ns_per_ms); + } + if (context.return_error) |e| return e; + + const request_options = c.aws_http_make_request_options{ + .self_size = @sizeOf(c.aws_http_make_request_options), + .on_response_headers = incomingHeadersCallback, + .on_response_header_block_done = null, + .on_response_body = incomingBodyCallback, + .on_complete = requestCompleteCallback, + .user_data = @ptrCast(*c_void, &context), + .request = http_request, + }; + + // C code + // app_ctx->response_code_written = false; + const stream = c.aws_http_connection_make_request(context.connection, &request_options); + if (stream == null) { + httplog.alert("failed to create request.", .{}); + return AwsError.RequestCreateError; + } + if (c.aws_http_stream_activate(stream) != c.AWS_OP_SUCCESS) { + httplog.alert("HTTP request failed with {s}.", .{c.aws_error_debug_str(c.aws_last_error())}); + return AwsError.HttpRequestError; + } + // TODO: Timeout + while (!context.request_complete.load(.SeqCst)) { + std.time.sleep(1 * std.time.ns_per_ms); + } + httplog.debug("request_complete. Response code {d}", .{context.response_code.?}); + httplog.debug("headers:", .{}); + for (context.headers.?.items) |h| { + httplog.debug(" {s}: {s}", .{ h.name, h.value }); + } + httplog.debug("raw response body:\n{s}", .{context.body}); + // Connection will stay alive until stream completes + c.aws_http_connection_release(context.connection); + context.connection = null; + if (tls_connection_options) |opts| { + c.aws_tls_connection_options_clean_up(opts); + } + var final_body: []const u8 = ""; + if (context.body) |b| { + final_body = b; + } + const rc = HttpResult{ + .body = final_body, + }; + return rc; + } + + // TODO: Re-encapsulate or delete this function. It is not currently + // used and will not be touched by the compiler + fn createRequest(method: []const u8, path: []const u8, body: []const u8) !*c.aws_http_message { + // TODO: Likely this should be encapsulated more + var http_request = c.aws_http_message_new_request(c_allocator); + // TODO: Verify if AWS cares about these headers (probably should be passing them...) + // Accept-Encoding: identity + // Content-Type: application/x-www-form-urlencoded + + if (c.aws_http_message_set_request_method(http_request, c.aws_byte_cursor_from_c_str(@ptrCast([*c]const u8, method))) != c.AWS_OP_SUCCESS) + return AwsError.SetRequestMethodError; + + if (c.aws_http_message_set_request_path(http_request, c.aws_byte_cursor_from_c_str(@ptrCast([*c]const u8, path))) != c.AWS_OP_SUCCESS) + return AwsError.SetRequestPathError; + + const body_cursor = c.aws_byte_cursor_from_c_str(@ptrCast([*c]const u8, body)); + const request_body = c.aws_input_stream_new_from_cursor(c_allocator, &body_cursor); + defer c.aws_input_stream_destroy(request_body); + c.aws_http_message_set_body_stream(http_request, request_body); + return http_request.?; + } + fn addHeaders(self: Self, request: *c.aws_http_message, host: []const u8, body: []const u8) !void { + const accept_header = c.aws_http_header{ + .name = c.aws_byte_cursor_from_c_str("Accept"), + .value = c.aws_byte_cursor_from_c_str("*/*"), + .compression = .AWS_HTTP_HEADER_COMPRESSION_USE_CACHE, + }; + if (c.aws_http_message_add_header(request, accept_header) != c.AWS_OP_SUCCESS) + return AwsError.AddHeaderError; + + const host_header = c.aws_http_header{ + .name = c.aws_byte_cursor_from_c_str("Host"), + .value = c.aws_byte_cursor_from_c_str(@ptrCast([*c]const u8, host)), + .compression = .AWS_HTTP_HEADER_COMPRESSION_USE_CACHE, + }; + if (c.aws_http_message_add_header(request, host_header) != c.AWS_OP_SUCCESS) + return AwsError.AddHeaderError; + + const user_agent_header = c.aws_http_header{ + .name = c.aws_byte_cursor_from_c_str("User-Agent"), + .value = c.aws_byte_cursor_from_c_str("zig-aws 1.0, Powered by the AWS Common Runtime."), + .compression = .AWS_HTTP_HEADER_COMPRESSION_USE_CACHE, + }; + if (c.aws_http_message_add_header(request, user_agent_header) != c.AWS_OP_SUCCESS) + return AwsError.AddHeaderError; + + // AWS does not seem to care about Accept-Encoding + // Accept-Encoding: identity + // Content-Type: application/x-www-form-urlencoded + // const accept_encoding_header = c.aws_http_header{ + // .name = c.aws_byte_cursor_from_c_str("Accept-Encoding"), + // .value = c.aws_byte_cursor_from_c_str("identity"), + // .compression = .AWS_HTTP_HEADER_COMPRESSION_USE_CACHE, + // }; + // if (c.aws_http_message_add_header(request, accept_encoding_header) != c.AWS_OP_SUCCESS) + // return AwsError.AddHeaderError; + + // AWS *does* seem to care about Content-Type. I don't think this header + // will hold for all APIs + // TODO: Work out Content-type + const content_type_header = c.aws_http_header{ + .name = c.aws_byte_cursor_from_c_str("Content-Type"), + .value = c.aws_byte_cursor_from_c_str("application/x-www-form-urlencoded"), + .compression = .AWS_HTTP_HEADER_COMPRESSION_USE_CACHE, + }; + if (c.aws_http_message_add_header(request, content_type_header) != c.AWS_OP_SUCCESS) + return AwsError.AddHeaderError; + + if (body.len > 0) { + const len = try std.fmt.allocPrint(self.allocator, "{d}", .{body.len}); + // This defer seems to work ok, but I'm a bit concerned about why + defer self.allocator.free(len); + const content_length_header = c.aws_http_header{ + .name = c.aws_byte_cursor_from_c_str("Content-Length"), + .value = c.aws_byte_cursor_from_c_str(@ptrCast([*c]const u8, len)), + .compression = .AWS_HTTP_HEADER_COMPRESSION_USE_CACHE, + }; + if (c.aws_http_message_add_header(request, content_length_header) != c.AWS_OP_SUCCESS) + return AwsError.AddHeaderError; + } + } + + fn connectionSetupCallback(connection: ?*c.aws_http_connection, error_code: c_int, user_data: ?*c_void) callconv(.C) void { + httplog.debug("connection setup callback start", .{}); + var context = userDataTo(RequestContext, user_data); + if (error_code != c.AWS_OP_SUCCESS) { + httplog.alert("Failed to setup connection: {s}.", .{c.aws_error_debug_str(c.aws_last_error())}); + context.return_error = AwsError.SetupConnectionError; + } + context.connection = connection; + context.connection_complete.store(true, .SeqCst); + httplog.debug("connection setup callback end", .{}); + } + + fn connectionShutdownCallback(connection: ?*c.aws_http_connection, error_code: c_int, user_data: ?*c_void) callconv(.C) void { + httplog.debug("connection shutdown callback start", .{}); + httplog.debug("connection shutdown callback end", .{}); + } + + fn incomingHeadersCallback(stream: ?*c.aws_http_stream, header_block: c.aws_http_header_block, headers: [*c]const c.aws_http_header, num_headers: usize, user_data: ?*c_void) callconv(.C) c_int { + var context = userDataTo(RequestContext, user_data); + + if (context.response_code == null) { + var status: c_int = 0; + if (c.aws_http_stream_get_incoming_response_status(stream, &status) == c.AWS_OP_SUCCESS) { + context.response_code = @intCast(u16, status); // RFC says this is a 3 digit number, so c_int is silly + httplog.debug("response status code from callback: {d}", .{status}); + } else { + httplog.alert("could not get status code", .{}); + context.return_error = AwsError.StatusCodeError; + } + } + for (headers[0..num_headers]) |header| { + const name = header.name.ptr[0..header.name.len]; + const value = header.value.ptr[0..header.value.len]; + httplog.debug("header from callback: {s}: {s}", .{ name, value }); + context.addHeader(name, value) catch + httplog.alert("could not append header to request context", .{}); + } + return c.AWS_OP_SUCCESS; + } + fn incomingBodyCallback(stream: ?*c.aws_http_stream, data: [*c]const c.aws_byte_cursor, user_data: ?*c_void) callconv(.C) c_int { + var context = userDataTo(RequestContext, user_data); + + httplog.debug("inbound body, len {d}", .{data.*.len}); + const array = @ptrCast(*const []u8, &data.*.ptr).*; + // Need this to be a slice because it does not necessarily have a \0 sentinal + const body_chunk = array[0..data.*.len]; + context.appendToBody(body_chunk) catch + httplog.alert("could not append to body!", .{}); + return c.AWS_OP_SUCCESS; + } + fn requestCompleteCallback(stream: ?*c.aws_http_stream, error_code: c_int, user_data: ?*c_void) callconv(.C) void { + var context = userDataTo(RequestContext, user_data); + context.request_complete.store(true, .SeqCst); + c.aws_http_stream_release(stream); + httplog.debug("request complete", .{}); + } + + // TODO: Re-encapsulate or delete this function. It is not currently + // used and will not be touched by the compiler + fn setupTls(self: Self, host: []const u8) !*c.aws_tls_connection_options { + if (Aws.tls_ctx_options == null) { + httplog.debug("Setting up tls options", .{}); + var opts: c.aws_tls_ctx_options = .{ + .allocator = c_allocator, + .minimum_tls_version = @intToEnum(c.aws_tls_versions, c.AWS_IO_TLS_VER_SYS_DEFAULTS), + .cipher_pref = @intToEnum(c.aws_tls_cipher_pref, c.AWS_IO_TLS_CIPHER_PREF_SYSTEM_DEFAULT), + .ca_file = c.aws_byte_buf_from_c_str(""), + .ca_path = c.aws_string_new_from_c_str(c_allocator, ""), + .alpn_list = null, + .certificate = c.aws_byte_buf_from_c_str(""), + .private_key = c.aws_byte_buf_from_c_str(""), + .max_fragment_size = 0, + .verify_peer = true, + }; + Aws.tls_ctx_options = &opts; + + c.aws_tls_ctx_options_init_default_client(Aws.tls_ctx_options.?, c_allocator); + // h2;http/1.1 + if (c.aws_tls_ctx_options_set_alpn_list(Aws.tls_ctx_options, "http/1.1") != c.AWS_OP_SUCCESS) { + httplog.alert("Failed to load alpn list with error {s}.", .{c.aws_error_debug_str(c.aws_last_error())}); + return AwsError.AlpnError; + } + + Aws.tls_ctx = c.aws_tls_client_ctx_new(c_allocator, Aws.tls_ctx_options.?); + + if (Aws.tls_ctx == null) { + std.debug.panic("Failed to initialize TLS context with error {s}.", .{c.aws_error_debug_str(c.aws_last_error())}); + } + httplog.debug("tls options setup applied", .{}); + } + + var tls_connection_options = c.aws_tls_connection_options{ + .alpn_list = null, + .server_name = null, + .on_negotiation_result = null, + .on_data_read = null, + .on_error = null, + .user_data = null, + .ctx = null, + .advertise_alpn_message = false, + .timeout_ms = 0, + }; + c.aws_tls_connection_options_init_from_ctx(&tls_connection_options, tls_ctx); + var host_var = host; + var host_cur = c.aws_byte_cursor_from_c_str(@ptrCast([*c]const u8, host_var)); + if (c.aws_tls_connection_options_set_server_name(&tls_connection_options, c_allocator, &host_cur) != c.AWS_OP_SUCCESS) { + httplog.alert("Failed to set servername with error {s}.", .{c.aws_error_debug_str(c.aws_last_error())}); + return AwsError.TlsError; + } + return &tls_connection_options; + + // if (app_ctx.uri.port) { + // port = app_ctx.uri.port; + // } + } + + pub const AwsError = error{ + AddHeaderError, + AlpnError, + CredentialsError, + HttpClientConnectError, + HttpRequestError, + SignableError, + SigningInitiationError, + TlsError, + RequestCreateError, + SetupConnectionError, + StatusCodeError, + SetRequestMethodError, + SetRequestPathError, + }; + + fn getCredentials(self: Self) !*c.aws_credentials { + var credential_result = AwsAsyncCallbackResult(c.aws_credentials){}; + var callback_results = AsyncResult(AwsAsyncCallbackResult(c.aws_credentials)){ .result = &credential_result }; + + const callback = awsAsyncCallbackResult(c.aws_credentials, "got credentials", assignCredentialsOnCallback); + const get_async_result = + c.aws_credentials_provider_get_credentials(self.credentialsProvider, callback, &callback_results); + + waitOnCallback(c.aws_credentials, &callback_results); + if (credential_result.error_code != c.AWS_ERROR_SUCCESS) { + httplog.alert("Could not acquire credentials: {s}:{s}", .{ c.aws_error_name(credential_result.error_code), c.aws_error_str(credential_result.error_code) }); + return AwsError.CredentialsError; + } + return credential_result.result orelse unreachable; + } + + // Generic wait on callback function + fn waitOnCallback(comptime T: type, results: *AsyncResult(AwsAsyncCallbackResult(T))) void { + var done = false; + while (!done) { + // TODO: Timeout + // More context: https://github.com/ziglang/zig/blob/119fc318a753f57b55809e9256e823accba6b56a/lib/std/crypto/benchmark.zig#L45-L54 + // var timer = try std.time.Timer.start(); + // const start = timer.lap(); + // while (offset < bytes) : (offset += block.len) { + // do work + // + // h.update(block[0..]); + // } + // mem.doNotOptimizeAway(&h); + // const end = timer.read(); + // + // const elapsed_s = @intToFloat(f64, end - start) / time.ns_per_s; + while (results.sync.load(.SeqCst)) { + std.time.sleep(1 * std.time.ns_per_ms); + } + done = results.count >= results.requiredCount; + // TODO: Timeout + std.time.sleep(1 * std.time.ns_per_ms); + } + } + + // Generic function that generates a type-specific funtion for callback use + fn awsAsyncCallback(comptime T: type, comptime message: []const u8) (fn (result: ?*T, error_code: c_int, user_data: ?*c_void) callconv(.C) void) { + const inner = struct { + fn func(userData: *AsyncResult(AwsAsyncCallbackResult(T)), apiData: ?*T) void { + userData.result.result = apiData; + } + }; + return awsAsyncCallbackResult(T, message, inner.func); + } + + // used by awsAsyncCallbackResult to cast our generic userdata void * + // into a type known to zig + fn userDataTo(comptime T: type, userData: ?*c_void) *T { + return @ptrCast(*T, @alignCast(@alignOf(T), userData)); + } + + // generic callback ability. Takes a function for the actual assignment + // If you need a standard assignment, use awsAsyncCallback instead + fn awsAsyncCallbackResult(comptime T: type, comptime message: []const u8, comptime resultAssignment: (fn (user: *AsyncResult(AwsAsyncCallbackResult(T)), apiData: ?*T) void)) (fn (result: ?*T, error_code: c_int, user_data: ?*c_void) callconv(.C) void) { + const inner = struct { + fn innerfunc(result: ?*T, error_code: c_int, user_data: ?*c_void) callconv(.C) void { + httplog.debug(message, .{}); + var asyncResult = userDataTo(AsyncResult(AwsAsyncCallbackResult(T)), user_data); + + asyncResult.sync.store(true, .SeqCst); + + asyncResult.count += 1; + asyncResult.result.error_code = error_code; + + resultAssignment(asyncResult, result); + // asyncResult.result.result = result; + + asyncResult.sync.store(false, .SeqCst); + } + }; + return inner.innerfunc; + } + + fn assignCredentialsOnCallback(asyncResult: *AsyncResult(AwsAsyncCallbackResult(c.aws_credentials)), credentials: ?*c.aws_credentials) void { + if (asyncResult.result.result) |result| { + c.aws_credentials_release(result); + } + + asyncResult.result.result = credentials; + + if (credentials) |cred| { + c.aws_credentials_acquire(cred); + } + } +}; + +fn cInit(allocator: *std.mem.Allocator) void { + // TODO: what happens if we actually get an allocator? + log.debug("auth init", .{}); + c_allocator = c.aws_default_allocator(); + // TODO: Grab logging level from environment + // See levels here: + // https://github.com/awslabs/aws-c-common/blob/ce964ca459759e685547e8aa95cada50fd078eeb/include/aws/common/logging.h#L13-L19 + // We set this to FATAL mostly because we're handling errors for the most + // part here in zig-land. We would therefore set up for something like + // AWS_LL_WARN, but the auth library is bubbling up an AWS_LL_ERROR + // level message about not being able to open an aws config file. This + // could be an error, but we don't need to panic people if configuration + // is done via environment variables + var logger_options = c.aws_logger_standard_options{ + // .level = .AWS_LL_WARN, + // .level = .AWS_LL_INFO, + // .level = .AWS_LL_DEBUG, + // .level = .AWS_LL_TRACE, + .level = .AWS_LL_FATAL, + .file = c.get_std_err(), + .filename = null, + }; + const rc = c.aws_logger_init_standard(&c_logger, c_allocator, &logger_options); + if (rc != c.AWS_OP_SUCCESS) { + std.debug.panic("Could not configure logging: {s}", .{c.aws_error_debug_str(c.aws_last_error())}); + } + + c.aws_logger_set(&c_logger); + // auth could use http library, so we'll init http, then auth + // TODO: determine deallocation of ca_path + c.aws_http_library_init(c_allocator); + c.aws_auth_library_init(c_allocator); +} + +fn cDeinit() void { // probably the wrong name + if (Aws.tls_ctx) |ctx| { + httplog.debug("tls_ctx deinit start", .{}); + c.aws_tls_ctx_release(ctx); + httplog.debug("tls_ctx deinit end", .{}); + } + if (Aws.tls_ctx_options) |opts| { + // See: + // https://github.com/awslabs/aws-c-io/blob/6c7bae503961545c5e99c6c836c4b37749cfc4ad/source/tls_channel_handler.c#L25 + // + // The way this structure is constructed (setupTls/makeRequest), the only + // thing we need to clean up here is the alpn_list, which is set by + // aws_tls_ctx_options_set_alpn_list to a constant value. My guess here + // is that memory is not allocated - the pointer is looking at the program data. + // So the pointer is non-zero, but cannot be deallocated, and we segfault + httplog.debug("tls_ctx_options deinit unnecessary - skipping", .{}); + // log.debug("tls_ctx_options deinit start. alpn_list: {*}", .{opts.alpn_list}); + // c.aws_string_destroy(opts.alpn_list); + // c.aws_tls_ctx_options_clean_up(opts); + // log.debug("tls_ctx_options deinit end", .{}); + } + c.aws_http_library_clean_up(); + log.debug("auth clean up start", .{}); + c.aws_auth_library_clean_up(); + log.debug("auth clean up complete", .{}); +} + +pub const ResponseMetadata = struct { + request_id: ?[]const u8, + allocator: *std.mem.Allocator, + pub fn deinit(self: *const ResponseMetadata) void { + if (self.request_id) |id| { + self.allocator.free(id); + } + } +}; + +pub const Options = struct { + region: []const u8 = "aws-global", + dualstack: bool = false, +}; + +pub const SigningOptions = struct { + region: []const u8 = "aws-global", + service: []const u8, +}; + +const EndPoint = struct { + uri: []const u8, + host: []const u8, + scheme: []const u8, + port: u16, + allocator: *std.mem.Allocator, + + fn deinit(self: EndPoint) void { + self.allocator.free(self.uri); + } +}; + +pub fn metadataFromResponse(allocator: *std.mem.Allocator, responseXml: []const u8) !ResponseMetadata { + const doc = try xml.parse(allocator, responseXml); + defer doc.deinit(); + const meta = doc.root.findChildByTag("ResponseMetadata"); + const request_id_src = meta.?.getCharData("RequestId"); + // requestIdSrc will be deallocated when deinit is called + // so we need to copy it locally + const request_id = if (request_id_src) |id| + try std.mem.dupe(allocator, u8, id) + else + null; + + return ResponseMetadata{ + .request_id = request_id, + .allocator = allocator, + }; +} + +fn regionSubDomain(allocator: *std.mem.Allocator, service: []const u8, region: []const u8, useDualStack: bool) !EndPoint { + const environment_override = std.os.getenv("AWS_ENDPOINT_URL"); + if (environment_override) |override| { + const uri = try std.fmt.allocPrint(allocator, "{s}", .{override}); + return endPointFromUri(allocator, uri); + } + // Fallback to us-east-1 if global endpoint does not exist. + const realregion = if (std.mem.eql(u8, region, "aws-global")) "us-east-1" else region; + const dualstack = if (useDualStack) ".dualstack" else ""; + + const domain = switch (std.hash_map.hashString(region)) { + US_ISO_EAST_1_HASH => "c2s.ic.gov", + CN_NORTH_1_HASH, CN_NORTHWEST_1_HASH => "amazonaws.com.cn", + US_ISOB_EAST_1_HASH => "sc2s.sgov.gov", + else => "amazonaws.com", + }; + + const uri = try std.fmt.allocPrint(allocator, "https://{s}{s}.{s}.{s}", .{ service, dualstack, realregion, domain }); + const host = uri["https://".len..]; + log.debug("host: {s}, scheme: {s}, port: {}", .{ host, "https", 443 }); + return EndPoint{ + .uri = uri, + .host = host, + .scheme = "https", + .port = 443, + .allocator = allocator, + }; +} + +/// creates an endpoint from a uri string. +/// +/// allocator: Will be used only to construct the EndPoint struct +/// uri: string constructed in such a way that deallocation is needed +fn endPointFromUri(allocator: *std.mem.Allocator, uri: []const u8) !EndPoint { + var scheme: []const u8 = ""; + var host: []const u8 = ""; + var port: u16 = 443; + var host_start: usize = 0; + var host_end: usize = 0; + for (uri) |ch, i| { + switch (ch) { + ':' => { + if (!std.mem.eql(u8, scheme, "")) { + // here to end is port - this is likely a bug if ipv6 address used + const rest_of_uri = uri[i + 1 ..]; + port = try std.fmt.parseUnsigned(u16, rest_of_uri, 10); + host_end = i; + } + }, + '/' => { + if (host_start == 0) { + host_start = i + 2; + scheme = uri[0 .. i - 1]; + if (std.mem.eql(u8, scheme, "http")) { + port = 80; + } else { + port = 443; + } + } + }, + else => continue, + } + } + if (host_end == 0) { + host_end = uri.len; + } + host = uri[host_start..host_end]; + + log.debug("host: {s}, scheme: {s}, port: {}", .{ host, scheme, port }); + return EndPoint{ + .uri = uri, + .host = host, + .scheme = scheme, + .allocator = allocator, + .port = port, + }; +} + +const Header = struct { + name: []const u8, + value: []const u8, +}; +const RequestContext = struct { + connection: ?*c.aws_http_connection = null, + connection_complete: std_atomic_bool.Bool = std_atomic_bool.Bool.init(false), // This is a 0.8.0 feature... :( + request_complete: std_atomic_bool.Bool = std_atomic_bool.Bool.init(false), // This is a 0.8.0 feature... :( + return_error: ?Aws.AwsError = null, + allocator: *std.mem.Allocator, + body: ?[]const u8 = null, + response_code: ?u16 = null, + headers: ?std.ArrayList(Header) = null, + + const Self = @This(); + + pub fn deinit(self: Self) void { + self.allocator.free(self.body); + if (self.headers) |hs| { + for (hs) |h| { + // deallocate the copied values + self.allocator.free(h.name); + self.allocator.free(h.value); + } + // deallocate the structure itself + h.deinit(); + } + } + + pub fn appendToBody(self: *Self, fragment: []const u8) !void { + var orig_body: []const u8 = ""; + if (self.body) |b| { + orig_body = try self.allocator.dupeZ(u8, b); + self.allocator.free(self.body.?); + self.body = null; + } + defer self.allocator.free(orig_body); + self.body = try std.fmt.allocPrint(self.allocator, "{s}{s}", .{ orig_body, fragment }); + } + + pub fn addHeader(self: *Self, name: []const u8, value: []const u8) !void { + if (self.headers == null) + self.headers = std.ArrayList(Header).init(self.allocator); + + const name_copy = try self.allocator.dupeZ(u8, name); + const value_copy = try self.allocator.dupeZ(u8, value); + + try self.headers.?.append(.{ + .name = name_copy, + .value = value_copy, + }); + } +}; diff --git a/src/bitfield-workaround.c b/src/bitfield-workaround.c new file mode 100644 index 0000000..bdb6916 --- /dev/null +++ b/src/bitfield-workaround.c @@ -0,0 +1,34 @@ +#include +#include + +#include "bitfield-workaround.h" + +extern void *new_aws_signing_config( + struct aws_allocator *allocator, + const struct bitfield_workaround_aws_signing_config_aws *config) { + struct aws_signing_config_aws *new_config = aws_mem_acquire(allocator, sizeof(struct aws_signing_config_aws)); + + new_config->algorithm = config->algorithm; + new_config->config_type = config->config_type; + new_config->signature_type = config->signature_type; + new_config->region = config->region; + new_config->service = config->service; + new_config->should_sign_header = config->should_sign_header; + new_config->should_sign_header_ud = config->should_sign_header_ud; + new_config->flags.use_double_uri_encode = config->flags.use_double_uri_encode; + new_config->flags.should_normalize_uri_path = config->flags.should_normalize_uri_path; + new_config->flags.omit_session_token = config->flags.omit_session_token; + new_config->signed_body_value = config->signed_body_value; + new_config->signed_body_header = config->signed_body_header; + new_config->credentials = config->credentials; + new_config->credentials_provider = config->credentials_provider; + new_config->expiration_in_seconds = config->expiration_in_seconds; + + aws_date_time_init_now(&new_config->date); + + return new_config; +} + +extern FILE *get_std_err() { + return stderr; +} diff --git a/src/bitfield-workaround.h b/src/bitfield-workaround.h new file mode 100644 index 0000000..e2ca13d --- /dev/null +++ b/src/bitfield-workaround.h @@ -0,0 +1,142 @@ +#ifndef ZIG_AWS_BITFIELD_WORKAROUND_H +#define ZIG_AWS_BITFIELD_WORKAROUND_H + +#include +#include + + + +// Copied verbatim from https://github.com/awslabs/aws-c-auth/blob/main/include/aws/auth/signing_config.h#L127-L241 +// However, the flags has changed to uint32_t without bitfield annotations +// as Zig does not support them yet. See https://github.com/ziglang/zig/issues/1499 +// We've renamed as well to make clear what's going on +// +// Signing date is also somewhat problematic, so we removed it and it is +// part of the c code + +/* + * Put all flags in here at the end. If this grows, stay aware of bit-space overflow and ABI compatibilty. + */ +struct bitfield_workaround_aws_signing_config_aws_flags { + /** + * We assume the uri will be encoded once in preparation for transmission. Certain services + * do not decode before checking signature, requiring us to actually double-encode the uri in the canonical + * request in order to pass a signature check. + */ + uint32_t use_double_uri_encode; + + /** + * Controls whether or not the uri paths should be normalized when building the canonical request + */ + uint32_t should_normalize_uri_path; + + /** + * Controls whether "X-Amz-Security-Token" is omitted from the canonical request. + * "X-Amz-Security-Token" is added during signing, as a header or + * query param, when credentials have a session token. + * If false (the default), this parameter is included in the canonical request. + * If true, this parameter is still added, but omitted from the canonical request. + */ + uint32_t omit_session_token; +}; + +/** + * A configuration structure for use in AWS-related signing. Currently covers sigv4 only, but is not required to. + */ +struct bitfield_workaround_aws_signing_config_aws { + + /** + * What kind of config structure is this? + */ + enum aws_signing_config_type config_type; + + /** + * What signing algorithm to use. + */ + enum aws_signing_algorithm algorithm; + + /** + * What sort of signature should be computed? + */ + enum aws_signature_type signature_type; + + /** + * The region to sign against + */ + struct aws_byte_cursor region; + + /** + * name of service to sign a request for + */ + struct aws_byte_cursor service; + + /** + * Raw date to use during the signing process. + */ + // struct aws_date_time date; + + /** + * Optional function to control which headers are a part of the canonical request. + * Skipping auth-required headers will result in an unusable signature. Headers injected by the signing process + * are not skippable. + * + * This function does not override the internal check function (x-amzn-trace-id, user-agent), but rather + * supplements it. In particular, a header will get signed if and only if it returns true to both + * the internal check (skips x-amzn-trace-id, user-agent) and this function (if defined). + */ + aws_should_sign_header_fn *should_sign_header; + void *should_sign_header_ud; + + /* + * Put all flags in here at the end. If this grows, stay aware of bit-space overflow and ABI compatibilty. + */ + struct bitfield_workaround_aws_signing_config_aws_flags flags; + + /** + * Optional string to use as the canonical request's body value. + * If string is empty, a value will be calculated from the payload during signing. + * Typically, this is the SHA-256 of the (request/chunk/event) payload, written as lowercase hex. + * If this has been precalculated, it can be set here. Special values used by certain services can also be set + * (e.g. "UNSIGNED-PAYLOAD" "STREAMING-AWS4-HMAC-SHA256-PAYLOAD" "STREAMING-AWS4-HMAC-SHA256-EVENTS"). + */ + struct aws_byte_cursor signed_body_value; + + /** + * Controls what body "hash" header, if any, should be added to the canonical request and the signed request: + * AWS_SBHT_NONE - no header should be added + * AWS_SBHT_X_AMZ_CONTENT_SHA256 - the body "hash" should be added in the X-Amz-Content-Sha256 header + */ + enum aws_signed_body_header_type signed_body_header; + + /* + * Signing key control: + * + * (1) If "credentials" is valid, use it + * (2) Else if "credentials_provider" is valid, query credentials from the provider and use the result + * (3) Else fail + * + */ + + /** + * AWS Credentials to sign with. + */ + const struct aws_credentials *credentials; + + /** + * AWS credentials provider to fetch credentials from. + */ + struct aws_credentials_provider *credentials_provider; + + /** + * If non-zero and the signing transform is query param, then signing will add X-Amz-Expires to the query + * string, equal to the value specified here. If this value is zero or if header signing is being used then + * this parameter has no effect. + */ + uint64_t expiration_in_seconds; +}; + + + +extern void *new_aws_signing_config(struct aws_allocator *allocator, const struct bitfield_workaround_aws_signing_config_aws *config); +extern FILE *get_std_err(); +#endif diff --git a/src/bool.zig b/src/bool.zig new file mode 100644 index 0000000..c968b86 --- /dev/null +++ b/src/bool.zig @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2015-2021 Zig Contributors +// This file is part of [zig](https://ziglang.org/), which is MIT licensed. +// The MIT license requires this copyright notice to be included in all copies +// and substantial portions of the software. + +const std = @import("std"); +const builtin = std.builtin; +const testing = std.testing; + +/// Thread-safe, lock-free boolean +pub const Bool = extern struct { + unprotected_value: bool, + + pub const Self = @This(); + + pub fn init(init_val: bool) Self { + return Self{ .unprotected_value = init_val }; + } + + // xchg is only valid rmw operation for a bool + /// Atomically modifies memory and then returns the previous value. + pub fn xchg(self: *Self, operand: bool, comptime ordering: std.builtin.AtomicOrder) bool { + switch (ordering) { + .Monotonic, .Acquire, .Release, .AcqRel, .SeqCst => {}, + else => @compileError("Invalid ordering '" ++ @tagName(ordering) ++ "' for a RMW operation"), + } + return @atomicRmw(bool, &self.unprotected_value, .Xchg, operand, ordering); + } + + pub fn load(self: *Self, comptime ordering: std.builtin.AtomicOrder) bool { + switch (ordering) { + .Unordered, .Monotonic, .Acquire, .SeqCst => {}, + else => @compileError("Invalid ordering '" ++ @tagName(ordering) ++ "' for a load operation"), + } + return @atomicLoad(bool, &self.unprotected_value, ordering); + } + + pub fn store(self: *Self, value: bool, comptime ordering: std.builtin.AtomicOrder) void { + switch (ordering) { + .Unordered, .Monotonic, .Release, .SeqCst => {}, + else => @compileError("Invalid ordering '" ++ @tagName(ordering) ++ "' for a store operation"), + } + @atomicStore(bool, &self.unprotected_value, value, ordering); + } +}; + +test "std.atomic.Bool" { + var a = Bool.init(false); + testing.expectEqual(false, a.xchg(false, .SeqCst)); + testing.expectEqual(false, a.load(.SeqCst)); + a.store(true, .SeqCst); + testing.expectEqual(true, a.xchg(false, .SeqCst)); + testing.expectEqual(false, a.load(.SeqCst)); +} diff --git a/src/main.zig b/src/main.zig new file mode 100644 index 0000000..e95d29e --- /dev/null +++ b/src/main.zig @@ -0,0 +1,63 @@ +const std = @import("std"); +const aws = @import("aws.zig"); + +pub fn log( + comptime level: std.log.Level, + comptime scope: @TypeOf(.EnumLiteral), + comptime format: []const u8, + args: anytype, +) void { + // Ignore awshttp messages + if (scope == .awshttp and @enumToInt(level) >= @enumToInt(std.log.Level.debug)) + return; + + const scope_prefix = "(" ++ @tagName(scope) ++ "): "; + const prefix = "[" ++ @tagName(level) ++ "] " ++ scope_prefix; + + // Print the message to stderr, silently ignoring any errors + const held = std.debug.getStderrMutex().acquire(); + defer held.release(); + const stderr = std.io.getStdErr().writer(); + nosuspend stderr.print(prefix ++ format ++ "\n", args) catch return; +} + +pub fn main() anyerror!void { + // Uncomment if you want to log allocations + // const file = try std.fs.cwd().createFile("/tmp/allocations.log", .{ .truncate = true }); + // defer file.close(); + // var child_allocator = std.heap.c_allocator; + // const allocator = &std.heap.loggingAllocator(child_allocator, file.writer()).allocator; + const allocator = std.heap.c_allocator; + + const options = aws.Options{ + .region = "us-west-2", + }; + std.log.info("Start", .{}); + + var client = aws.Aws.init(allocator); + defer client.deinit(); + const resp = try client.call(aws.services.sts.get_caller_identity.Request{}, options); + // TODO: This is a bit wonky. Root cause is lack of declarations in + // comptime-generated types + defer aws.Aws.responseDeinit(resp.raw_response, resp.response_metadata); + + // Flip to true to run a second time. This will help debug + // allocation/deallocation issues + const test_twice = false; + if (test_twice) { + std.time.sleep(1000 * std.time.ns_per_ms); + std.log.info("second request", .{}); + + var client2 = aws.Aws.init(allocator); + defer client2.deinit(); + const resp2 = try client2.call(aws.services.sts.get_caller_identity.Request{}, options); // catch here and try alloc? + defer aws.Aws.responseDeinit(resp2.raw_response, resp2.response_metadata); + } + + std.log.info("arn: {s}", .{resp.arn}); + std.log.info("id: {s}", .{resp.user_id}); + std.log.info("account: {s}", .{resp.account}); + std.log.info("requestId: {s}", .{resp.response_metadata.request_id}); + + std.log.info("Departing main", .{}); +} diff --git a/src/xml.zig b/src/xml.zig new file mode 100644 index 0000000..3fa9d16 --- /dev/null +++ b/src/xml.zig @@ -0,0 +1,649 @@ +const std = @import("std"); +const mem = std.mem; +const testing = std.testing; +const Allocator = mem.Allocator; +const ArenaAllocator = std.heap.ArenaAllocator; +const ArrayList = std.ArrayList; + +pub const Attribute = struct { + name: []const u8, value: []const u8 +}; + +pub const Content = union(enum) { + CharData: []const u8, Comment: []const u8, Element: *Element +}; + +pub const Element = struct { + pub const AttributeList = ArrayList(*Attribute); + pub const ContentList = ArrayList(Content); + + tag: []const u8, + attributes: AttributeList, + children: ContentList, + + fn init(tag: []const u8, alloc: *Allocator) Element { + return .{ + .tag = tag, + .attributes = AttributeList.init(alloc), + .children = ContentList.init(alloc), + }; + } + + pub fn getAttribute(self: *Element, attrib_name: []const u8) ?[]const u8 { + for (self.attributes.items) |child| { + if (mem.eql(u8, child.name, attrib_name)) { + return child.value; + } + } + + return null; + } + + pub fn getCharData(self: *Element, child_tag: []const u8) ?[]const u8 { + const child = self.findChildByTag(child_tag) orelse return null; + if (child.children.items.len != 1) { + return null; + } + + return switch (child.children.items[0]) { + .CharData => |char_data| char_data, + else => null, + }; + } + + pub fn iterator(self: *Element) ChildIterator { + return .{ + .items = self.children.items, + .i = 0, + }; + } + + pub fn elements(self: *Element) ChildElementIterator { + return .{ + .inner = self.iterator(), + }; + } + + pub fn findChildByTag(self: *Element, tag: []const u8) ?*Element { + return self.findChildrenByTag(tag).next(); + } + + pub fn findChildrenByTag(self: *Element, tag: []const u8) FindChildrenByTagIterator { + return .{ + .inner = self.elements(), + .tag = tag, + }; + } + + pub const ChildIterator = struct { + items: []Content, + i: usize, + + pub fn next(self: *ChildIterator) ?*Content { + if (self.i < self.items.len) { + self.i += 1; + return &self.items[self.i - 1]; + } + + return null; + } + }; + + pub const ChildElementIterator = struct { + inner: ChildIterator, + + pub fn next(self: *ChildElementIterator) ?*Element { + while (self.inner.next()) |child| { + if (child.* != .Element) { + continue; + } + + return child.*.Element; + } + + return null; + } + }; + + pub const FindChildrenByTagIterator = struct { + inner: ChildElementIterator, + tag: []const u8, + + pub fn next(self: *FindChildrenByTagIterator) ?*Element { + while (self.inner.next()) |child| { + if (!mem.eql(u8, child.tag, self.tag)) { + continue; + } + + return child; + } + + return null; + } + }; +}; + +pub const XmlDecl = struct { + version: []const u8, encoding: ?[]const u8, standalone: ?bool +}; + +pub const Document = struct { + arena: ArenaAllocator, + xml_decl: ?*XmlDecl, + root: *Element, + + pub fn deinit(self: Document) void { + var arena = self.arena; // Copy to stack so self can be taken by value. + arena.deinit(); + } +}; + +const ParseContext = struct { + source: []const u8, + offset: usize, + line: usize, + column: usize, + + fn init(source: []const u8) ParseContext { + return .{ + .source = source, + .offset = 0, + .line = 0, + .column = 0, + }; + } + + fn peek(self: *ParseContext) ?u8 { + return if (self.offset < self.source.len) self.source[self.offset] else null; + } + + fn consume(self: *ParseContext) !u8 { + if (self.offset < self.source.len) { + return self.consumeNoEof(); + } + + return error.UnexpectedEof; + } + + fn consumeNoEof(self: *ParseContext) u8 { + std.debug.assert(self.offset < self.source.len); + const c = self.source[self.offset]; + self.offset += 1; + + if (c == '\n') { + self.line += 1; + self.column = 0; + } else { + self.column += 1; + } + + return c; + } + + fn eat(self: *ParseContext, char: u8) bool { + self.expect(char) catch return false; + return true; + } + + fn expect(self: *ParseContext, expected: u8) !void { + if (self.peek()) |actual| { + if (expected != actual) { + return error.UnexpectedCharacter; + } + + _ = self.consumeNoEof(); + return; + } + + return error.UnexpectedEof; + } + + fn eatStr(self: *ParseContext, text: []const u8) bool { + self.expectStr(text) catch return false; + return true; + } + + fn expectStr(self: *ParseContext, text: []const u8) !void { + if (self.source.len < self.offset + text.len) { + return error.UnexpectedEof; + } else if (std.mem.startsWith(u8, self.source[self.offset..], text)) { + var i: usize = 0; + while (i < text.len) : (i += 1) { + _ = self.consumeNoEof(); + } + + return; + } + + return error.UnexpectedCharacter; + } + + fn eatWs(self: *ParseContext) bool { + var ws = false; + + while (self.peek()) |ch| { + switch (ch) { + ' ', '\t', '\n', '\r' => { + ws = true; + _ = self.consumeNoEof(); + }, + else => break, + } + } + + return ws; + } + + fn expectWs(self: *ParseContext) !void { + if (!self.eatWs()) return error.UnexpectedCharacter; + } + + fn currentLine(self: ParseContext) []const u8 { + var begin: usize = 0; + if (mem.lastIndexOfScalar(u8, self.source[0..self.offset], '\n')) |prev_nl| { + begin = prev_nl + 1; + } + + var end = mem.indexOfScalarPos(u8, self.source, self.offset, '\n') orelse self.source.len; + return self.source[begin..end]; + } +}; + +test "ParseContext" { + { + var ctx = ParseContext.init("I like pythons"); + testing.expectEqual(@as(?u8, 'I'), ctx.peek()); + testing.expectEqual(@as(u8, 'I'), ctx.consumeNoEof()); + testing.expectEqual(@as(?u8, ' '), ctx.peek()); + testing.expectEqual(@as(u8, ' '), try ctx.consume()); + + testing.expect(ctx.eat('l')); + testing.expectEqual(@as(?u8, 'i'), ctx.peek()); + testing.expectEqual(false, ctx.eat('a')); + testing.expectEqual(@as(?u8, 'i'), ctx.peek()); + + try ctx.expect('i'); + testing.expectEqual(@as(?u8, 'k'), ctx.peek()); + testing.expectError(error.UnexpectedCharacter, ctx.expect('a')); + testing.expectEqual(@as(?u8, 'k'), ctx.peek()); + + testing.expect(ctx.eatStr("ke")); + testing.expectEqual(@as(?u8, ' '), ctx.peek()); + + testing.expect(ctx.eatWs()); + testing.expectEqual(@as(?u8, 'p'), ctx.peek()); + testing.expectEqual(false, ctx.eatWs()); + testing.expectEqual(@as(?u8, 'p'), ctx.peek()); + + testing.expectEqual(false, ctx.eatStr("aaaaaaaaa")); + testing.expectEqual(@as(?u8, 'p'), ctx.peek()); + + testing.expectError(error.UnexpectedEof, ctx.expectStr("aaaaaaaaa")); + testing.expectEqual(@as(?u8, 'p'), ctx.peek()); + testing.expectError(error.UnexpectedCharacter, ctx.expectStr("pytn")); + testing.expectEqual(@as(?u8, 'p'), ctx.peek()); + try ctx.expectStr("python"); + testing.expectEqual(@as(?u8, 's'), ctx.peek()); + } + + { + var ctx = ParseContext.init(""); + testing.expectEqual(ctx.peek(), null); + testing.expectError(error.UnexpectedEof, ctx.consume()); + testing.expectEqual(ctx.eat('p'), false); + testing.expectError(error.UnexpectedEof, ctx.expect('p')); + } +} + +pub const ParseError = error{ IllegalCharacter, UnexpectedEof, UnexpectedCharacter, UnclosedValue, UnclosedComment, InvalidName, InvalidEntity, InvalidStandaloneValue, NonMatchingClosingTag, InvalidDocument, OutOfMemory }; + +pub fn parse(backing_allocator: *Allocator, source: []const u8) !Document { + var ctx = ParseContext.init(source); + return try parseDocument(&ctx, backing_allocator); +} + +fn parseDocument(ctx: *ParseContext, backing_allocator: *Allocator) !Document { + var doc = Document{ + .arena = ArenaAllocator.init(backing_allocator), + .xml_decl = null, + .root = undefined, + }; + + errdefer doc.deinit(); + + try trySkipComments(ctx, &doc.arena.allocator); + + doc.xml_decl = try tryParseProlog(ctx, &doc.arena.allocator); + _ = ctx.eatWs(); + try trySkipComments(ctx, &doc.arena.allocator); + + doc.root = (try tryParseElement(ctx, &doc.arena.allocator)) orelse return error.InvalidDocument; + _ = ctx.eatWs(); + try trySkipComments(ctx, &doc.arena.allocator); + + if (ctx.peek() != null) return error.InvalidDocument; + + return doc; +} + +fn parseAttrValue(ctx: *ParseContext, alloc: *Allocator) ![]const u8 { + const quote = try ctx.consume(); + if (quote != '"' and quote != '\'') return error.UnexpectedCharacter; + + const begin = ctx.offset; + + while (true) { + const c = ctx.consume() catch return error.UnclosedValue; + if (c == quote) break; + } + + const end = ctx.offset - 1; + + return try dupeAndUnescape(alloc, ctx.source[begin..end]); +} + +fn parseEqAttrValue(ctx: *ParseContext, alloc: *Allocator) ![]const u8 { + _ = ctx.eatWs(); + try ctx.expect('='); + _ = ctx.eatWs(); + + return try parseAttrValue(ctx, alloc); +} + +fn parseNameNoDupe(ctx: *ParseContext) ![]const u8 { + // XML's spec on names is very long, so to make this easier + // we just take any character that is not special and not whitespace + const begin = ctx.offset; + + while (ctx.peek()) |ch| { + switch (ch) { + ' ', '\t', '\n', '\r' => break, + '&', '"', '\'', '<', '>', '?', '=', '/' => break, + else => _ = ctx.consumeNoEof(), + } + } + + const end = ctx.offset; + if (begin == end) return error.InvalidName; + + return ctx.source[begin..end]; +} + +fn tryParseCharData(ctx: *ParseContext, alloc: *Allocator) !?[]const u8 { + const begin = ctx.offset; + + while (ctx.peek()) |ch| { + switch (ch) { + '<', '>' => break, + else => _ = ctx.consumeNoEof(), + } + } + + const end = ctx.offset; + if (begin == end) return null; + + return try dupeAndUnescape(alloc, ctx.source[begin..end]); +} + +fn parseContent(ctx: *ParseContext, alloc: *Allocator) ParseError!Content { + if (try tryParseCharData(ctx, alloc)) |cd| { + return Content{ .CharData = cd }; + } else if (try tryParseComment(ctx, alloc)) |comment| { + return Content{ .Comment = comment }; + } else if (try tryParseElement(ctx, alloc)) |elem| { + return Content{ .Element = elem }; + } else { + return error.UnexpectedCharacter; + } +} + +fn tryParseAttr(ctx: *ParseContext, alloc: *Allocator) !?*Attribute { + const name = parseNameNoDupe(ctx) catch return null; + _ = ctx.eatWs(); + try ctx.expect('='); + _ = ctx.eatWs(); + const value = try parseAttrValue(ctx, alloc); + + const attr = try alloc.create(Attribute); + attr.name = try mem.dupe(alloc, u8, name); + attr.value = value; + return attr; +} + +fn tryParseElement(ctx: *ParseContext, alloc: *Allocator) !?*Element { + const start = ctx.offset; + if (!ctx.eat('<')) return null; + const tag = parseNameNoDupe(ctx) catch { + ctx.offset = start; + return null; + }; + + const element = try alloc.create(Element); + element.* = Element.init(try std.mem.dupe(alloc, u8, tag), alloc); + + while (ctx.eatWs()) { + const attr = (try tryParseAttr(ctx, alloc)) orelse break; + try element.attributes.append(attr); + } + + if (ctx.eatStr("/>")) { + return element; + } + + try ctx.expect('>'); + + while (true) { + if (ctx.peek() == null) { + return error.UnexpectedEof; + } else if (ctx.eatStr("'); + return element; +} + +test "tryParseElement" { + var arena = std.heap.ArenaAllocator.init(testing.allocator); + defer arena.deinit(); + var alloc = &arena.allocator; + + { + var ctx = ParseContext.init("<= a='b'/>"); + testing.expectEqual(@as(?*Element, null), try tryParseElement(&ctx, alloc)); + testing.expectEqual(@as(?u8, '<'), ctx.peek()); + } + + { + var ctx = ParseContext.init(""); + const elem = try tryParseElement(&ctx, alloc); + testing.expectEqualSlices(u8, elem.?.tag, "python"); + + const size_attr = elem.?.attributes.items[0]; + testing.expectEqualSlices(u8, size_attr.name, "size"); + testing.expectEqualSlices(u8, size_attr.value, "15"); + + const color_attr = elem.?.attributes.items[1]; + testing.expectEqualSlices(u8, color_attr.name, "color"); + testing.expectEqualSlices(u8, color_attr.value, "green"); + } + + { + var ctx = ParseContext.init("test"); + const elem = try tryParseElement(&ctx, alloc); + testing.expectEqualSlices(u8, elem.?.tag, "python"); + testing.expectEqualSlices(u8, elem.?.children.items[0].CharData, "test"); + } + + { + var ctx = ParseContext.init("bdf"); + const elem = try tryParseElement(&ctx, alloc); + testing.expectEqualSlices(u8, elem.?.tag, "a"); + testing.expectEqualSlices(u8, elem.?.children.items[0].CharData, "b"); + testing.expectEqualSlices(u8, elem.?.children.items[1].Element.tag, "c"); + testing.expectEqualSlices(u8, elem.?.children.items[2].CharData, "d"); + testing.expectEqualSlices(u8, elem.?.children.items[3].Element.tag, "e"); + testing.expectEqualSlices(u8, elem.?.children.items[4].CharData, "f"); + testing.expectEqualSlices(u8, elem.?.children.items[5].Comment, "g"); + } +} + +fn tryParseProlog(ctx: *ParseContext, alloc: *Allocator) !?*XmlDecl { + const start = ctx.offset; + if (!ctx.eatStr(""); + return decl; +} + +test "tryParseProlog" { + var arena = std.heap.ArenaAllocator.init(testing.allocator); + defer arena.deinit(); + var alloc = &arena.allocator; + + { + var ctx = ParseContext.init(""); + testing.expectEqual(@as(?*XmlDecl, null), try tryParseProlog(&ctx, alloc)); + testing.expectEqual(@as(?u8, '<'), ctx.peek()); + } + + { + var ctx = ParseContext.init(""); + const decl = try tryParseProlog(&ctx, alloc); + testing.expectEqualSlices(u8, "aa", decl.?.version); + testing.expectEqual(@as(?[]const u8, null), decl.?.encoding); + testing.expectEqual(@as(?bool, null), decl.?.standalone); + } + + { + var ctx = ParseContext.init(""); + const decl = try tryParseProlog(&ctx, alloc); + testing.expectEqualSlices(u8, "aa", decl.?.version); + testing.expectEqualSlices(u8, "bbb", decl.?.encoding.?); + testing.expectEqual(@as(?bool, true), decl.?.standalone.?); + } +} + +fn trySkipComments(ctx: *ParseContext, alloc: *Allocator) !void { + while (try tryParseComment(ctx, alloc)) |_| { + _ = ctx.eatWs(); + } +} + +fn tryParseComment(ctx: *ParseContext, alloc: *Allocator) !?[]const u8 { + if (!ctx.eatStr("")) { + _ = ctx.consume() catch return error.UnclosedComment; + } + + const end = ctx.offset - "-->".len; + return try mem.dupe(alloc, u8, ctx.source[begin..end]); +} + +fn unescapeEntity(text: []const u8) !u8 { + const EntitySubstition = struct { + text: []const u8, replacement: u8 + }; + + const entities = [_]EntitySubstition{ + .{ .text = "<", .replacement = '<' }, + .{ .text = ">", .replacement = '>' }, + .{ .text = "&", .replacement = '&' }, + .{ .text = "'", .replacement = '\'' }, + .{ .text = """, .replacement = '"' }, + }; + + for (entities) |entity| { + if (std.mem.eql(u8, text, entity.text)) return entity.replacement; + } + + return error.InvalidEntity; +} + +fn dupeAndUnescape(alloc: *Allocator, text: []const u8) ![]const u8 { + const str = try alloc.alloc(u8, text.len); + + var j: usize = 0; + var i: usize = 0; + while (i < text.len) : (j += 1) { + if (text[i] == '&') { + const entity_end = 1 + (mem.indexOfScalarPos(u8, text, i, ';') orelse return error.InvalidEntity); + str[j] = try unescapeEntity(text[i..entity_end]); + i = entity_end; + } else { + str[j] = text[i]; + i += 1; + } + } + + return alloc.shrink(str, j); +} + +test "dupeAndUnescape" { + var arena = std.heap.ArenaAllocator.init(testing.allocator); + defer arena.deinit(); + var alloc = &arena.allocator; + + testing.expectEqualSlices(u8, "test", try dupeAndUnescape(alloc, "test")); + testing.expectEqualSlices(u8, "ad\"e'f<", try dupeAndUnescape(alloc, "a<b&c>d"e'f<")); + testing.expectError(error.InvalidEntity, dupeAndUnescape(alloc, "python&")); + testing.expectError(error.InvalidEntity, dupeAndUnescape(alloc, "python&&")); + testing.expectError(error.InvalidEntity, dupeAndUnescape(alloc, "python&test;")); + testing.expectError(error.InvalidEntity, dupeAndUnescape(alloc, "python&boa")); +} + +test "Top level comments" { + var arena = std.heap.ArenaAllocator.init(testing.allocator); + defer arena.deinit(); + var alloc = &arena.allocator; + + const doc = try parse(alloc, ""); + testing.expectEqualSlices(u8, "python", doc.root.tag); +}