diff --git a/src/aws_signing.zig b/src/aws_signing.zig index 4454a41..72c61ce 100644 --- a/src/aws_signing.zig +++ b/src/aws_signing.zig @@ -282,6 +282,136 @@ pub fn freeSignedRequest(allocator: std.mem.Allocator, request: *base.Request, c allocator.free(request.headers); } +pub const CredentialsFn = *const fn ([]const u8) ?auth.Credentials; +pub fn verify(allocator: std.mem.Allocator, request: std.http.Server.Request, request_body_reader: anytype, credentials_fn: CredentialsFn) !bool { + // Authorization: AWS4-HMAC-SHA256 Credential=ACCESS/20230908/us-west-2/s3/aws4_request, SignedHeaders=accept;content-length;content-type;host;x-amz-content-sha256;x-amz-date;x-amz-storage-class, Signature=fcc43ce73a34c9bd1ddf17e8a435f46a859812822f944f9eeb2aabcd64b03523 + const auth_header = request.headers.getFirstValue("Authorization").?; + if (!std.mem.startsWith(u8, auth_header, "AWS4-HMAC-SHA256")) return error.UnsupportedAuthorizationType; + var credential: ?[]const u8 = null; + var signed_headers: ?[]const u8 = null; + var signature: ?[]const u8 = null; + var split_iterator = std.mem.splitSequence(u8, auth_header, " "); + while (split_iterator.next()) |auth_part| { + // NOTE: auth_part likely to end with , + if (std.ascii.startsWithIgnoreCase(auth_part, "Credential=")) { + credential = std.mem.trim(u8, auth_part["Credential=".len..], ","); + continue; + } + if (std.ascii.startsWithIgnoreCase(auth_part, "SignedHeaders=")) { + signed_headers = std.mem.trim(u8, auth_part["SignedHeaders=".len..], ","); + continue; + } + if (std.ascii.startsWithIgnoreCase(auth_part, "Signature=")) { + signature = std.mem.trim(u8, auth_part["Signature=".len..], ","); + continue; + } + } + if (credential == null) return error.AuthorizationHeaderMissingCredential; + if (signed_headers == null) return error.AuthorizationHeaderMissingSignedHeaders; + if (signature == null) return error.AuthorizationHeaderMissingSignature; + return verifyParsedAuthorization( + allocator, + request, + request_body_reader, + credential.?, + signed_headers.?, + signature.?, + credentials_fn, + ); +} + +fn verifyParsedAuthorization( + allocator: std.mem.Allocator, + request: std.http.Server.Request, + request_body_reader: anytype, + credential: []const u8, + signed_headers: []const u8, + signature: []const u8, + credentials_fn: CredentialsFn, +) !bool { + // AWS4-HMAC-SHA256 + // Credential=ACCESS/20230908/us-west-2/s3/aws4_request + // SignedHeaders=accept;content-length;content-type;host;x-amz-content-sha256;x-amz-date;x-amz-storage-class + // Signature=fcc43ce73a34c9bd1ddf17e8a435f46a859812822f944f9eeb2aabcd64b03523 + var credential_iterator = std.mem.split(u8, credential, "/"); + const access_key = credential_iterator.next().?; + const credentials = credentials_fn(access_key) orelse return error.CredentialsNotFound; + // TODO: https://stackoverflow.com/questions/29276609/aws-authentication-requires-a-valid-date-or-x-amz-date-header-curl + // For now I want to see this test pass + const normalized_iso_date = request.headers.getFirstValue("x-amz-date") orelse + request.headers.getFirstValue("Date").?; + log.debug("Got date: {s}", .{normalized_iso_date}); + _ = credential_iterator.next().?; // skip the date...I don't think we need this + const region = credential_iterator.next().?; + const service = credential_iterator.next().?; + const aws4_request = credential_iterator.next().?; + if (!std.mem.eql(u8, aws4_request, "aws4_request")) return error.UnexpectedCredentialValue; + var config = Config{ + .service = service, + .credentials = credentials, + .region = region, + .algorithm = .v4, + .signature_type = .headers, + .signed_body_header = .sha256, + .expiration_in_seconds = 0, + .signing_time = try date.dateTimeToTimestamp(try date.parseIso8601ToDateTime(normalized_iso_date)), + }; + + var headers = try allocator.alloc(base.Header, std.mem.count(u8, signed_headers, ";") + 1); + defer allocator.free(headers); + var signed_headers_iterator = std.mem.splitSequence(u8, signed_headers, ";"); + var inx: usize = 0; + while (signed_headers_iterator.next()) |signed_header| { + var is_forbidden = false; + inline for (forbidden_headers) |forbidden| { + if (std.ascii.eqlIgnoreCase(forbidden.name, signed_header)) { + is_forbidden = true; + break; + } + } + if (is_forbidden) continue; + headers[inx] = .{ + .name = signed_header, + .value = request.headers.getFirstValue(signed_header).?, + }; + inx += 1; + } + var target_iterator = std.mem.splitSequence(u8, request.target, "?"); + var signed_request = base.Request{ + .path = target_iterator.first(), + .headers = headers[0..inx], + .method = @tagName(request.method), + .content_type = request.headers.getFirstValue("content-type").?, + }; + signed_request.query = request.target[signed_request.path.len..]; // TODO: should this be +1? query here would include '?' + signed_request.body = try request_body_reader.readAllAlloc(allocator, std.math.maxInt(usize)); + defer allocator.free(signed_request.body); + signed_request = try signRequest(allocator, signed_request, config); + defer freeSignedRequest(allocator, &signed_request, config); + return verifySignedRequest(signed_request, signature); +} + +fn verifySignedRequest(signed_request: base.Request, signature: []const u8) !bool { + // We're not doing a lot of error checking here...we are all in control of this code + const auth_header = blk: { + for (signed_request.headers) |header| { + if (std.mem.eql(u8, header.name, "Authorization")) + break :blk header.value; + } + break :blk null; + }; + var split_iterator = std.mem.splitSequence(u8, auth_header.?, " "); + const calculated_signature = blk: { + while (split_iterator.next()) |auth_part| { + if (std.ascii.startsWithIgnoreCase(auth_part, "Signature=")) { + break :blk std.mem.trim(u8, auth_part["Signature=".len..], ","); + } + } + break :blk null; + }; + return std.mem.eql(u8, signature, calculated_signature.?); +} + fn getSigningKey(allocator: std.mem.Allocator, signing_date: []const u8, config: Config) ![]const u8 { // TODO: This is designed for lots of caching. We need to work that out // kSecret = your secret access key @@ -896,3 +1026,50 @@ test "can sign" { try std.testing.expectEqualStrings("Authorization", signed_req.headers[signed_req.headers.len - 1].name); try std.testing.expectEqualStrings(expected_auth, signed_req.headers[signed_req.headers.len - 1].value); } + +var test_credential: ?auth.Credentials = null; +test "can verify" { + const allocator = std.testing.allocator; + + const access_key = try allocator.dupe(u8, "ACCESS"); + const secret_key = try allocator.dupe(u8, "SECRET"); + test_credential = auth.Credentials.init(allocator, access_key, secret_key, null); + defer test_credential.?.deinit(); + + var headers = std.http.Headers.init(allocator); + defer headers.deinit(); + try headers.append("Connection", "keep-alive"); + try headers.append("Accept-Encoding", "gzip, deflate, zstd"); + try headers.append("TE", "gzip, deflate, trailers"); + try headers.append("Accept", "application/json"); + try headers.append("Host", "127.0.0.1"); + try headers.append("User-Agent", "zig-aws 1.0"); + try headers.append("Content-Type", "text/plain"); + try headers.append("x-amz-storage-class", "STANDARD"); + try headers.append("Content-Length", "3"); + try headers.append("X-Amz-Date", "20230908T170252Z"); + try headers.append("x-amz-content-sha256", "fcde2b2edba56bf408601fb721fe9b5c338d10ee429ea04fae5511b68fbf8fb9"); + try headers.append("Authorization", "AWS4-HMAC-SHA256 Credential=ACCESS/20230908/us-west-2/s3/aws4_request, SignedHeaders=accept;content-length;content-type;host;x-amz-content-sha256;x-amz-date;x-amz-storage-class, Signature=fcc43ce73a34c9bd1ddf17e8a435f46a859812822f944f9eeb2aabcd64b03523"); + + var buf = "bar".*; + var fis = std.io.fixedBufferStream(&buf); + var request = std.http.Server.Request{ + .method = std.http.Method.PUT, + .target = "/mysfitszj3t6webstack-hostingbucketa91a61fe-1ep3ezkgwpxr0/i/am/a/teapot/foo?x-id=PutObject", + .version = .@"HTTP/1.1", + .content_length = 3, + .headers = headers, + .parser = std.http.protocol.HeadersParser.initDynamic(std.math.maxInt(usize)), + }; + + // std.testing.log_level = .debug; + try std.testing.expect(try verify(allocator, request, fis.reader(), struct { + cred: auth.Credentials, + + const Self = @This(); + fn getCreds(access: []const u8) ?auth.Credentials { + if (std.mem.eql(u8, access, "ACCESS")) return test_credential.?; + return null; + } + }.getCreds)); +}