account id -> u40, use Access Key embedding
This commit is contained in:
		
							parent
							
								
									9d6527acf4
								
							
						
					
					
						commit
						4f0c608392
					
				
					 5 changed files with 92 additions and 51 deletions
				
			
		|  | @ -1,3 +1,4 @@ | |||
| const builtin = @import("builtin"); | ||||
| const std = @import("std"); | ||||
| const encryption = @import("encryption.zig"); | ||||
| const sqlite = @import("sqlite"); // TODO: If we use this across all services, Account should not have this, and we should have a localdbaccount struct | ||||
|  | @ -10,10 +11,10 @@ const Self = @This(); | |||
| allocator: std.mem.Allocator, | ||||
| root_account_key: *[encryption.key_length]u8, | ||||
| 
 | ||||
| pub var root_key_mapping: ?std.StringHashMap([]const u8) = null; | ||||
| pub var root_key_mapping: ?std.AutoHashMap(u40, []const u8) = null; | ||||
| 
 | ||||
| pub fn accountForId(allocator: std.mem.Allocator, account_id: []const u8) !Self { | ||||
|     if (std.mem.eql(u8, account_id, "1234")) { | ||||
| pub fn accountForId(allocator: std.mem.Allocator, account_id: u40) !Self { | ||||
|     if (account_id == 1234) { | ||||
|         var key = try allocator.alloc(u8, encryption.key_length); | ||||
|         errdefer allocator.free(key); | ||||
|         try encryption.decodeKey(key[0..encryption.key_length], test_account_key.*); | ||||
|  | @ -37,7 +38,7 @@ pub fn accountForId(allocator: std.mem.Allocator, account_id: []const u8) !Self | |||
|     } | ||||
| 
 | ||||
|     // TODO: Check STS | ||||
|     log.err("Got account id '{s}', but could not find this ('1234' is test account). STS GetAccessKeyInfo not implemented", .{account_id}); | ||||
|     log.err("Got account id '{d}', but could not find this ('1234' is test account). STS GetAccessKeyInfo not implemented", .{account_id}); | ||||
|     return error.NotImplemented; | ||||
| } | ||||
| 
 | ||||
|  | @ -59,14 +60,13 @@ pub fn testDbDeinit() void { | |||
| } | ||||
| /// Gets the database for this account. If under test, a memory database is used | ||||
| /// instead. Will initialize the database with appropriate metadata tables | ||||
| pub fn dbForAccount(allocator: std.mem.Allocator, account_id: []const u8) !*sqlite.Db { | ||||
|     const builtin = @import("builtin"); | ||||
| pub fn dbForAccount(allocator: std.mem.Allocator, account_id: u40) !*sqlite.Db { | ||||
|     if (builtin.is_test and test_retain_db) | ||||
|         if (test_db) |db| return db; | ||||
|     // TODO: Need to move this function somewhere central | ||||
|     // TODO: Need configuration for what directory to use | ||||
|     // TODO: Should this be a pool, and if so, how would we know when to close? | ||||
|     const file_without_path = try std.fmt.allocPrint(allocator, "ddb-{s}.sqlite3", .{account_id}); | ||||
|     const file_without_path = try std.fmt.allocPrint(allocator, "ddb-{d}.sqlite3", .{account_id}); | ||||
|     defer allocator.free(file_without_path); | ||||
|     const db_file_name = try std.fs.path.joinZ(allocator, &[_][]const u8{ data_dir, file_without_path }); | ||||
|     defer allocator.free(db_file_name); | ||||
|  |  | |||
|  | @ -5,7 +5,7 @@ event_data: []const u8, | |||
| headers: std.http.Headers, | ||||
| status: std.http.Status, | ||||
| reason: ?[]const u8, | ||||
| account_id: []const u8, | ||||
| account_id: u40, | ||||
| output_format: OutputFormat, | ||||
| 
 | ||||
| pub const OutputFormat = enum { | ||||
|  |  | |||
|  | @ -33,6 +33,13 @@ pub fn handler(request: *AuthenticatedRequest, writer: anytype) ![]const u8 { | |||
|     var parsed = try std.json.parseFromSlice(std.json.Value, allocator, request.event_data, .{}); | ||||
|     defer parsed.deinit(); | ||||
|     const request_params = try parseRequest(request, parsed, writer); | ||||
|     defer { | ||||
|         for (request_params.table_info.attribute_definitions) |d| { | ||||
|             allocator.free(d.*.name); | ||||
|             allocator.destroy(d); | ||||
|         } | ||||
|         allocator.free(request_params.table_info.attribute_definitions); | ||||
|     } | ||||
|     // Parsing does most validation for us, but we also need to make sure that | ||||
|     // the attributes specified in the key schema actually exist | ||||
|     var found_keys: u2 = if (request_params.table_info.range_key_attribute_name == null) 0b01 else 0b00; | ||||
|  | @ -54,13 +61,6 @@ pub fn handler(request: *AuthenticatedRequest, writer: anytype) ![]const u8 { | |||
|             writer, | ||||
|             "Attribute names in KeySchema must also exist in AttributeDefinitions", | ||||
|         ); | ||||
|     defer { | ||||
|         for (request_params.table_info.attribute_definitions) |d| { | ||||
|             allocator.free(d.*.name); | ||||
|             allocator.destroy(d); | ||||
|         } | ||||
|         allocator.free(request_params.table_info.attribute_definitions); | ||||
|     } | ||||
|     var db = try Account.dbForAccount(allocator, account_id); | ||||
|     defer allocator.destroy(db); | ||||
|     defer db.deinit(); | ||||
|  | @ -144,7 +144,7 @@ pub fn handler(request: *AuthenticatedRequest, writer: anytype) ![]const u8 { | |||
| 
 | ||||
|     var al = std.ArrayList(u8).init(allocator); | ||||
|     var response_writer = al.writer(); | ||||
|     try response_writer.print("table created for account {s}\n", .{account_id}); | ||||
|     try response_writer.print("table created for account {d}\n", .{account_id}); | ||||
|     return al.toOwnedSlice(); | ||||
| } | ||||
| 
 | ||||
|  |  | |||
							
								
								
									
										10
									
								
								src/ddb.zig
									
										
									
									
									
								
							
							
						
						
									
										10
									
								
								src/ddb.zig
									
										
									
									
									
								
							|  | @ -482,7 +482,7 @@ pub const Table = struct { | |||
| /// are stored in here, realistically, this will be the first function called | ||||
| /// every time anything interacts with the database, so this function opens | ||||
| /// the database for you | ||||
| pub fn tablesForAccount(allocator: std.mem.Allocator, account_id: []const u8) !AccountTables { | ||||
| pub fn tablesForAccount(allocator: std.mem.Allocator, account_id: u40) !AccountTables { | ||||
| 
 | ||||
|     // TODO: This function should take a list of table names, which can then be used | ||||
|     // to filter the query below rather than just grabbing everything | ||||
|  | @ -676,7 +676,7 @@ fn insertIntoDm( | |||
|     }); | ||||
| } | ||||
| 
 | ||||
| fn testCreateTable(allocator: std.mem.Allocator, account_id: []const u8) !*sqlite.Db { | ||||
| fn testCreateTable(allocator: std.mem.Allocator, account_id: u40) !*sqlite.Db { | ||||
|     var db = try Account.dbForAccount(allocator, account_id); | ||||
|     const account = try Account.accountForId(allocator, account_id); // This will get us the encryption key needed | ||||
|     defer account.deinit(); | ||||
|  | @ -707,7 +707,7 @@ fn testCreateTable(allocator: std.mem.Allocator, account_id: []const u8) !*sqlit | |||
| } | ||||
| test "can create a table" { | ||||
|     const allocator = std.testing.allocator; | ||||
|     const account_id = "1234"; | ||||
|     const account_id = 1234; | ||||
|     var db = try testCreateTable(allocator, account_id); | ||||
|     defer allocator.destroy(db); | ||||
|     defer db.deinit(); | ||||
|  | @ -715,7 +715,7 @@ test "can create a table" { | |||
| test "can list tables in an account" { | ||||
|     Account.test_retain_db = true; | ||||
|     const allocator = std.testing.allocator; | ||||
|     const account_id = "1234"; | ||||
|     const account_id = 1234; | ||||
|     var db = try testCreateTable(allocator, account_id); | ||||
|     defer allocator.destroy(db); | ||||
|     defer Account.testDbDeinit(); | ||||
|  | @ -729,7 +729,7 @@ test "can list tables in an account" { | |||
| test "can put an item in a table in an account" { | ||||
|     Account.test_retain_db = true; | ||||
|     const allocator = std.testing.allocator; | ||||
|     const account_id = "1234"; | ||||
|     const account_id = 1234; | ||||
|     var db = try testCreateTable(allocator, account_id); | ||||
|     defer allocator.destroy(db); | ||||
|     defer Account.testDbDeinit(); | ||||
|  |  | |||
							
								
								
									
										101
									
								
								src/main.zig
									
										
									
									
									
								
							
							
						
						
									
										101
									
								
								src/main.zig
									
										
									
									
									
								
							|  | @ -1,3 +1,4 @@ | |||
| const builtin = @import("builtin"); | ||||
| const std = @import("std"); | ||||
| const universal_lambda = @import("universal_lambda_handler"); | ||||
| const universal_lambda_interface = @import("universal_lambda_interface"); | ||||
|  | @ -24,24 +25,6 @@ pub fn main() !u8 { | |||
| } | ||||
| 
 | ||||
| pub fn handler(allocator: std.mem.Allocator, event_data: []const u8, context: universal_lambda_interface.Context) ![]const u8 { | ||||
|     const builtin = @import("builtin"); | ||||
|     var rss: if (builtin.os.tag == .linux) std.os.rusage else usize = undefined; | ||||
|     if (builtin.os.tag == .linux and builtin.mode == .Debug) | ||||
|         rss = std.os.getrusage(std.os.rusage.SELF); | ||||
|     defer if (builtin.os.tag == .linux and builtin.mode == .Debug) { // and  debug mode) { | ||||
|         const rusage = std.os.getrusage(std.os.rusage.SELF); | ||||
|         log.debug( | ||||
|             "Request complete, max RSS of process: {d}M. Incremental: {d}K, User: {d}μs, System: {d}μs", | ||||
|             .{ | ||||
|                 @divTrunc(rusage.maxrss, 1024), | ||||
|                 rusage.maxrss - rss.maxrss, | ||||
|                 (rusage.utime.tv_sec - rss.utime.tv_sec) * std.time.us_per_s + | ||||
|                     rusage.utime.tv_usec - rss.utime.tv_usec, | ||||
|                 (rusage.stime.tv_sec - rss.stime.tv_sec) * std.time.us_per_s + | ||||
|                     rusage.stime.tv_usec - rss.stime.tv_usec, | ||||
|             }, | ||||
|         ); | ||||
|     }; | ||||
|     const access_key = try allocator.dupe(u8, "ACCESS"); | ||||
|     const secret_key = try allocator.dupe(u8, "SECRET"); | ||||
|     test_credential = signing.Credentials.init(allocator, access_key, secret_key, null); | ||||
|  | @ -146,7 +129,7 @@ fn authenticateUser(allocator: std.mem.Allocator, context: universal_lambda_inte | |||
| 
 | ||||
| var test_credential: signing.Credentials = undefined; | ||||
| var root_creds: std.StringHashMap(signing.Credentials) = undefined; | ||||
| var root_account_mapping: std.StringHashMap([]const u8) = undefined; | ||||
| // var root_account_mapping: std.StringHashMap([]const u8) = undefined; | ||||
| var creds_buf: [8192]u8 = undefined; | ||||
| fn getCreds(access: []const u8) ?signing.Credentials { | ||||
|     // We have 3 levels of access here | ||||
|  | @ -163,8 +146,8 @@ fn getCreds(access: []const u8) ?signing.Credentials { | |||
| 
 | ||||
| fn fillRootCreds(allocator: std.mem.Allocator) !void { | ||||
|     root_creds = std.StringHashMap(signing.Credentials).init(allocator); | ||||
|     root_account_mapping = std.StringHashMap([]const u8).init(allocator); | ||||
|     Account.root_key_mapping = std.StringHashMap([]const u8).init(allocator); | ||||
|     // root_account_mapping = std.StringHashMap([]const u8).init(allocator); | ||||
|     Account.root_key_mapping = std.AutoHashMap(u40, []const u8).init(allocator); | ||||
|     var file = std.fs.cwd().openFile("access_keys.csv", .{}) catch |e| { | ||||
|         log.err("Could not open access_keys.csv to access root creds: {}", .{e}); | ||||
|         return e; | ||||
|  | @ -219,8 +202,9 @@ fn fillRootCreds(allocator: std.mem.Allocator) !void { | |||
|             .session_token = null, | ||||
|             .allocator = NullAllocator.init(), | ||||
|         }); | ||||
|         const global_account_id = try allocator.dupe(u8, account_id); | ||||
|         try root_account_mapping.put(global_access_key, global_account_id); | ||||
|         const global_account_id = try std.fmt.parseInt(u40, account_id, 10); | ||||
|         // unnecessary. Account ids are embedded in access keys! | ||||
|         // try root_account_mapping.put(global_access_key, global_account_id); | ||||
|         try Account.root_key_mapping.?.put(global_account_id, try allocator.dupe(u8, existing_key)); | ||||
|         // TODO: key rotation will need another hash map, can be triggered on val_num == 5 | ||||
| 
 | ||||
|  | @ -270,19 +254,21 @@ const NullAllocator = struct { | |||
|     } | ||||
| }; | ||||
| 
 | ||||
| fn accountForAccessKey(allocator: std.mem.Allocator, access_key: []const u8) ![]const u8 { | ||||
| fn accountForAccessKey(allocator: std.mem.Allocator, access_key: []const u8) !u40 { | ||||
|     _ = allocator; | ||||
|     log.debug("Finding account for access key: '{s}'", .{access_key}); | ||||
|     if (access_key.len != 20) return error.InvalidAccessKey; | ||||
|     return try accountIdForAccessKey(@as(*[20]u8, @ptrCast(@constCast(access_key))).*); | ||||
|     // Since this happens after authentication, we can assume our root creds store | ||||
|     // is populated | ||||
|     if (root_account_mapping.get(access_key)) |account| return account; | ||||
|     log.err("Creds not found in store. STS GetAccessKeyInfo call is not yet implemented", .{}); | ||||
|     return error.NotImplemented; | ||||
|     // if (root_account_mapping.get(access_key)) |account| return account; | ||||
|     // log.err("Creds not found in store. STS GetAccessKeyInfo call is not yet implemented", .{}); | ||||
|     // return error.NotImplemented; | ||||
| } | ||||
| /// Function assumes an authenticated request, so signing.verify must be called | ||||
| /// and returned true before calling this function. If authentication header | ||||
| /// is not found, environment variable will be used | ||||
| fn accountId(allocator: std.mem.Allocator, headers: std.http.Headers) ![]const u8 { | ||||
| fn accountId(allocator: std.mem.Allocator, headers: std.http.Headers) !u40 { | ||||
|     const auth_header = headers.getFirstValue("Authorization"); | ||||
|     if (auth_header) |h| { | ||||
|         // AWS4-HMAC-SHA256 Credential=ACCESS/20230908/us-west-2/s3/aws4_request, SignedHeaders=accept;content-length;content-type;host;x-amz-content-sha256;x-amz-date;x-amz-storage-class, Signature=fcc43ce73a34c9bd1ddf17e8a435f46a859812822f944f9eeb2aabcd64b03523 | ||||
|  | @ -326,8 +312,8 @@ fn iamCredentials(allocator: std.mem.Allocator) ![]const u8 { | |||
|     iam_credential = signing.Credentials.init(allocator, try iamAccessKey(allocator), try iamSecretKey(allocator), null); | ||||
|     return iam_credential.?; | ||||
| } | ||||
| fn iamAccountId(allocator: std.mem.Allocator) ![]const u8 { | ||||
|     return try getVariable(allocator, &iam_account_id, "IAM_ACCOUNT_ID"); | ||||
| fn iamAccountId(allocator: std.mem.Allocator) !u40 { | ||||
|     return std.fmt.parseInt(u40, try getVariable(allocator, &iam_account_id, "IAM_ACCOUNT_ID"), 10); | ||||
| } | ||||
| fn iamAccessKey(allocator: std.mem.Allocator) ![]const u8 { | ||||
|     return try getVariable(allocator, &iam_access_key, "IAM_ACCESS_KEY"); | ||||
|  | @ -346,3 +332,58 @@ test { | |||
|     std.testing.refAllDecls(@import("batchwriteitem.zig")); | ||||
|     std.testing.refAllDecls(@import("batchgetitem.zig")); | ||||
| } | ||||
| 
 | ||||
| test "can get account id from access key" { | ||||
|     // ELAKM5YGIGQQAD2B54IZ, Account 888534479904 | ||||
|     // Also, https://medium.com/@TalBeerySec/a-short-note-on-aws-key-id-f88cc4317489 | ||||
|     // aws_access_key_id: ASIAY34FZKBOKMUTVV7A yields the expected account id "609629065308" | ||||
|     try std.testing.expectEqual(@as(u40, 609629065308), try accountIdForAccessKey(@as(*[20]u8, @ptrCast(@constCast("ASIAY34FZKBOKMUTVV7A"))).*)); | ||||
|     try std.testing.expectEqual(@as(u40, 888534479904), try accountIdForAccessKey(@as(*[20]u8, @ptrCast(@constCast("ELAKM5YGIGQQAD2B54IZ"))).*)); | ||||
| } | ||||
| 
 | ||||
| fn accountIdForAccessKey(access_key: [20]u8) !u40 { | ||||
|     const ak_integer_part = access_key[4..]; | ||||
|     const ak_integer = try base32Decode(u80, @as(*[16]u8, @ptrCast(@constCast(ak_integer_part.ptr))).*); | ||||
|     const account_id = ak_integer >> 39; | ||||
|     return @as(u40, @truncate(account_id)); | ||||
|     // Do we want an array like this? Probably so | ||||
|     // import base64 | ||||
|     // import binascii | ||||
|     // | ||||
|     // def AWSAccount_from_AWSKeyID(AWSKeyID): | ||||
|     // | ||||
|     //     trimmed_AWSKeyID = AWSKeyID[4:] #remove KeyID prefix | ||||
|     //     x = base64.b32decode(trimmed_AWSKeyID) #base32 decode | ||||
|     //     y = x[0:6] | ||||
|     // | ||||
|     //     z = int.from_bytes(y, byteorder='big', signed=False) | ||||
|     //     mask = int.from_bytes(binascii.unhexlify(b'7fffffffff80'), byteorder='big', signed=False) | ||||
|     // | ||||
|     //     e = (z & mask)>>7 | ||||
|     //     return (e) | ||||
| } | ||||
| 
 | ||||
| fn base32Decode(comptime T: type, data: [@typeInfo(T).Int.bits / 5]u8) !T { | ||||
|     // const alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567"; | ||||
|     const ti = @typeInfo(T); | ||||
|     if (ti != .Int or ti.Int.signedness != .unsigned) | ||||
|         @compileError("decode only works with unsigned integers"); | ||||
|     if (ti.Int.bits % 5 != 0) | ||||
|         @compileError("unsigned integer bit length must be a multiple of 5 to use this function"); | ||||
|     const Shift_type = @Type(.{ .Int = .{ | ||||
|         .signedness = .unsigned, | ||||
|         .bits = @ceil(@log2(@as(f128, @floatFromInt(ti.Int.bits)))), | ||||
|     } }); | ||||
|     var rc: T = 0; | ||||
|     for (data, 0..) |b, i| { | ||||
|         var curr: T = 0; | ||||
|         if (b >= 'A' and b <= 'Z') { | ||||
|             curr = b - 'A'; | ||||
|         } else if (b >= '2' and b <= '7') { | ||||
|             curr = b - '2' + 26; | ||||
|         } else return error.InvalidCharacter; | ||||
|         curr <<= @as(Shift_type, @intCast((data.len - 1 - i) * 5)); | ||||
|         rc |= curr; | ||||
|     } | ||||
|     return rc; | ||||
| } | ||||
|  |  | |||
		Loading…
	
	Add table
		
		Reference in a new issue