diff --git a/README.md b/README.md index 5c9ca11..187df33 100644 --- a/README.md +++ b/README.md @@ -58,7 +58,6 @@ Only environment variable based credentials can be used at the moment. TODO List: -* Add STS key support * Add option to cache signature keys * Implement credentials provider * Implement jitter/exponential backoff diff --git a/src/aws_credentials.zig b/src/aws_credentials.zig index 1a73910..c24b0bf 100644 --- a/src/aws_credentials.zig +++ b/src/aws_credentials.zig @@ -23,7 +23,8 @@ fn getEnvironmentCredentials(allocator: std.mem.Allocator) !?auth.Credentials { allocator, (try getEnvironmentVariable(allocator, "AWS_ACCESS_KEY_ID")) orelse return null, mutable_key, - try getEnvironmentVariable(allocator, "AWS_SESSION_TOKEN"), + (try getEnvironmentVariable(allocator, "AWS_SESSION_TOKEN")) orelse + try getEnvironmentVariable(allocator, "AWS_SECURITY_TOKEN"), ); } diff --git a/src/aws_signing.zig b/src/aws_signing.zig index 52628b4..c660da5 100644 --- a/src/aws_signing.zig +++ b/src/aws_signing.zig @@ -155,9 +155,18 @@ pub fn signRequest(allocator: std.mem.Allocator, request: base.Request, config: ); errdefer freeSignedRequest(allocator, &rc, config); - const newheaders = try allocator.alloc(base.Header, rc.headers.len + 2); + var additional_header_count: u2 = 2; + if (config.credentials.session_token != null) + additional_header_count += 1; + const newheaders = try allocator.alloc(base.Header, rc.headers.len + additional_header_count); errdefer allocator.free(newheaders); const oldheaders = rc.headers; + if (config.credentials.session_token) |t| { + newheaders[newheaders.len - 3] = base.Header{ + .name = "X-Amz-Security-Token", + .value = try allocator.dupe(u8, t), + }; + } errdefer freeSignedRequest(allocator, &rc, config); std.mem.copy(base.Header, newheaders, oldheaders); newheaders[newheaders.len - 2] = base.Header{ @@ -245,7 +254,10 @@ pub fn freeSignedRequest(allocator: std.mem.Allocator, request: *base.Request, c var remove_len: u2 = 0; for (request.headers) |h| { - if (std.ascii.eqlIgnoreCase(h.name, "X-Amz-Date") or std.ascii.eqlIgnoreCase(h.name, "Authorization")) { + if (std.ascii.eqlIgnoreCase(h.name, "X-Amz-Date") or + std.ascii.eqlIgnoreCase(h.name, "Authorization") or + std.ascii.eqlIgnoreCase(h.name, "X-Amz-Security-Token")) + { allocator.free(h.value); remove_len += 1; } @@ -332,7 +344,7 @@ fn createCanonicalRequest(allocator: std.mem.Allocator, request: base.Request, c const canonical_query = try canonicalQueryString(allocator, request.query); defer allocator.free(canonical_query); log.debug("canonical query: {s}", .{canonical_query}); - const canonical_headers = try canonicalHeaders(allocator, request.headers); + const canonical_headers = try canonicalHeaders(allocator, request.headers, config.flags); const payload_hash = try hash(allocator, request.body, config.signed_body_header); defer allocator.free(payload_hash); @@ -558,7 +570,7 @@ const CanonicalHeaders = struct { str: []const u8, signed_headers: []const u8, }; -fn canonicalHeaders(allocator: std.mem.Allocator, headers: []base.Header) !CanonicalHeaders { +fn canonicalHeaders(allocator: std.mem.Allocator, headers: []base.Header, flags: ConfigFlags) !CanonicalHeaders { // // Doc example. Original: // @@ -592,6 +604,17 @@ fn canonicalHeaders(allocator: std.mem.Allocator, headers: []base.Header) !Canon break; } } + // Well, this is fun (https://docs.aws.amazon.com/general/latest/gr/sigv4-add-signature-to-request.html): + // + // When you add the X-Amz-Security-Token parameter to the query string, + // some services require that you include this parameter in the + // canonical (signed) request. For other services, you add this + // parameter at the end, after you calculate the signature. For + // details, see the API reference documentation for that service. + if (flags.omit_session_token and std.ascii.eqlIgnoreCase(h.name, "X-Amz-Security-Token")) { + skip = true; + break; + } if (skip) continue; total_len += (h.name.len + h.value.len + 2); @@ -718,7 +741,7 @@ test "canonical headers" { \\x-amz-date:20150830T123600Z \\ ; - const actual = try canonicalHeaders(allocator, headers.items); + const actual = try canonicalHeaders(allocator, headers.items, .{}); defer allocator.free(actual.str); defer allocator.free(actual.signed_headers); try std.testing.expectEqualStrings(expected, actual.str);