From 685bae479b51fdc86da9f9ae562709f89aaf3455 Mon Sep 17 00:00:00 2001 From: Emil Lerch Date: Thu, 8 Feb 2024 15:25:41 -0800 Subject: [PATCH] batchwriteitem is alive --- src/Account.zig | 14 +- src/batchwriteitem.zig | 413 ++++++++++++++++++++++++++++++++++++++++- src/createtable.zig | 4 +- src/ddb.zig | 217 ++++++++++++++++++---- src/encryption.zig | 55 +++++- 5 files changed, 651 insertions(+), 52 deletions(-) diff --git a/src/Account.zig b/src/Account.zig index 966b7bd..68afeeb 100644 --- a/src/Account.zig +++ b/src/Account.zig @@ -30,11 +30,18 @@ pub fn deinit(self: Self) void { pub var data_dir: []const u8 = ""; pub var test_retain_db: bool = false; -var test_db: ?sqlite.Db = null; +var test_db: ?*sqlite.Db = null; +pub fn testDbDeinit() void { + test_retain_db = false; + if (test_db) |db| { + db.deinit(); + test_db = null; + } +} /// Gets the database for this account. If under test, a memory database is used /// instead. Will initialize the database with appropriate metadata tables -pub fn dbForAccount(allocator: std.mem.Allocator, account_id: []const u8) !sqlite.Db { +pub fn dbForAccount(allocator: std.mem.Allocator, account_id: []const u8) !*sqlite.Db { const builtin = @import("builtin"); if (builtin.is_test and test_retain_db) if (test_db) |db| return db; @@ -47,7 +54,8 @@ pub fn dbForAccount(allocator: std.mem.Allocator, account_id: []const u8) !sqlit defer allocator.free(db_file_name); const mode = if (builtin.is_test) sqlite.Db.Mode.Memory else sqlite.Db.Mode{ .File = db_file_name }; const new = mode == .Memory or (std.fs.cwd().statFile(file_without_path) catch null == null); - var db = try sqlite.Db.init(.{ + var db = try allocator.create(sqlite.Db); + db.* = try sqlite.Db.init(.{ .mode = mode, .open_flags = .{ .write = true, diff --git a/src/batchwriteitem.zig b/src/batchwriteitem.zig index f407e3f..05aead9 100644 --- a/src/batchwriteitem.zig +++ b/src/batchwriteitem.zig @@ -21,6 +21,63 @@ const AttributeValue = union(ddb.AttributeTypeName) { binary_set: [][]const u8, const Self = @This(); + pub fn jsonParse( + allocator: std.mem.Allocator, + source: *std.json.Scanner, + options: std.json.ParseOptions, + ) !Self { + if (.object_begin != try source.next()) return error.UnexpectedToken; + const token = try source.nextAlloc(allocator, options.allocate.?); + if (token != .string) return error.UnexpectedToken; + var rc: Self = undefined; + if (std.mem.eql(u8, token.string, "string")) + rc = Self{ .string = try std.json.innerParse([]const u8, allocator, source, options) }; + if (std.mem.eql(u8, token.string, "number")) + rc = Self{ .number = try std.json.innerParse([]const u8, allocator, source, options) }; + if (std.mem.eql(u8, token.string, "binary")) + rc = Self{ .binary = try std.json.innerParse([]const u8, allocator, source, options) }; + if (std.mem.eql(u8, token.string, "boolean")) + rc = Self{ .boolean = try std.json.innerParse(bool, allocator, source, options) }; + if (std.mem.eql(u8, token.string, "null")) + rc = Self{ .null = try std.json.innerParse(bool, allocator, source, options) }; + if (std.mem.eql(u8, token.string, "string_set")) + rc = Self{ .string_set = try std.json.innerParse([][]const u8, allocator, source, options) }; + if (std.mem.eql(u8, token.string, "number_set")) + rc = Self{ .number_set = try std.json.innerParse([][]const u8, allocator, source, options) }; + if (std.mem.eql(u8, token.string, "binary_set")) + rc = Self{ .binary_set = try std.json.innerParse([][]const u8, allocator, source, options) }; + if (std.mem.eql(u8, token.string, "list")) { + var json = try std.json.Value.jsonParse(allocator, source, options); + rc = Self{ .list = json.array }; + } + if (std.mem.eql(u8, token.string, "map")) { + var json = try std.json.Value.jsonParse(allocator, source, options); + rc = Self{ .map = json.object }; + } + if (.object_end != try source.next()) return error.UnexpectedToken; + return rc; + } + + pub fn jsonStringify(self: Self, jws: anytype) !void { + try jws.beginObject(); + try jws.objectField(@tagName(self)); + switch (self) { + .string, .number, .binary => |s| try jws.write(s), + .boolean, .null => |b| try jws.write(b), + .string_set, .number_set, .binary_set => |s| try jws.write(s), + .list => |l| try jws.write(l.items), + .map => |inner| { + try jws.beginObject(); + var it = inner.iterator(); + while (it.next()) |entry| { + try jws.objectField(entry.key_ptr.*); + try jws.write(entry.value_ptr.*); + } + try jws.endObject(); + }, + } + return try jws.endObject(); + } pub fn validate(self: Self) !void { switch (self) { .string, .string_set, .boolean, .null, .map, .list => {}, @@ -475,19 +532,114 @@ const Params = struct { pub fn handler(request: *AuthenticatedRequest, writer: anytype) ![]const u8 { const allocator = request.allocator; const account_id = request.account_id; - _ = account_id; var params = try Params.parseRequest(allocator, request, writer); defer params.deinit(); + try params.validate(); + // 1. Get the list of encrypted table names using the account id root key - // 2. Get the matching table-scope encryption keys - // 3. For each table request: - // 1. Find the hash values of put and delete requests in the request - // 2. Encrypt the hash values - // 3. Delete any existing records with that hash value (for delete requests, we're done here) - // 4. If put request, put the new item in the table (with encrypted values, using table encryption) + var account_tables = try ddb.tablesForAccount(allocator, account_id); + defer account_tables.deinit(); + // 2. For each table request: + for (params.request_items) |table_req| { + var request_table: ddb.Table = undefined; + var found = false; + for (account_tables.items) |tbl| { + if (std.mem.eql(u8, tbl.name, table_req.table_name)) { + request_table = tbl; + found = true; + } + } + if (!found) { + std.log.warn("Table name in request does not exist in account. Table name specified: {s}", .{table_req.table_name}); + continue; // TODO: This API has the concept of returning the list of unprocessed stuff. We need to do that here + } + for (table_req.requests) |req| { + if (req.put_request) |p| + try process_request(allocator, account_tables.db, &request_table, .put, p); + if (req.delete_request) |d| + try process_request(allocator, account_tables.db, &request_table, .delete, d); + } + } // TODO: Capacity limiting and metrics - return "hi"; + if (params.return_consumed_capacity != .none or params.return_item_collection_metrics) + try returnException( + request, + .internal_server_error, + error.NotImplemented, + writer, + "Changes processed, but metrics/capacity are not yet implemented", + ); + // { + // "UnprocessedItems": { + // "Forum": [ + // { + // "PutRequest": { + // "Item": { + // "Name": { + // "S": "Amazon ElastiCache" + // }, + // "Category": { + // "S": "Amazon Web Services" + // } + // } + // } + // } + // ] + // }, + // "ConsumedCapacity": [ + // { + // "TableName": "Forum", + // "CapacityUnits": 3 + // } + // ] + // } + return "{}"; +} +const RequestType = enum { + put, + delete, +}; +fn process_request( + allocator: std.mem.Allocator, + db: anytype, + table: *ddb.Table, + req_type: RequestType, + req_attributes: []Attribute, +) !void { + _ = db; + // 1. Find the hash values of put and delete requests in the request + const hash_key_attribute_name = table.info.value.hash_key_attribute_name; + const range_key_attribute_name = table.info.value.range_key_attribute_name; + var hash_attribute: ?Attribute = null; + var range_attribute: ?Attribute = null; + for (req_attributes) |*att| { + if (std.mem.eql(u8, att.name, hash_key_attribute_name)) { + hash_attribute = att.*; + continue; + } + if (range_key_attribute_name) |r| { + if (std.mem.eql(u8, att.name, r)) + range_attribute = att.*; + } + } + if (hash_attribute == null) return error.HashAttributeNotFound; + if (range_attribute == null and range_key_attribute_name != null) return error.RangeAttributeNotFound; + + const hash_value = try std.json.stringifyAlloc(allocator, hash_attribute.?.value, .{}); + defer allocator.free(hash_value); + const range_value = if (range_attribute) |r| + try std.json.stringifyAlloc(allocator, r.value, .{}) + else + null; + defer if (range_value) |r| allocator.free(r); + if (req_type == .delete) { + try table.deleteItem(hash_value, range_value); + } else { + const attributes_as_string = try std.json.stringifyAlloc(allocator, req_attributes, .{}); + defer allocator.free(attributes_as_string); + try table.putItem(hash_value, range_value, attributes_as_string); + } } test "basic request parsing failure" { @@ -661,3 +813,248 @@ test "all types request parsing" { try std.testing.expectEqualStrings("this text is base64-encoded", buf); try std.testing.expect(put_and_or_delete.delete_request == null); } + +test "write item" { + Account.test_retain_db = true; + defer Account.testDbDeinit(); + const allocator = std.testing.allocator; + const account_id = "1234"; + var db = try Account.dbForAccount(allocator, account_id); + defer allocator.destroy(db); + defer Account.testDbDeinit(); + const account = try Account.accountForId(allocator, account_id); // This will get us the encryption key needed + defer account.deinit(); + var hash = ddb.AttributeDefinition{ .name = "Artist", .type = .S }; + var range = ddb.AttributeDefinition{ .name = "SongTitle", .type = .S }; + var definitions = @constCast(&[_]*ddb.AttributeDefinition{ + &hash, + &range, + }); + var table_info: ddb.TableInfo = .{ + .table_key = undefined, + .attribute_definitions = definitions[0..], + .hash_key_attribute_name = "Artist", + .range_key_attribute_name = "SongTitle", + }; + encryption.randomEncodedKey(&table_info.table_key); + try ddb.createDdbTable( + allocator, + db, + account, + "MusicCollection", + table_info, + 5, + 5, + false, + ); + var request = AuthenticatedRequest{ + .output_format = .text, + .event_data = + \\ { + \\ "RequestItems": { + \\ "MusicCollection": [ + \\ { + \\ "PutRequest": { + \\ "Item": { + \\ "Artist": { + \\ "S": "Mettalica" + \\ }, + \\ "SongTitle": { + \\ "S": "Master of Puppets" + \\ }, + \\ "Binary": { + \\ "B": "dGhpcyB0ZXh0IGlzIGJhc2U2NC1lbmNvZGVk" + \\ }, + \\ "Boolean": { + \\ "BOOL": true + \\ }, + \\ "Null": { + \\ "NULL": true + \\ }, + \\ "List": { + \\ "L": [ {"S": "Cookies"} , {"S": "Coffee"}, {"N": "3.14159"}] + \\ }, + \\ "Map": { + \\ "M": {"Name": {"S": "Joe"}, "Age": {"N": "35"}} + \\ }, + \\ "Number Set": { + \\ "NS": ["42.2", "-19", "7.5", "3.14"] + \\ }, + \\ "Binary Set": { + \\ "BS": ["U3Vubnk=", "UmFpbnk=", "U25vd3k="] + \\ }, + \\ "String Set": { + \\ "SS": ["Giraffe", "Hippo" ,"Zebra"] + \\ } + \\ } + \\ } + \\ } + \\ ] + \\ } + \\ } + , + .headers = undefined, + .status = .ok, + .reason = "", + .account_id = "1234", + .allocator = allocator, + }; + var al = std.ArrayList(u8).init(allocator); + defer al.deinit(); + var writer = al.writer(); + _ = try handler(&request, writer); +} + +test "round trip attributes" { + const allocator = std.testing.allocator; + var json_stuff = try std.json.parseFromSlice(std.json.Value, allocator, + \\ { + \\ "M": {"Name": {"S": "Joe"}, "Age": {"N": "35"}}, + \\ "L": [ {"S": "Cookies"} , {"S": "Coffee"}, {"N": "3.14159"}] + \\ } + , .{}); + defer json_stuff.deinit(); + const map = json_stuff.value.object.get("M").?.object; + const list = json_stuff.value.object.get("L").?.array; + const attributes = &[_]Attribute{ + .{ + .name = "foo", + .value = .{ .string = "bar" }, + }, + .{ + .name = "foo", + .value = .{ .number = "42" }, + }, + .{ + .name = "foo", + .value = .{ .binary = "YmFy" }, // "bar" + }, + .{ + .name = "foo", + .value = .{ .boolean = true }, + }, + .{ + .name = "foo", + .value = .{ .null = false }, + }, + .{ + .name = "foo", + .value = .{ .string_set = @constCast(&[_][]const u8{ "foo", "bar" }) }, + }, + .{ + .name = "foo", + .value = .{ .number_set = @constCast(&[_][]const u8{ "41", "42" }) }, + }, + .{ + .name = "foo", + .value = .{ .binary_set = @constCast(&[_][]const u8{ "Zm9v", "YmFy" }) }, // foo, bar + }, + .{ + .name = "foo", + .value = .{ .map = map }, + }, + .{ + .name = "foo", + .value = .{ .list = list }, + }, + }; + const attributes_as_string = try std.json.stringifyAlloc( + allocator, + attributes, + .{ .whitespace = .indent_2 }, + ); + defer allocator.free(attributes_as_string); + try std.testing.expectEqualStrings( + \\[ + \\ { + \\ "name": "foo", + \\ "value": { + \\ "string": "bar" + \\ } + \\ }, + \\ { + \\ "name": "foo", + \\ "value": { + \\ "number": "42" + \\ } + \\ }, + \\ { + \\ "name": "foo", + \\ "value": { + \\ "binary": "YmFy" + \\ } + \\ }, + \\ { + \\ "name": "foo", + \\ "value": { + \\ "boolean": true + \\ } + \\ }, + \\ { + \\ "name": "foo", + \\ "value": { + \\ "null": false + \\ } + \\ }, + \\ { + \\ "name": "foo", + \\ "value": { + \\ "string_set": [ + \\ "foo", + \\ "bar" + \\ ] + \\ } + \\ }, + \\ { + \\ "name": "foo", + \\ "value": { + \\ "number_set": [ + \\ "41", + \\ "42" + \\ ] + \\ } + \\ }, + \\ { + \\ "name": "foo", + \\ "value": { + \\ "binary_set": [ + \\ "Zm9v", + \\ "YmFy" + \\ ] + \\ } + \\ }, + \\ { + \\ "name": "foo", + \\ "value": { + \\ "map": { + \\ "Name": { + \\ "S": "Joe" + \\ }, + \\ "Age": { + \\ "N": "35" + \\ } + \\ } + \\ } + \\ }, + \\ { + \\ "name": "foo", + \\ "value": { + \\ "list": [ + \\ { + \\ "S": "Cookies" + \\ }, + \\ { + \\ "S": "Coffee" + \\ }, + \\ { + \\ "N": "3.14159" + \\ } + \\ ] + \\ } + \\ } + \\] + , attributes_as_string); + + var round_tripped = try std.json.parseFromSlice([]Attribute, allocator, attributes_as_string, .{}); + defer round_tripped.deinit(); +} diff --git a/src/createtable.zig b/src/createtable.zig index 3886560..6cda337 100644 --- a/src/createtable.zig +++ b/src/createtable.zig @@ -62,12 +62,14 @@ pub fn handler(request: *AuthenticatedRequest, writer: anytype) ![]const u8 { allocator.free(request_params.table_info.attribute_definitions); } var db = try Account.dbForAccount(allocator, account_id); + defer allocator.destroy(db); + defer db.deinit(); const account = try Account.accountForId(allocator, account_id); // This will get us the encryption key needed defer account.deinit(); try ddb.createDdbTable( allocator, - &db, + db, account, request_params.table_name, request_params.table_info, diff --git a/src/ddb.zig b/src/ddb.zig index cbdb9ff..db8888b 100644 --- a/src/ddb.zig +++ b/src/ddb.zig @@ -5,6 +5,17 @@ const Account = @import("Account.zig"); const encryption = @import("encryption.zig"); const builtin = @import("builtin"); +// We need our enryption to be able to store/retrieve and otherwise work like +// a database. So the use of a nonce here defeats these use cases +const nonce = &[_]u8{ + 0x55, 0x4a, 0x38, 0x16, 0x55, 0x55, 0x2d, 0x05, + 0x32, 0x70, 0x3f, 0xa0, 0xde, 0x3d, 0x2c, 0xb8, + 0x89, 0x40, 0x07, 0xc5, 0x57, 0x7d, 0xa0, 0xb8, +}; + +fn encryptAndEncode(allocator: std.mem.Allocator, key: [encryption.key_length]u8, plaintext: []const u8) ![]const u8 { + return try encryption.encryptAndEncodeWithNonce(allocator, key, nonce.*, plaintext); +} /// Serialized into metadata table. This is an explicit enum with a twin /// AttributeTypeName enum to make coding with these types easier. Use /// Descriptor for storage or communication with the outside world, and @@ -60,75 +71,193 @@ pub const TableInfo = struct { range_key_attribute_name: ?[]const u8, }; -pub const TableArray = struct { +pub const AccountTables = struct { items: []Table, + db: *sqlite.Db, allocator: std.mem.Allocator, - pub fn init(allocator: std.mem.Allocator, length: usize) !TableArray { + pub fn init(allocator: std.mem.Allocator, length: usize, db: *sqlite.Db) !AccountTables { return .{ .allocator = allocator, .items = try allocator.alloc(Table, length), + .db = db, }; } - pub fn deinit(self: *TableArray) void { + pub fn deinit(self: *AccountTables) void { for (self.items) |*item| item.deinit(); self.allocator.free(self.items); } }; pub const Table = struct { - table_name: []const u8, - table_key: [encryption.key_length]u8, + name: []const u8, + key: [encryption.key_length]u8, + info: std.json.Parsed(TableInfo), + /// underlying data for json parsed version + info_str: []const u8, + db: *sqlite.Db, allocator: std.mem.Allocator, + encrypted_name: []const u8, + + pub fn deleteItem(self: *Table, hash_value: []const u8, range_value: ?[]const u8) !void { + const encrypted_hash = try encryptAndEncode(self.allocator, self.key, hash_value); + defer self.allocator.free(encrypted_hash); + const encrypted_range = if (range_value) |r| + try encryptAndEncode(self.allocator, self.key, r) + else + null; + defer if (encrypted_range != null) self.allocator.free(encrypted_range.?); + + // TODO: hashKey and rangeKey are text, while hashvalue/rangevalue are blobx + // this is to accomodate non-string hash/range value by running a hash + // function over the data, probably base64 encoded. Do we want to do + // something like this? + const delete = try std.fmt.allocPrint(self.allocator, + \\DELETE FROM '{s}' WHERE + \\ hashKey = ? + \\ AND + \\ rangeKey = ? + , .{self.encrypted_name}); + defer self.allocator.free(delete); + try self.db.execDynamic(delete, .{}, .{ + encrypted_hash, + encrypted_range, + }); + } + pub fn putItem( + self: *Table, + hash_value: []const u8, + range_value: ?[]const u8, + data: []const u8, + ) !void { + var sp = try self.db.savepoint("putItem"); + + errdefer sp.rollback(); + // TODO: savepoint this + try self.deleteItem(hash_value, range_value); + const encrypted_hash = try encryptAndEncode(self.allocator, self.key, hash_value); + defer self.allocator.free(encrypted_hash); + const encrypted_range = if (range_value) |r| + try encryptAndEncode(self.allocator, self.key, r) + else + null; + defer if (encrypted_range != null) self.allocator.free(encrypted_range.?); + const encrypted_data = try encryptAndEncode(self.allocator, self.key, data); + defer self.allocator.free(encrypted_data); + + // TODO: hashKey and rangeKey are text, while hashvalue/rangevalue are blobx + // this is to accomodate non-string hash/range value by running a hash + // function over the data, probably base64 encoded. Do we want to do + // something like this? + const insert = try std.fmt.allocPrint(self.allocator, + \\INSERT INTO '{s}' ( + \\ hashKey, + \\ rangeKey, + \\ hashValue, + \\ rangeValue, + \\ itemSize, + \\ ObjectJSON + \\ ) VALUES ( ?, ?, ?, ?, ?, ? ) + // This syntax doesn't seem to work here? Used ?'s above + // \\ $encrypted_hash{{[]const u8}}, + // \\ $encrypted_range{{[]const u8}}, + // \\ $encrypted_hash{{[]const u8}}, + // \\ $encrypted_range{{[]const u8}}, + // \\ $len{{usize}}, + // \\ $encrypted_data{{[]const u8}} + // \\ ) + , .{self.encrypted_name}); + defer self.allocator.free(insert); + var diags = sqlite.Diagnostics{}; + // std.debug.print( + // \\==================== + // \\Insert to table: {s} + // \\ hashKey: {s} + // \\ rangeKey: {?s} + // \\ hashValue: {s} + // \\ rangeValue: {?s} + // \\ itemSize: {d} + // \\ ObjectJSON: {s} + // \\==================== + // , .{ + // self.encrypted_name, + // encrypted_hash, + // encrypted_range, + // encrypted_hash, + // encrypted_range, + // encrypted_data.len, + // encrypted_data, + // }); + self.db.execDynamic(insert, .{ .diags = &diags }, .{ + encrypted_hash, + encrypted_range, + encrypted_hash, + encrypted_range, + encrypted_data.len, + encrypted_data, + }) catch |e| { + std.debug.print("Insert stmt: {s}\n", .{insert}); + std.debug.print("SqlLite diags: {s}\n", .{diags}); + return e; + }; + sp.commit(); + } pub fn deinit(self: *Table) void { - std.crypto.utils.secureZero(u8, &self.table_key); - self.allocator.free(self.table_name); + std.crypto.utils.secureZero(u8, &self.key); + self.allocator.free(self.encrypted_name); + self.allocator.free(self.info_str); + self.info.deinit(); + self.allocator.free(self.name); } }; -// Gets all table names/keys for the account. Caller owns returned array -pub fn tablesForAccount(allocator: std.mem.Allocator, account_id: []const u8) !TableArray { +/// Gets all table names/keys for the account. Caller owns returned array +/// The return value will also provide the opened database. As encryption keys +/// are stored in here, realistically, this will be the first function called +/// every time anything interacts with the database, so this function opens +/// the database for you +pub fn tablesForAccount(allocator: std.mem.Allocator, account_id: []const u8) !AccountTables { + + // TODO: This function should take a list of table names, which can then be used + // to filter the query below rather than just grabbing everything var db = try Account.dbForAccount(allocator, account_id); - defer if (!builtin.is_test) db.deinit(); + errdefer if (!builtin.is_test) db.deinit(); const account = try Account.accountForId(allocator, account_id); // This will get us the encryption key needed defer account.deinit(); const query = - \\SELECT TableName as table_name, TableInfo as table_info FROM dm + \\SELECT TableName as name, TableInfo as info FROM dm ; var stmt = try db.prepare(query); defer stmt.deinit(); const rows = try stmt.all(struct { - table_name: []const u8, - table_info: []const u8, + name: []const u8, + info: []const u8, }, allocator, .{}, .{}); defer allocator.free(rows); - var rc = try TableArray.init(allocator, rows.len); + var rc = try AccountTables.init(allocator, rows.len, db); errdefer rc.deinit(); // std.debug.print(" \n===\nRow count: {d}\n===\n", .{rows.len}); for (rows, 0..) |row, inx| { - defer allocator.free(row.table_name); - defer allocator.free(row.table_info); + defer allocator.free(row.name); + defer allocator.free(row.info); const table_name = try encryption.decodeAndDecrypt( allocator, account.root_account_key.*, - row.table_name, + row.name, ); errdefer allocator.free(table_name); const table_info_str = try encryption.decodeAndDecrypt( allocator, account.root_account_key.*, - row.table_info, + row.info, ); - defer allocator.free(table_info_str); - // std.debug.print(" \n===TableInfo: {s}\n===\n", .{table_info_str}); - const table_info = try std.json.parseFromSlice(TableInfo, allocator, table_info_str, .{}); - defer table_info.deinit(); + // errdefer allocator.free(table_info.table_key); // defer { // // we don't even really need to defer this... @@ -141,10 +270,14 @@ pub fn tablesForAccount(allocator: std.mem.Allocator, account_id: []const u8) !T rc.items[inx] = .{ .allocator = allocator, - .table_name = table_name, - .table_key = undefined, + .name = table_name, + .encrypted_name = try allocator.dupe(u8, row.name), + .key = undefined, + .info = try std.json.parseFromSlice(TableInfo, allocator, table_info_str, .{}), + .info_str = table_info_str, + .db = db, }; - try encryption.decodeKey(&rc.items[inx].table_key, table_info.value.table_key); + try encryption.decodeKey(&rc.items[inx].key, rc.items[inx].info.value.table_key); } return rc; } @@ -220,12 +353,12 @@ fn insertIntoDatabaseMetadata( billing_mode_pay_per_request: bool, ) ![]const u8 { // TODO: better to do all encryption when request params are parsed? - const encrypted_table_name = try encryption.encryptAndEncode(allocator, account.root_account_key.*, table_name); + const encrypted_table_name = try encryptAndEncode(allocator, account.root_account_key.*, table_name); errdefer allocator.free(encrypted_table_name); // We'll json serialize our table_info structure, encrypt, encode, and plow in const table_info_string = try std.json.stringifyAlloc(allocator, table_info, .{ .whitespace = .indent_2 }); defer allocator.free(table_info_string); - const encrypted_table_info = try encryption.encryptAndEncode(allocator, account.root_account_key.*, table_info_string); + const encrypted_table_info = try encryptAndEncode(allocator, account.root_account_key.*, table_info_string); defer allocator.free(encrypted_table_info); try insertIntoDm(db, encrypted_table_name, encrypted_table_info, read_capacity_units, write_capacity_units, billing_mode_pay_per_request); return encrypted_table_name; @@ -279,7 +412,7 @@ fn insertIntoDm( }); } -fn testCreateTable(allocator: std.mem.Allocator, account_id: []const u8) !sqlite.Db { +fn testCreateTable(allocator: std.mem.Allocator, account_id: []const u8) !*sqlite.Db { var db = try Account.dbForAccount(allocator, account_id); const account = try Account.accountForId(allocator, account_id); // This will get us the encryption key needed defer account.deinit(); @@ -298,7 +431,7 @@ fn testCreateTable(allocator: std.mem.Allocator, account_id: []const u8) !sqlite encryption.randomEncodedKey(&table_info.table_key); try createDdbTable( allocator, - &db, + db, account, "MusicCollection", table_info, @@ -312,18 +445,38 @@ test "can create a table" { const allocator = std.testing.allocator; const account_id = "1234"; var db = try testCreateTable(allocator, account_id); + defer allocator.destroy(db); defer db.deinit(); } test "can list tables in an account" { Account.test_retain_db = true; - defer Account.test_retain_db = false; const allocator = std.testing.allocator; const account_id = "1234"; var db = try testCreateTable(allocator, account_id); - defer db.deinit(); + defer allocator.destroy(db); + defer Account.testDbDeinit(); var table_list = try tablesForAccount(allocator, account_id); defer table_list.deinit(); try std.testing.expectEqual(@as(usize, 1), table_list.items.len); - try std.testing.expectEqualStrings("MusicCollection", table_list.items[0].table_name); + try std.testing.expectEqualStrings("MusicCollection", table_list.items[0].name); + // std.debug.print(" \n===\nKey: {s}\n===\n", .{std.fmt.fmtSliceHexLower(&table_list.items[0].table_key)}); +} + +test "can put an item in a table in an account" { + Account.test_retain_db = true; + const allocator = std.testing.allocator; + const account_id = "1234"; + var db = try testCreateTable(allocator, account_id); + defer allocator.destroy(db); + defer Account.testDbDeinit(); + var table_list = try tablesForAccount(allocator, account_id); + defer table_list.deinit(); + try std.testing.expectEqualStrings("MusicCollection", table_list.items[0].name); + var table = table_list.items[0]; + try table.putItem("Foo Fighters", "Everlong", "whatevs"); + // This should succeed, because putItem is an upsert mechanism + try table.putItem("Foo Fighters", "Everlong", "whatevs"); + + // TODO: this test should do getItem to verify data // std.debug.print(" \n===\nKey: {s}\n===\n", .{std.fmt.fmtSliceHexLower(&table_list.items[0].table_key)}); } diff --git a/src/encryption.zig b/src/encryption.zig index 77216b4..3dc1e2f 100644 --- a/src/encryption.zig +++ b/src/encryption.zig @@ -5,6 +5,7 @@ pub const salt_length = 256 / 8; // https://crypto.stackexchange.com/a/56132 pub const encoded_salt_length = std.base64.standard.Encoder.calcSize(salt_length); pub const key_length = std.crypto.aead.salsa_poly.XSalsa20Poly1305.key_length; pub const encoded_key_length = std.base64.standard.Encoder.calcSize(key_length); +pub const nonce_length = std.crypto.aead.salsa_poly.XSalsa20Poly1305.nonce_length; /// Generates a random salt of appropriate length pub fn randomSalt(salt: *[salt_length]u8) void { @@ -56,19 +57,25 @@ pub fn deriveKeyFromEncodedSalt(derived_key: *[key_length]u8, password: []const return derived_key.*; } -/// Encrypts data. Use deriveKey function to get a key from password/salt +/// Encrypts data. Use deriveKey function to get a key from password/salt. +/// Uses a random nonce. To supply a nonce instead, use encryptWithNonce /// Caller owns memory pub fn encrypt(allocator: std.mem.Allocator, key: [key_length]u8, plaintext: []const u8) ![]const u8 { + var nonce: [std.crypto.aead.salsa_poly.XSalsa20Poly1305.nonce_length]u8 = undefined; + std.crypto.random.bytes(&nonce); // add nonce to beginning of our ciphertext + return try encryptWithNonce(allocator, key, nonce, plaintext); +} + +pub fn encryptWithNonce(allocator: std.mem.Allocator, key: [key_length]u8, nonce: [nonce_length]u8, plaintext: []const u8) ![]const u8 { var ciphertext = try allocator.alloc( u8, - std.crypto.aead.salsa_poly.XSalsa20Poly1305.nonce_length + std.crypto.aead.salsa_poly.XSalsa20Poly1305.tag_length + plaintext.len, + nonce_length + std.crypto.aead.salsa_poly.XSalsa20Poly1305.tag_length + plaintext.len, ); errdefer allocator.free(ciphertext); // Create the nonce - const nonce_length = std.crypto.aead.salsa_poly.XSalsa20Poly1305.nonce_length; - std.crypto.random.bytes(ciphertext[0..nonce_length]); // add nonce to beginning of our ciphertext - const nonce = ciphertext[0..nonce_length]; + @memcpy(ciphertext[0..nonce_length], nonce[0..]); // add nonce to beginning of our ciphertext + const nonce_copy = ciphertext[0..nonce_length]; const tag = ciphertext[nonce_length .. nonce_length + std.crypto.aead.salsa_poly.XSalsa20Poly1305.tag_length]; const c = ciphertext[nonce_length + std.crypto.aead.salsa_poly.XSalsa20Poly1305.tag_length ..]; @@ -78,7 +85,7 @@ pub fn encrypt(allocator: std.mem.Allocator, key: [key_length]u8, plaintext: []c tag, plaintext, "ad", - nonce.*, + nonce_copy.*, key, ); return ciphertext; @@ -95,14 +102,24 @@ pub fn encryptAndEncode(allocator: std.mem.Allocator, key: [key_length]u8, plain return Encoder.encode(encoded_ciphertext, ciphertext); } +/// Encrypts data. Use deriveKey function to get a key from password/salt +/// Caller owns memory +pub fn encryptAndEncodeWithNonce(allocator: std.mem.Allocator, key: [key_length]u8, nonce: [nonce_length]u8, plaintext: []const u8) ![]const u8 { + const ciphertext = try encryptWithNonce(allocator, key, nonce, plaintext); + defer allocator.free(ciphertext); + const Encoder = std.base64.standard.Encoder; + var encoded_ciphertext = try allocator.alloc(u8, Encoder.calcSize(ciphertext.len)); + errdefer allocator.free(encoded_ciphertext); + return Encoder.encode(encoded_ciphertext, ciphertext); +} + /// Decrypts data. Use deriveKey function to get a key from password/salt pub fn decrypt(allocator: std.mem.Allocator, key: [key_length]u8, ciphertext: []const u8) ![]const u8 { var plaintext = try allocator.alloc( u8, - ciphertext.len - std.crypto.aead.salsa_poly.XSalsa20Poly1305.nonce_length - std.crypto.aead.salsa_poly.XSalsa20Poly1305.tag_length, + ciphertext.len - nonce_length - std.crypto.aead.salsa_poly.XSalsa20Poly1305.tag_length, ); errdefer allocator.free(plaintext); - const nonce_length = std.crypto.aead.salsa_poly.XSalsa20Poly1305.nonce_length; const nonce = ciphertext[0..nonce_length].*; const tag = ciphertext[nonce_length .. nonce_length + std.crypto.aead.salsa_poly.XSalsa20Poly1305.tag_length].*; const c = ciphertext[nonce_length + std.crypto.aead.salsa_poly.XSalsa20Poly1305.tag_length ..]; @@ -164,6 +181,28 @@ test "can encrypt and decrypt data with simpler api but without KDF" { try std.testing.expectEqualStrings(plaintext, decrypted_text[0..]); } +test "can encrypt twice with same result" { + const allocator = std.testing.allocator; + const plaintext = "Hello, Zig!"; + var key: [key_length]u8 = undefined; + var nonce: [nonce_length]u8 = undefined; + var encoded_key: [encoded_key_length]u8 = undefined; + randomEncodedKey(encoded_key[0..]); + std.crypto.random.bytes(&nonce); + // std.testing.log_level = .debug; + std.log.debug("Encoded key: {s}", .{encoded_key}); + + try decodeKey(&key, encoded_key); + const ciphertext = try encryptWithNonce(allocator, key, nonce, plaintext); + defer allocator.free(ciphertext); + std.log.debug("Ciphertext: {s}\n", .{std.fmt.fmtSliceHexLower(ciphertext)}); + + const ciphertext2 = try encryptWithNonce(allocator, key, nonce, plaintext); + defer allocator.free(ciphertext2); + std.log.debug("Ciphertext: {s}\n", .{std.fmt.fmtSliceHexLower(ciphertext2)}); + try std.testing.expectEqualSlices(u8, ciphertext, ciphertext2); +} + // test "can encrypt and decrypt data" { // var tag: [std.crypto.aead.salsa_poly.XSalsa20Poly1305.tag_length]u8 = undefined; // const password = "mySecurePassword";