Compare commits

...

7 Commits

4 changed files with 217 additions and 15 deletions

View File

@ -15,6 +15,9 @@ pub const Options = struct {
dualstack: bool = false,
success_http_code: i64 = 200,
client: Client,
/// Used for testing to provide consistent signing. If null, will use current time
signing_time: ?i64 = null,
};
/// Using this constant may blow up build times. Recommed using Services()
@ -170,6 +173,7 @@ pub fn Request(comptime request_action: anytype) type {
.region = options.region,
.dualstack = options.dualstack,
.client = options.client,
.signing_time = options.signing_time,
});
}
@ -261,6 +265,7 @@ pub fn Request(comptime request_action: anytype) type {
.region = options.region,
.dualstack = options.dualstack,
.sigv4_service_name = Self.service_meta.sigv4_name,
.signing_time = options.signing_time,
},
);
defer response.deinit();
@ -1102,7 +1107,13 @@ pub fn IgnoringWriter(comptime WriterType: type) type {
};
}
fn reportTraffic(allocator: std.mem.Allocator, info: []const u8, request: awshttp.HttpRequest, response: awshttp.HttpResult, comptime reporter: fn (comptime []const u8, anytype) void) !void {
fn reportTraffic(
allocator: std.mem.Allocator,
info: []const u8,
request: awshttp.HttpRequest,
response: awshttp.HttpResult,
comptime reporter: fn (comptime []const u8, anytype) void,
) !void {
var msg = std.ArrayList(u8).init(allocator);
defer msg.deinit();
const writer = msg.writer();
@ -1489,6 +1500,8 @@ const TestSetup = struct {
const aws_creds = @import("aws_credentials.zig");
const aws_auth = @import("aws_authentication.zig");
const signing_time =
date.dateTimeToTimestamp(date.parseIso8601ToDateTime("20230908T170252Z") catch @compileError("Cannot parse date")) catch @compileError("Cannot parse date");
fn init(allocator: std.mem.Allocator, options: TestOptions) Self {
return .{
@ -1518,6 +1531,7 @@ const TestSetup = struct {
return .{
.region = "us-west-2",
.client = client,
.signing_time = signing_time,
};
}
@ -1965,7 +1979,9 @@ test "rest_xml_with_input: S3 put object" {
const s3opts = Options{
.region = "us-west-2",
.client = options.client,
.signing_time = TestSetup.signing_time,
};
// std.testing.log_level = .debug;
const result = try Request(services.s3.put_object).call(.{
.bucket = "mysfitszj3t6webstack-hostingbucketa91a61fe-1ep3ezkgwpxr0",
.key = "i/am/a/teapot/foo",
@ -1973,8 +1989,11 @@ test "rest_xml_with_input: S3 put object" {
.body = "bar",
.storage_class = "STANDARD",
}, s3opts);
std.log.info("PutObject Request id: {any}", .{result.response_metadata.request_id});
std.log.info("PutObject etag: {any}", .{result.response.e_tag.?});
for (test_harness.request_options.request_headers.list.items) |header| {
std.log.info("Request header: {s}: {s}", .{ header.name, header.value });
}
std.log.info("PutObject Request id: {s}", .{result.response_metadata.request_id});
std.log.info("PutObject etag: {s}", .{result.response.e_tag.?});
//mysfitszj3t6webstack-hostingbucketa91a61fe-1ep3ezkgwpxr0.s3.us-west-2.amazonaws.com
defer result.deinit();
test_harness.stop();

View File

@ -39,6 +39,9 @@ pub const Options = struct {
region: []const u8 = "aws-global",
dualstack: bool = false,
sigv4_service_name: ?[]const u8 = null,
/// Used for testing to provide consistent signing. If null, will use current time
signing_time: ?i64 = null,
};
pub const Header = base.Header;
@ -110,6 +113,7 @@ pub const AwsHttp = struct {
.region = getRegion(service, options.region),
.service = options.sigv4_service_name orelse service,
.credentials = creds,
.signing_time = options.signing_time,
};
return try self.makeRequest(endpoint, request, signing_config);
}

View File

@ -282,6 +282,136 @@ pub fn freeSignedRequest(allocator: std.mem.Allocator, request: *base.Request, c
allocator.free(request.headers);
}
pub const CredentialsFn = *const fn ([]const u8) ?auth.Credentials;
pub fn verify(allocator: std.mem.Allocator, request: std.http.Server.Request, request_body_reader: anytype, credentials_fn: CredentialsFn) !bool {
// Authorization: AWS4-HMAC-SHA256 Credential=ACCESS/20230908/us-west-2/s3/aws4_request, SignedHeaders=accept;content-length;content-type;host;x-amz-content-sha256;x-amz-date;x-amz-storage-class, Signature=fcc43ce73a34c9bd1ddf17e8a435f46a859812822f944f9eeb2aabcd64b03523
const auth_header = request.headers.getFirstValue("Authorization").?;
if (!std.mem.startsWith(u8, auth_header, "AWS4-HMAC-SHA256")) return error.UnsupportedAuthorizationType;
var credential: ?[]const u8 = null;
var signed_headers: ?[]const u8 = null;
var signature: ?[]const u8 = null;
var split_iterator = std.mem.splitSequence(u8, auth_header, " ");
while (split_iterator.next()) |auth_part| {
// NOTE: auth_part likely to end with ,
if (std.ascii.startsWithIgnoreCase(auth_part, "Credential=")) {
credential = std.mem.trim(u8, auth_part["Credential=".len..], ",");
continue;
}
if (std.ascii.startsWithIgnoreCase(auth_part, "SignedHeaders=")) {
signed_headers = std.mem.trim(u8, auth_part["SignedHeaders=".len..], ",");
continue;
}
if (std.ascii.startsWithIgnoreCase(auth_part, "Signature=")) {
signature = std.mem.trim(u8, auth_part["Signature=".len..], ",");
continue;
}
}
if (credential == null) return error.AuthorizationHeaderMissingCredential;
if (signed_headers == null) return error.AuthorizationHeaderMissingSignedHeaders;
if (signature == null) return error.AuthorizationHeaderMissingSignature;
return verifyParsedAuthorization(
allocator,
request,
request_body_reader,
credential.?,
signed_headers.?,
signature.?,
credentials_fn,
);
}
fn verifyParsedAuthorization(
allocator: std.mem.Allocator,
request: std.http.Server.Request,
request_body_reader: anytype,
credential: []const u8,
signed_headers: []const u8,
signature: []const u8,
credentials_fn: CredentialsFn,
) !bool {
// AWS4-HMAC-SHA256
// Credential=ACCESS/20230908/us-west-2/s3/aws4_request
// SignedHeaders=accept;content-length;content-type;host;x-amz-content-sha256;x-amz-date;x-amz-storage-class
// Signature=fcc43ce73a34c9bd1ddf17e8a435f46a859812822f944f9eeb2aabcd64b03523
var credential_iterator = std.mem.split(u8, credential, "/");
const access_key = credential_iterator.next().?;
const credentials = credentials_fn(access_key) orelse return error.CredentialsNotFound;
// TODO: https://stackoverflow.com/questions/29276609/aws-authentication-requires-a-valid-date-or-x-amz-date-header-curl
// For now I want to see this test pass
const normalized_iso_date = request.headers.getFirstValue("x-amz-date") orelse
request.headers.getFirstValue("Date").?;
log.debug("Got date: {s}", .{normalized_iso_date});
_ = credential_iterator.next().?; // skip the date...I don't think we need this
const region = credential_iterator.next().?;
const service = credential_iterator.next().?;
const aws4_request = credential_iterator.next().?;
if (!std.mem.eql(u8, aws4_request, "aws4_request")) return error.UnexpectedCredentialValue;
var config = Config{
.service = service,
.credentials = credentials,
.region = region,
.algorithm = .v4,
.signature_type = .headers,
.signed_body_header = .sha256,
.expiration_in_seconds = 0,
.signing_time = try date.dateTimeToTimestamp(try date.parseIso8601ToDateTime(normalized_iso_date)),
};
var headers = try allocator.alloc(base.Header, std.mem.count(u8, signed_headers, ";") + 1);
defer allocator.free(headers);
var signed_headers_iterator = std.mem.splitSequence(u8, signed_headers, ";");
var inx: usize = 0;
while (signed_headers_iterator.next()) |signed_header| {
var is_forbidden = false;
inline for (forbidden_headers) |forbidden| {
if (std.ascii.eqlIgnoreCase(forbidden.name, signed_header)) {
is_forbidden = true;
break;
}
}
if (is_forbidden) continue;
headers[inx] = .{
.name = signed_header,
.value = request.headers.getFirstValue(signed_header).?,
};
inx += 1;
}
var target_iterator = std.mem.splitSequence(u8, request.target, "?");
var signed_request = base.Request{
.path = target_iterator.first(),
.headers = headers[0..inx],
.method = @tagName(request.method),
.content_type = request.headers.getFirstValue("content-type").?,
};
signed_request.query = request.target[signed_request.path.len..]; // TODO: should this be +1? query here would include '?'
signed_request.body = try request_body_reader.readAllAlloc(allocator, std.math.maxInt(usize));
defer allocator.free(signed_request.body);
signed_request = try signRequest(allocator, signed_request, config);
defer freeSignedRequest(allocator, &signed_request, config);
return verifySignedRequest(signed_request, signature);
}
fn verifySignedRequest(signed_request: base.Request, signature: []const u8) !bool {
// We're not doing a lot of error checking here...we are all in control of this code
const auth_header = blk: {
for (signed_request.headers) |header| {
if (std.mem.eql(u8, header.name, "Authorization"))
break :blk header.value;
}
break :blk null;
};
var split_iterator = std.mem.splitSequence(u8, auth_header.?, " ");
const calculated_signature = blk: {
while (split_iterator.next()) |auth_part| {
if (std.ascii.startsWithIgnoreCase(auth_part, "Signature=")) {
break :blk std.mem.trim(u8, auth_part["Signature=".len..], ",");
}
}
break :blk null;
};
return std.mem.eql(u8, signature, calculated_signature.?);
}
fn getSigningKey(allocator: std.mem.Allocator, signing_date: []const u8, config: Config) ![]const u8 {
// TODO: This is designed for lots of caching. We need to work that out
// kSecret = your secret access key
@ -896,3 +1026,50 @@ test "can sign" {
try std.testing.expectEqualStrings("Authorization", signed_req.headers[signed_req.headers.len - 1].name);
try std.testing.expectEqualStrings(expected_auth, signed_req.headers[signed_req.headers.len - 1].value);
}
var test_credential: ?auth.Credentials = null;
test "can verify" {
const allocator = std.testing.allocator;
const access_key = try allocator.dupe(u8, "ACCESS");
const secret_key = try allocator.dupe(u8, "SECRET");
test_credential = auth.Credentials.init(allocator, access_key, secret_key, null);
defer test_credential.?.deinit();
var headers = std.http.Headers.init(allocator);
defer headers.deinit();
try headers.append("Connection", "keep-alive");
try headers.append("Accept-Encoding", "gzip, deflate, zstd");
try headers.append("TE", "gzip, deflate, trailers");
try headers.append("Accept", "application/json");
try headers.append("Host", "127.0.0.1");
try headers.append("User-Agent", "zig-aws 1.0");
try headers.append("Content-Type", "text/plain");
try headers.append("x-amz-storage-class", "STANDARD");
try headers.append("Content-Length", "3");
try headers.append("X-Amz-Date", "20230908T170252Z");
try headers.append("x-amz-content-sha256", "fcde2b2edba56bf408601fb721fe9b5c338d10ee429ea04fae5511b68fbf8fb9");
try headers.append("Authorization", "AWS4-HMAC-SHA256 Credential=ACCESS/20230908/us-west-2/s3/aws4_request, SignedHeaders=accept;content-length;content-type;host;x-amz-content-sha256;x-amz-date;x-amz-storage-class, Signature=fcc43ce73a34c9bd1ddf17e8a435f46a859812822f944f9eeb2aabcd64b03523");
var buf = "bar".*;
var fis = std.io.fixedBufferStream(&buf);
var request = std.http.Server.Request{
.method = std.http.Method.PUT,
.target = "/mysfitszj3t6webstack-hostingbucketa91a61fe-1ep3ezkgwpxr0/i/am/a/teapot/foo?x-id=PutObject",
.version = .@"HTTP/1.1",
.content_length = 3,
.headers = headers,
.parser = std.http.protocol.HeadersParser.initDynamic(std.math.maxInt(usize)),
};
// std.testing.log_level = .debug;
try std.testing.expect(try verify(allocator, request, fis.reader(), struct {
cred: auth.Credentials,
const Self = @This();
fn getCreds(access: []const u8) ?auth.Credentials {
if (std.mem.eql(u8, access, "ACCESS")) return test_credential.?;
return null;
}
}.getCreds));
}

View File

@ -147,6 +147,8 @@ pub fn parseIso8601ToDateTime(data: []const u8) !DateTime {
// Basic format YYYYMMDDThhmmss
if (data.len == "YYYYMMDDThhmmss".len and data[8] == 'T')
return try parseIso8601BasicFormatToDateTime(data);
if (data.len == "YYYYMMDDThhmmssZ".len and data[8] == 'T')
return try parseIso8601BasicFormatToDateTime(data);
var start: usize = 0;
var state = IsoParsingState.Start;
@ -228,7 +230,7 @@ fn endIsoState(current_state: IsoParsingState, date: *DateTime, prev_data: []con
}
return next_state;
}
fn dateTimeToTimestamp(datetime: DateTime) !i64 {
pub fn dateTimeToTimestamp(datetime: DateTime) !i64 {
const epoch = DateTime{
.year = 1970,
.month = 1,
@ -275,9 +277,9 @@ fn secondsBetween(start: DateTime, end: DateTime) DateTimeToTimestampError!i64 {
const leap_years_between = leapYearsBetween(start.year, end.year);
var add_days: u1 = 0;
const years_diff = end.year - start.year;
log.debug("Years from epoch: {d}, Leap years: {d}", .{ years_diff, leap_years_between });
// log.debug("Years from epoch: {d}, Leap years: {d}", .{ years_diff, leap_years_between });
var days_diff: i32 = (years_diff * DAYS_PER_YEAR) + leap_years_between + add_days;
log.debug("Days with leap year, without month: {d}", .{days_diff});
// log.debug("Days with leap year, without month: {d}", .{days_diff});
const seconds_into_year = secondsFromBeginningOfYear(
end.year,
@ -310,15 +312,15 @@ fn secondsFromBeginningOfYear(year: u16, month: u8, day: u8, hour: u8, minute: u
days_diff += days_per_month[current_month - 1]; // months are 1-based vs array is 0-based
current_month += 1;
}
log.debug("Days with month, without day: {d}. Day of month {d}, will add {d} days", .{
days_diff,
day,
day - 1,
});
// log.debug("Days with month, without day: {d}. Day of month {d}, will add {d} days", .{
// days_diff,
// day,
// day - 1,
// });
// We need -1 because we're not actually including the ending day (that's up to hour/minute)
// In other words, days in the month are 1-based, while hours/minutes are zero based
days_diff += day - 1;
log.debug("Total days diff: {d}", .{days_diff});
// log.debug("Total days diff: {d}", .{days_diff});
var seconds_diff: u32 = days_diff * SECONDS_PER_DAY;
// From here out, we want to get everything into seconds
@ -339,7 +341,7 @@ fn leapYearsBetween(start_year_inclusive: u16, end_year_exclusive: u16) u16 {
const start = @min(start_year_inclusive, end_year_exclusive);
const end = @max(start_year_inclusive, end_year_exclusive);
var current = start;
log.debug("Leap years starting from {d}, ending at {d}", .{ start, end });
// log.debug("Leap years starting from {d}, ending at {d}", .{ start, end });
while (current % 4 != 0 and current < end) {
current += 1;
}
@ -349,7 +351,7 @@ fn leapYearsBetween(start_year_inclusive: u16, end_year_exclusive: u16) u16 {
while (current < end) {
if (current % 4 == 0) {
if (current % 100 != 0) {
log.debug("Year {d} is leap year", .{current});
// log.debug("Year {d} is leap year", .{current});
rc += 1;
current += 4;
continue;
@ -357,7 +359,7 @@ fn leapYearsBetween(start_year_inclusive: u16, end_year_exclusive: u16) u16 {
// We're on a century, which is normally not a leap year, unless
// it's divisible by 400
if (current % 400 == 0) {
log.debug("Year {d} is leap year", .{current});
// log.debug("Year {d} is leap year", .{current});
rc += 1;
}
}