account id -> u40, use Access Key embedding

This commit is contained in:
Emil Lerch 2024-03-04 13:17:10 -08:00
parent 9d6527acf4
commit 4f0c608392
Signed by: lobo
GPG Key ID: A7B62D657EF764F8
5 changed files with 92 additions and 51 deletions

View File

@ -1,3 +1,4 @@
const builtin = @import("builtin");
const std = @import("std"); const std = @import("std");
const encryption = @import("encryption.zig"); const encryption = @import("encryption.zig");
const sqlite = @import("sqlite"); // TODO: If we use this across all services, Account should not have this, and we should have a localdbaccount struct const sqlite = @import("sqlite"); // TODO: If we use this across all services, Account should not have this, and we should have a localdbaccount struct
@ -10,10 +11,10 @@ const Self = @This();
allocator: std.mem.Allocator, allocator: std.mem.Allocator,
root_account_key: *[encryption.key_length]u8, root_account_key: *[encryption.key_length]u8,
pub var root_key_mapping: ?std.StringHashMap([]const u8) = null; pub var root_key_mapping: ?std.AutoHashMap(u40, []const u8) = null;
pub fn accountForId(allocator: std.mem.Allocator, account_id: []const u8) !Self { pub fn accountForId(allocator: std.mem.Allocator, account_id: u40) !Self {
if (std.mem.eql(u8, account_id, "1234")) { if (account_id == 1234) {
var key = try allocator.alloc(u8, encryption.key_length); var key = try allocator.alloc(u8, encryption.key_length);
errdefer allocator.free(key); errdefer allocator.free(key);
try encryption.decodeKey(key[0..encryption.key_length], test_account_key.*); try encryption.decodeKey(key[0..encryption.key_length], test_account_key.*);
@ -37,7 +38,7 @@ pub fn accountForId(allocator: std.mem.Allocator, account_id: []const u8) !Self
} }
// TODO: Check STS // TODO: Check STS
log.err("Got account id '{s}', but could not find this ('1234' is test account). STS GetAccessKeyInfo not implemented", .{account_id}); log.err("Got account id '{d}', but could not find this ('1234' is test account). STS GetAccessKeyInfo not implemented", .{account_id});
return error.NotImplemented; return error.NotImplemented;
} }
@ -59,14 +60,13 @@ pub fn testDbDeinit() void {
} }
/// Gets the database for this account. If under test, a memory database is used /// Gets the database for this account. If under test, a memory database is used
/// instead. Will initialize the database with appropriate metadata tables /// instead. Will initialize the database with appropriate metadata tables
pub fn dbForAccount(allocator: std.mem.Allocator, account_id: []const u8) !*sqlite.Db { pub fn dbForAccount(allocator: std.mem.Allocator, account_id: u40) !*sqlite.Db {
const builtin = @import("builtin");
if (builtin.is_test and test_retain_db) if (builtin.is_test and test_retain_db)
if (test_db) |db| return db; if (test_db) |db| return db;
// TODO: Need to move this function somewhere central // TODO: Need to move this function somewhere central
// TODO: Need configuration for what directory to use // TODO: Need configuration for what directory to use
// TODO: Should this be a pool, and if so, how would we know when to close? // TODO: Should this be a pool, and if so, how would we know when to close?
const file_without_path = try std.fmt.allocPrint(allocator, "ddb-{s}.sqlite3", .{account_id}); const file_without_path = try std.fmt.allocPrint(allocator, "ddb-{d}.sqlite3", .{account_id});
defer allocator.free(file_without_path); defer allocator.free(file_without_path);
const db_file_name = try std.fs.path.joinZ(allocator, &[_][]const u8{ data_dir, file_without_path }); const db_file_name = try std.fs.path.joinZ(allocator, &[_][]const u8{ data_dir, file_without_path });
defer allocator.free(db_file_name); defer allocator.free(db_file_name);

View File

@ -5,7 +5,7 @@ event_data: []const u8,
headers: std.http.Headers, headers: std.http.Headers,
status: std.http.Status, status: std.http.Status,
reason: ?[]const u8, reason: ?[]const u8,
account_id: []const u8, account_id: u40,
output_format: OutputFormat, output_format: OutputFormat,
pub const OutputFormat = enum { pub const OutputFormat = enum {

View File

@ -33,6 +33,13 @@ pub fn handler(request: *AuthenticatedRequest, writer: anytype) ![]const u8 {
var parsed = try std.json.parseFromSlice(std.json.Value, allocator, request.event_data, .{}); var parsed = try std.json.parseFromSlice(std.json.Value, allocator, request.event_data, .{});
defer parsed.deinit(); defer parsed.deinit();
const request_params = try parseRequest(request, parsed, writer); const request_params = try parseRequest(request, parsed, writer);
defer {
for (request_params.table_info.attribute_definitions) |d| {
allocator.free(d.*.name);
allocator.destroy(d);
}
allocator.free(request_params.table_info.attribute_definitions);
}
// Parsing does most validation for us, but we also need to make sure that // Parsing does most validation for us, but we also need to make sure that
// the attributes specified in the key schema actually exist // the attributes specified in the key schema actually exist
var found_keys: u2 = if (request_params.table_info.range_key_attribute_name == null) 0b01 else 0b00; var found_keys: u2 = if (request_params.table_info.range_key_attribute_name == null) 0b01 else 0b00;
@ -54,13 +61,6 @@ pub fn handler(request: *AuthenticatedRequest, writer: anytype) ![]const u8 {
writer, writer,
"Attribute names in KeySchema must also exist in AttributeDefinitions", "Attribute names in KeySchema must also exist in AttributeDefinitions",
); );
defer {
for (request_params.table_info.attribute_definitions) |d| {
allocator.free(d.*.name);
allocator.destroy(d);
}
allocator.free(request_params.table_info.attribute_definitions);
}
var db = try Account.dbForAccount(allocator, account_id); var db = try Account.dbForAccount(allocator, account_id);
defer allocator.destroy(db); defer allocator.destroy(db);
defer db.deinit(); defer db.deinit();
@ -144,7 +144,7 @@ pub fn handler(request: *AuthenticatedRequest, writer: anytype) ![]const u8 {
var al = std.ArrayList(u8).init(allocator); var al = std.ArrayList(u8).init(allocator);
var response_writer = al.writer(); var response_writer = al.writer();
try response_writer.print("table created for account {s}\n", .{account_id}); try response_writer.print("table created for account {d}\n", .{account_id});
return al.toOwnedSlice(); return al.toOwnedSlice();
} }

View File

@ -482,7 +482,7 @@ pub const Table = struct {
/// are stored in here, realistically, this will be the first function called /// are stored in here, realistically, this will be the first function called
/// every time anything interacts with the database, so this function opens /// every time anything interacts with the database, so this function opens
/// the database for you /// the database for you
pub fn tablesForAccount(allocator: std.mem.Allocator, account_id: []const u8) !AccountTables { pub fn tablesForAccount(allocator: std.mem.Allocator, account_id: u40) !AccountTables {
// TODO: This function should take a list of table names, which can then be used // TODO: This function should take a list of table names, which can then be used
// to filter the query below rather than just grabbing everything // to filter the query below rather than just grabbing everything
@ -676,7 +676,7 @@ fn insertIntoDm(
}); });
} }
fn testCreateTable(allocator: std.mem.Allocator, account_id: []const u8) !*sqlite.Db { fn testCreateTable(allocator: std.mem.Allocator, account_id: u40) !*sqlite.Db {
var db = try Account.dbForAccount(allocator, account_id); var db = try Account.dbForAccount(allocator, account_id);
const account = try Account.accountForId(allocator, account_id); // This will get us the encryption key needed const account = try Account.accountForId(allocator, account_id); // This will get us the encryption key needed
defer account.deinit(); defer account.deinit();
@ -707,7 +707,7 @@ fn testCreateTable(allocator: std.mem.Allocator, account_id: []const u8) !*sqlit
} }
test "can create a table" { test "can create a table" {
const allocator = std.testing.allocator; const allocator = std.testing.allocator;
const account_id = "1234"; const account_id = 1234;
var db = try testCreateTable(allocator, account_id); var db = try testCreateTable(allocator, account_id);
defer allocator.destroy(db); defer allocator.destroy(db);
defer db.deinit(); defer db.deinit();
@ -715,7 +715,7 @@ test "can create a table" {
test "can list tables in an account" { test "can list tables in an account" {
Account.test_retain_db = true; Account.test_retain_db = true;
const allocator = std.testing.allocator; const allocator = std.testing.allocator;
const account_id = "1234"; const account_id = 1234;
var db = try testCreateTable(allocator, account_id); var db = try testCreateTable(allocator, account_id);
defer allocator.destroy(db); defer allocator.destroy(db);
defer Account.testDbDeinit(); defer Account.testDbDeinit();
@ -729,7 +729,7 @@ test "can list tables in an account" {
test "can put an item in a table in an account" { test "can put an item in a table in an account" {
Account.test_retain_db = true; Account.test_retain_db = true;
const allocator = std.testing.allocator; const allocator = std.testing.allocator;
const account_id = "1234"; const account_id = 1234;
var db = try testCreateTable(allocator, account_id); var db = try testCreateTable(allocator, account_id);
defer allocator.destroy(db); defer allocator.destroy(db);
defer Account.testDbDeinit(); defer Account.testDbDeinit();

View File

@ -1,3 +1,4 @@
const builtin = @import("builtin");
const std = @import("std"); const std = @import("std");
const universal_lambda = @import("universal_lambda_handler"); const universal_lambda = @import("universal_lambda_handler");
const universal_lambda_interface = @import("universal_lambda_interface"); const universal_lambda_interface = @import("universal_lambda_interface");
@ -24,24 +25,6 @@ pub fn main() !u8 {
} }
pub fn handler(allocator: std.mem.Allocator, event_data: []const u8, context: universal_lambda_interface.Context) ![]const u8 { pub fn handler(allocator: std.mem.Allocator, event_data: []const u8, context: universal_lambda_interface.Context) ![]const u8 {
const builtin = @import("builtin");
var rss: if (builtin.os.tag == .linux) std.os.rusage else usize = undefined;
if (builtin.os.tag == .linux and builtin.mode == .Debug)
rss = std.os.getrusage(std.os.rusage.SELF);
defer if (builtin.os.tag == .linux and builtin.mode == .Debug) { // and debug mode) {
const rusage = std.os.getrusage(std.os.rusage.SELF);
log.debug(
"Request complete, max RSS of process: {d}M. Incremental: {d}K, User: {d}μs, System: {d}μs",
.{
@divTrunc(rusage.maxrss, 1024),
rusage.maxrss - rss.maxrss,
(rusage.utime.tv_sec - rss.utime.tv_sec) * std.time.us_per_s +
rusage.utime.tv_usec - rss.utime.tv_usec,
(rusage.stime.tv_sec - rss.stime.tv_sec) * std.time.us_per_s +
rusage.stime.tv_usec - rss.stime.tv_usec,
},
);
};
const access_key = try allocator.dupe(u8, "ACCESS"); const access_key = try allocator.dupe(u8, "ACCESS");
const secret_key = try allocator.dupe(u8, "SECRET"); const secret_key = try allocator.dupe(u8, "SECRET");
test_credential = signing.Credentials.init(allocator, access_key, secret_key, null); test_credential = signing.Credentials.init(allocator, access_key, secret_key, null);
@ -146,7 +129,7 @@ fn authenticateUser(allocator: std.mem.Allocator, context: universal_lambda_inte
var test_credential: signing.Credentials = undefined; var test_credential: signing.Credentials = undefined;
var root_creds: std.StringHashMap(signing.Credentials) = undefined; var root_creds: std.StringHashMap(signing.Credentials) = undefined;
var root_account_mapping: std.StringHashMap([]const u8) = undefined; // var root_account_mapping: std.StringHashMap([]const u8) = undefined;
var creds_buf: [8192]u8 = undefined; var creds_buf: [8192]u8 = undefined;
fn getCreds(access: []const u8) ?signing.Credentials { fn getCreds(access: []const u8) ?signing.Credentials {
// We have 3 levels of access here // We have 3 levels of access here
@ -163,8 +146,8 @@ fn getCreds(access: []const u8) ?signing.Credentials {
fn fillRootCreds(allocator: std.mem.Allocator) !void { fn fillRootCreds(allocator: std.mem.Allocator) !void {
root_creds = std.StringHashMap(signing.Credentials).init(allocator); root_creds = std.StringHashMap(signing.Credentials).init(allocator);
root_account_mapping = std.StringHashMap([]const u8).init(allocator); // root_account_mapping = std.StringHashMap([]const u8).init(allocator);
Account.root_key_mapping = std.StringHashMap([]const u8).init(allocator); Account.root_key_mapping = std.AutoHashMap(u40, []const u8).init(allocator);
var file = std.fs.cwd().openFile("access_keys.csv", .{}) catch |e| { var file = std.fs.cwd().openFile("access_keys.csv", .{}) catch |e| {
log.err("Could not open access_keys.csv to access root creds: {}", .{e}); log.err("Could not open access_keys.csv to access root creds: {}", .{e});
return e; return e;
@ -219,8 +202,9 @@ fn fillRootCreds(allocator: std.mem.Allocator) !void {
.session_token = null, .session_token = null,
.allocator = NullAllocator.init(), .allocator = NullAllocator.init(),
}); });
const global_account_id = try allocator.dupe(u8, account_id); const global_account_id = try std.fmt.parseInt(u40, account_id, 10);
try root_account_mapping.put(global_access_key, global_account_id); // unnecessary. Account ids are embedded in access keys!
// try root_account_mapping.put(global_access_key, global_account_id);
try Account.root_key_mapping.?.put(global_account_id, try allocator.dupe(u8, existing_key)); try Account.root_key_mapping.?.put(global_account_id, try allocator.dupe(u8, existing_key));
// TODO: key rotation will need another hash map, can be triggered on val_num == 5 // TODO: key rotation will need another hash map, can be triggered on val_num == 5
@ -270,19 +254,21 @@ const NullAllocator = struct {
} }
}; };
fn accountForAccessKey(allocator: std.mem.Allocator, access_key: []const u8) ![]const u8 { fn accountForAccessKey(allocator: std.mem.Allocator, access_key: []const u8) !u40 {
_ = allocator; _ = allocator;
log.debug("Finding account for access key: '{s}'", .{access_key}); log.debug("Finding account for access key: '{s}'", .{access_key});
if (access_key.len != 20) return error.InvalidAccessKey;
return try accountIdForAccessKey(@as(*[20]u8, @ptrCast(@constCast(access_key))).*);
// Since this happens after authentication, we can assume our root creds store // Since this happens after authentication, we can assume our root creds store
// is populated // is populated
if (root_account_mapping.get(access_key)) |account| return account; // if (root_account_mapping.get(access_key)) |account| return account;
log.err("Creds not found in store. STS GetAccessKeyInfo call is not yet implemented", .{}); // log.err("Creds not found in store. STS GetAccessKeyInfo call is not yet implemented", .{});
return error.NotImplemented; // return error.NotImplemented;
} }
/// Function assumes an authenticated request, so signing.verify must be called /// Function assumes an authenticated request, so signing.verify must be called
/// and returned true before calling this function. If authentication header /// and returned true before calling this function. If authentication header
/// is not found, environment variable will be used /// is not found, environment variable will be used
fn accountId(allocator: std.mem.Allocator, headers: std.http.Headers) ![]const u8 { fn accountId(allocator: std.mem.Allocator, headers: std.http.Headers) !u40 {
const auth_header = headers.getFirstValue("Authorization"); const auth_header = headers.getFirstValue("Authorization");
if (auth_header) |h| { if (auth_header) |h| {
// AWS4-HMAC-SHA256 Credential=ACCESS/20230908/us-west-2/s3/aws4_request, SignedHeaders=accept;content-length;content-type;host;x-amz-content-sha256;x-amz-date;x-amz-storage-class, Signature=fcc43ce73a34c9bd1ddf17e8a435f46a859812822f944f9eeb2aabcd64b03523 // AWS4-HMAC-SHA256 Credential=ACCESS/20230908/us-west-2/s3/aws4_request, SignedHeaders=accept;content-length;content-type;host;x-amz-content-sha256;x-amz-date;x-amz-storage-class, Signature=fcc43ce73a34c9bd1ddf17e8a435f46a859812822f944f9eeb2aabcd64b03523
@ -326,8 +312,8 @@ fn iamCredentials(allocator: std.mem.Allocator) ![]const u8 {
iam_credential = signing.Credentials.init(allocator, try iamAccessKey(allocator), try iamSecretKey(allocator), null); iam_credential = signing.Credentials.init(allocator, try iamAccessKey(allocator), try iamSecretKey(allocator), null);
return iam_credential.?; return iam_credential.?;
} }
fn iamAccountId(allocator: std.mem.Allocator) ![]const u8 { fn iamAccountId(allocator: std.mem.Allocator) !u40 {
return try getVariable(allocator, &iam_account_id, "IAM_ACCOUNT_ID"); return std.fmt.parseInt(u40, try getVariable(allocator, &iam_account_id, "IAM_ACCOUNT_ID"), 10);
} }
fn iamAccessKey(allocator: std.mem.Allocator) ![]const u8 { fn iamAccessKey(allocator: std.mem.Allocator) ![]const u8 {
return try getVariable(allocator, &iam_access_key, "IAM_ACCESS_KEY"); return try getVariable(allocator, &iam_access_key, "IAM_ACCESS_KEY");
@ -346,3 +332,58 @@ test {
std.testing.refAllDecls(@import("batchwriteitem.zig")); std.testing.refAllDecls(@import("batchwriteitem.zig"));
std.testing.refAllDecls(@import("batchgetitem.zig")); std.testing.refAllDecls(@import("batchgetitem.zig"));
} }
test "can get account id from access key" {
// ELAKM5YGIGQQAD2B54IZ, Account 888534479904
// Also, https://medium.com/@TalBeerySec/a-short-note-on-aws-key-id-f88cc4317489
// aws_access_key_id: ASIAY34FZKBOKMUTVV7A yields the expected account id "609629065308"
try std.testing.expectEqual(@as(u40, 609629065308), try accountIdForAccessKey(@as(*[20]u8, @ptrCast(@constCast("ASIAY34FZKBOKMUTVV7A"))).*));
try std.testing.expectEqual(@as(u40, 888534479904), try accountIdForAccessKey(@as(*[20]u8, @ptrCast(@constCast("ELAKM5YGIGQQAD2B54IZ"))).*));
}
fn accountIdForAccessKey(access_key: [20]u8) !u40 {
const ak_integer_part = access_key[4..];
const ak_integer = try base32Decode(u80, @as(*[16]u8, @ptrCast(@constCast(ak_integer_part.ptr))).*);
const account_id = ak_integer >> 39;
return @as(u40, @truncate(account_id));
// Do we want an array like this? Probably so
// import base64
// import binascii
//
// def AWSAccount_from_AWSKeyID(AWSKeyID):
//
// trimmed_AWSKeyID = AWSKeyID[4:] #remove KeyID prefix
// x = base64.b32decode(trimmed_AWSKeyID) #base32 decode
// y = x[0:6]
//
// z = int.from_bytes(y, byteorder='big', signed=False)
// mask = int.from_bytes(binascii.unhexlify(b'7fffffffff80'), byteorder='big', signed=False)
//
// e = (z & mask)>>7
// return (e)
}
fn base32Decode(comptime T: type, data: [@typeInfo(T).Int.bits / 5]u8) !T {
// const alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567";
const ti = @typeInfo(T);
if (ti != .Int or ti.Int.signedness != .unsigned)
@compileError("decode only works with unsigned integers");
if (ti.Int.bits % 5 != 0)
@compileError("unsigned integer bit length must be a multiple of 5 to use this function");
const Shift_type = @Type(.{ .Int = .{
.signedness = .unsigned,
.bits = @ceil(@log2(@as(f128, @floatFromInt(ti.Int.bits)))),
} });
var rc: T = 0;
for (data, 0..) |b, i| {
var curr: T = 0;
if (b >= 'A' and b <= 'Z') {
curr = b - 'A';
} else if (b >= '2' and b <= '7') {
curr = b - '2' + 26;
} else return error.InvalidCharacter;
curr <<= @as(Shift_type, @intCast((data.len - 1 - i) * 5));
rc |= curr;
}
return rc;
}